use dist.all_reduce(min) to sync num_blocks_local (#2933)

* pre-commit all files check

* reduce min num_blocks_local

* fix nranks=1

* pre-commit when commit-msg
This commit is contained in:
Yuanle Liu
2025-07-21 16:23:36 +08:00
committed by GitHub
parent 67990e0572
commit 2f74e93d7e
9 changed files with 71 additions and 66 deletions

View File

@@ -3,6 +3,7 @@ default_install_hook_types:
- commit-msg - commit-msg
default_stages: default_stages:
- pre-commit # Run locally - pre-commit # Run locally
- commit-msg
# - manual # Run in CI # - manual # Run in CI
repos: repos:
- repo: https://github.com/psf/black.git - repo: https://github.com/psf/black.git

View File

@@ -860,7 +860,7 @@ class LLMEngine:
) )
if self.do_profile: if self.do_profile:
get_profile_block_num = np.zeros([self.cfg.worker_num_per_node], dtype=np.int32) get_profile_block_num = np.zeros([1], dtype=np.int32)
self.get_profile_block_num_signal = IPCSignal( self.get_profile_block_num_signal = IPCSignal(
name="get_profile_block_num", name="get_profile_block_num",
array=get_profile_block_num, array=get_profile_block_num,
@@ -1118,15 +1118,9 @@ class LLMEngine:
Stop profiling of the model server and reset variables. Stop profiling of the model server and reset variables.
""" """
self.do_profile = 0 self.do_profile = 0
num_gpu_blocks = -1 while self.get_profile_block_num_signal.value[0] == 0:
for i in range(self.cfg.tensor_parallel_size):
while self.get_profile_block_num_signal.value[i] == 0:
time.sleep(1) time.sleep(1)
if num_gpu_blocks < 0: num_gpu_blocks = self.get_profile_block_num_signal.value[0]
num_gpu_blocks = self.get_profile_block_num_signal.value[i]
else:
num_gpu_blocks = min(num_gpu_blocks, self.get_profile_block_num_signal.value[i])
self.cfg.cache_config.reset(num_gpu_blocks) self.cfg.cache_config.reset(num_gpu_blocks)
self.resource_manager.reset_cache_config(self.cfg.cache_config) self.resource_manager.reset_cache_config(self.cfg.cache_config)
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed": if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":

View File

@@ -141,7 +141,8 @@ class EngineClient:
task["preprocess_end_time"] = time.time() task["preprocess_end_time"] = time.time()
preprocess_cost_time = task["preprocess_end_time"] - task["preprocess_start_time"] preprocess_cost_time = task["preprocess_end_time"] - task["preprocess_start_time"]
api_server_logger.info( api_server_logger.info(
f"Cache request with request_id ({task.get('request_id')}), " f"preprocess time cost {preprocess_cost_time}" f"Cache request with request_id ({task.get('request_id')}), "
f"preprocess time cost {preprocess_cost_time}"
) )
self.vaild_parameters(task) self.vaild_parameters(task)

View File

@@ -110,6 +110,7 @@ class XPUForwardMeta(ForwardMeta):
""" """
XPUForwardMeta is used to store the global meta information of the forward, and some XPU specific meta info. XPUForwardMeta is used to store the global meta information of the forward, and some XPU specific meta info.
""" """
# Accumulated offset # Accumulated offset
cum_offsets: Optional[paddle.Tensor] = None cum_offsets: Optional[paddle.Tensor] = None
# TODO(wanghaitao): Supplementary notes # TODO(wanghaitao): Supplementary notes

View File

@@ -375,10 +375,21 @@ class PaddleDisWorkerProc:
logger.info(f"------- model_block_memory_used:{model_block_memory_used} --------") logger.info(f"------- model_block_memory_used:{model_block_memory_used} --------")
logger.info(f"------- num_blocks_local:{num_blocks_local} --------") logger.info(f"------- num_blocks_local:{num_blocks_local} --------")
logger.info(f"self.fd_config.parallel_config.do_profile:{self.fd_config.parallel_config.do_profile}") if num_blocks_local <= 0:
raise ValueError(
"The total number of blocks cannot be less than zero."
"Please increase gpu_memory_utilization"
"Or decrease max_num_batched_tokens(max model length) "
)
if self.ranks > 1:
num_blocks_local = paddle.full(shape=[1], fill_value=num_blocks_local, dtype="int32")
dist.all_reduce(num_blocks_local, op=dist.ReduceOp.MIN)
num_blocks_local = num_blocks_local.item()
if self.local_rank == 0:
# 3. Send IPCSignal # 3. Send IPCSignal
get_profile_block_num = np.zeros(shape=[self.ranks], dtype=np.int32) get_profile_block_num = np.zeros(shape=[1], dtype=np.int32)
self.get_profile_block_num_signal = IPCSignal( self.get_profile_block_num_signal = IPCSignal(
name="get_profile_block_num", name="get_profile_block_num",
array=get_profile_block_num, array=get_profile_block_num,
@@ -386,31 +397,11 @@ class PaddleDisWorkerProc:
suffix=self.parallel_config.engine_pid, suffix=self.parallel_config.engine_pid,
create=False, create=False,
) )
self.get_profile_block_num_signal.value[self.local_rank] = num_blocks_local self.get_profile_block_num_signal.value[0] = num_blocks_local
# Wait all worker send the signal
while np.any(self.get_profile_block_num_signal.value <= 0):
time.sleep(0.01)
num_blocks_global = self.get_profile_block_num_signal.value.min().item()
if num_blocks_global < 0:
logger.error(
"The total number of blocks cannot be less than zero."
"Please increase gpu_memory_utilization"
"Or decrease max_num_batched_tokens(max model length) "
)
raise ValueError(
"The total number of blocks cannot be less than zero."
"Please increase gpu_memory_utilization"
"Or decrease max_num_batched_tokens(max model length) "
)
self.get_profile_block_num_signal.value[self.local_rank] = num_blocks_global
else: else:
num_blocks_global = self.fd_config.parallel_config.total_block_num num_blocks_local = self.fd_config.parallel_config.total_block_num
# NOTE(liuzichang): Too big num_blocks_global will lead to error 700
# 4. Updata share inputs # 4. Updata share inputs
self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_global) self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_local)
def init_device(self) -> None: def init_device(self) -> None:
"""Initialize device and Construct model runner""" """Initialize device and Construct model runner"""

View File

@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import difflib
import os
import argparse import argparse
import difflib
from paddleformers.trl.llm_utils import init_dist_env from paddleformers.trl.llm_utils import init_dist_env
@@ -24,12 +23,7 @@ from fastdeploy.rl.rollout_model import RolloutModel
_, ranks = init_dist_env() _, ranks = init_dist_env()
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("--model_path", type=str, required=True, help="Path to the model directory")
"--model_path",
type=str,
required=True,
help="Path to the model directory"
)
args = parser.parse_args() args = parser.parse_args()
# base result # base result
@@ -55,6 +49,7 @@ for k, v in actor_eval_model.state_dict().items():
for k, v in actor_eval_model.get_name_mappings_to_training().items(): for k, v in actor_eval_model.get_name_mappings_to_training().items():
content += f"{k}:{v}\n" content += f"{k}:{v}\n"
def compare_strings(a: str, b: str) -> bool: def compare_strings(a: str, b: str) -> bool:
if a == b: if a == b:
print("✅ 两个字符串完全一致") print("✅ 两个字符串完全一致")
@@ -68,8 +63,11 @@ def compare_strings(a: str, b: str) -> bool:
return False return False
with open("baseline.txt", "r", encoding="utf-8") as f: with open("baseline.txt", "r", encoding="utf-8") as f:
baseline = f.read() baseline = f.read()
assert compare_strings(baseline, content), "In the unittest of RL scenario, your modification " \ assert compare_strings(baseline, content), (
"caused inconsistency in the content before and after. Please fix it. " \ "In the unittest of RL scenario, your modification "
"caused inconsistency in the content before and after. Please fix it. "
"Can request assistance from yuanlehome or gzy19990617 (github id)." "Can request assistance from yuanlehome or gzy19990617 (github id)."
)

View File

@@ -12,11 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import subprocess import subprocess
import sys import sys
import os
import time
import pytest
def test_rollout_model_with_distributed_launch(): def test_rollout_model_with_distributed_launch():
@@ -29,26 +27,24 @@ def test_rollout_model_with_distributed_launch():
base_path = os.getenv("MODEL_PATH") base_path = os.getenv("MODEL_PATH")
if base_path: if base_path:
model_path=os.path.join(base_path, "ernie-4_5-vl-28b-a3b-bf16-paddle") model_path = os.path.join(base_path, "ernie-4_5-vl-28b-a3b-bf16-paddle")
else: else:
model_path="./ernie-4_5-vl-28b-a3b-bf16-paddle" model_path = "./ernie-4_5-vl-28b-a3b-bf16-paddle"
command = [ command = [
sys.executable, sys.executable,
"-m", "paddle.distributed.launch", "-m",
"--gpus", "0,1", "paddle.distributed.launch",
"--gpus",
"0,1",
rollout_script, rollout_script,
"--model_path", model_path, "--model_path",
model_path,
] ]
print(f"Executing command: {' '.join(command)}") print(f"Executing command: {' '.join(command)}")
process = subprocess.Popen( process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
try: try:
stdout, stderr = process.communicate(timeout=300) stdout, stderr = process.communicate(timeout=300)

View File

@@ -8,9 +8,33 @@ prompts = [
sampling_params = SamplingParams(temperature=0.8, top_p=0.00001, max_tokens=16) sampling_params = SamplingParams(temperature=0.8, top_p=0.00001, max_tokens=16)
# 加载模型 # 加载模型
llm = LLM(model="/data1/fastdeploy/ERNIE_300B_4L", tensor_parallel_size=16, max_model_len=8192, static_decode_blocks=0, quantization='wint8', block_size=16) llm = LLM(
model="/data1/fastdeploy/ERNIE_300B_4L",
tensor_parallel_size=16,
max_model_len=8192,
static_decode_blocks=0,
quantization="wint8",
block_size=16,
)
# 批量进行推理llm内部基于资源情况进行请求排队、动态插入处理 # 批量进行推理llm内部基于资源情况进行请求排队、动态插入处理
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
assert outputs[0].outputs.token_ids==[23768, 97000, 47814, 59335, 68170, 183, 49080, 94717, 82966, 99140, 31615, 51497, 94851, 60764, 10889, 2] assert outputs[0].outputs.token_ids == [
23768,
97000,
47814,
59335,
68170,
183,
49080,
94717,
82966,
99140,
31615,
51497,
94851,
60764,
10889,
2,
]

View File

@@ -73,6 +73,5 @@ def test_sampler():
print(next_tokens) print(next_tokens)
if __name__ == "__main__": if __name__ == "__main__":
test_sampler() test_sampler()