mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
use dist.all_reduce(min) to sync num_blocks_local (#2933)
* pre-commit all files check * reduce min num_blocks_local * fix nranks=1 * pre-commit when commit-msg
This commit is contained in:
@@ -3,6 +3,7 @@ default_install_hook_types:
|
|||||||
- commit-msg
|
- commit-msg
|
||||||
default_stages:
|
default_stages:
|
||||||
- pre-commit # Run locally
|
- pre-commit # Run locally
|
||||||
|
- commit-msg
|
||||||
# - manual # Run in CI
|
# - manual # Run in CI
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/psf/black.git
|
- repo: https://github.com/psf/black.git
|
||||||
|
@@ -860,7 +860,7 @@ class LLMEngine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.do_profile:
|
if self.do_profile:
|
||||||
get_profile_block_num = np.zeros([self.cfg.worker_num_per_node], dtype=np.int32)
|
get_profile_block_num = np.zeros([1], dtype=np.int32)
|
||||||
self.get_profile_block_num_signal = IPCSignal(
|
self.get_profile_block_num_signal = IPCSignal(
|
||||||
name="get_profile_block_num",
|
name="get_profile_block_num",
|
||||||
array=get_profile_block_num,
|
array=get_profile_block_num,
|
||||||
@@ -1118,15 +1118,9 @@ class LLMEngine:
|
|||||||
Stop profiling of the model server and reset variables.
|
Stop profiling of the model server and reset variables.
|
||||||
"""
|
"""
|
||||||
self.do_profile = 0
|
self.do_profile = 0
|
||||||
num_gpu_blocks = -1
|
while self.get_profile_block_num_signal.value[0] == 0:
|
||||||
for i in range(self.cfg.tensor_parallel_size):
|
time.sleep(1)
|
||||||
while self.get_profile_block_num_signal.value[i] == 0:
|
num_gpu_blocks = self.get_profile_block_num_signal.value[0]
|
||||||
time.sleep(1)
|
|
||||||
if num_gpu_blocks < 0:
|
|
||||||
num_gpu_blocks = self.get_profile_block_num_signal.value[i]
|
|
||||||
else:
|
|
||||||
num_gpu_blocks = min(num_gpu_blocks, self.get_profile_block_num_signal.value[i])
|
|
||||||
|
|
||||||
self.cfg.cache_config.reset(num_gpu_blocks)
|
self.cfg.cache_config.reset(num_gpu_blocks)
|
||||||
self.resource_manager.reset_cache_config(self.cfg.cache_config)
|
self.resource_manager.reset_cache_config(self.cfg.cache_config)
|
||||||
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
|
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
|
||||||
|
@@ -141,7 +141,8 @@ class EngineClient:
|
|||||||
task["preprocess_end_time"] = time.time()
|
task["preprocess_end_time"] = time.time()
|
||||||
preprocess_cost_time = task["preprocess_end_time"] - task["preprocess_start_time"]
|
preprocess_cost_time = task["preprocess_end_time"] - task["preprocess_start_time"]
|
||||||
api_server_logger.info(
|
api_server_logger.info(
|
||||||
f"Cache request with request_id ({task.get('request_id')}), " f"preprocess time cost {preprocess_cost_time}"
|
f"Cache request with request_id ({task.get('request_id')}), "
|
||||||
|
f"preprocess time cost {preprocess_cost_time}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.vaild_parameters(task)
|
self.vaild_parameters(task)
|
||||||
|
@@ -110,6 +110,7 @@ class XPUForwardMeta(ForwardMeta):
|
|||||||
"""
|
"""
|
||||||
XPUForwardMeta is used to store the global meta information of the forward, and some XPU specific meta info.
|
XPUForwardMeta is used to store the global meta information of the forward, and some XPU specific meta info.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Accumulated offset
|
# Accumulated offset
|
||||||
cum_offsets: Optional[paddle.Tensor] = None
|
cum_offsets: Optional[paddle.Tensor] = None
|
||||||
# TODO(wanghaitao): Supplementary notes
|
# TODO(wanghaitao): Supplementary notes
|
||||||
|
@@ -375,42 +375,33 @@ class PaddleDisWorkerProc:
|
|||||||
logger.info(f"------- model_block_memory_used:{model_block_memory_used} --------")
|
logger.info(f"------- model_block_memory_used:{model_block_memory_used} --------")
|
||||||
logger.info(f"------- num_blocks_local:{num_blocks_local} --------")
|
logger.info(f"------- num_blocks_local:{num_blocks_local} --------")
|
||||||
|
|
||||||
logger.info(f"self.fd_config.parallel_config.do_profile:{self.fd_config.parallel_config.do_profile}")
|
if num_blocks_local <= 0:
|
||||||
|
|
||||||
# 3. Send IPCSignal
|
|
||||||
get_profile_block_num = np.zeros(shape=[self.ranks], dtype=np.int32)
|
|
||||||
self.get_profile_block_num_signal = IPCSignal(
|
|
||||||
name="get_profile_block_num",
|
|
||||||
array=get_profile_block_num,
|
|
||||||
dtype=np.int32,
|
|
||||||
suffix=self.parallel_config.engine_pid,
|
|
||||||
create=False,
|
|
||||||
)
|
|
||||||
self.get_profile_block_num_signal.value[self.local_rank] = num_blocks_local
|
|
||||||
|
|
||||||
# Wait all worker send the signal
|
|
||||||
while np.any(self.get_profile_block_num_signal.value <= 0):
|
|
||||||
time.sleep(0.01)
|
|
||||||
num_blocks_global = self.get_profile_block_num_signal.value.min().item()
|
|
||||||
|
|
||||||
if num_blocks_global < 0:
|
|
||||||
logger.error(
|
|
||||||
"The total number of blocks cannot be less than zero."
|
|
||||||
"Please increase gpu_memory_utilization"
|
|
||||||
"Or decrease max_num_batched_tokens(max model length) "
|
|
||||||
)
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The total number of blocks cannot be less than zero."
|
"The total number of blocks cannot be less than zero."
|
||||||
"Please increase gpu_memory_utilization"
|
"Please increase gpu_memory_utilization"
|
||||||
"Or decrease max_num_batched_tokens(max model length) "
|
"Or decrease max_num_batched_tokens(max model length) "
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_profile_block_num_signal.value[self.local_rank] = num_blocks_global
|
if self.ranks > 1:
|
||||||
|
num_blocks_local = paddle.full(shape=[1], fill_value=num_blocks_local, dtype="int32")
|
||||||
|
dist.all_reduce(num_blocks_local, op=dist.ReduceOp.MIN)
|
||||||
|
num_blocks_local = num_blocks_local.item()
|
||||||
|
|
||||||
|
if self.local_rank == 0:
|
||||||
|
# 3. Send IPCSignal
|
||||||
|
get_profile_block_num = np.zeros(shape=[1], dtype=np.int32)
|
||||||
|
self.get_profile_block_num_signal = IPCSignal(
|
||||||
|
name="get_profile_block_num",
|
||||||
|
array=get_profile_block_num,
|
||||||
|
dtype=np.int32,
|
||||||
|
suffix=self.parallel_config.engine_pid,
|
||||||
|
create=False,
|
||||||
|
)
|
||||||
|
self.get_profile_block_num_signal.value[0] = num_blocks_local
|
||||||
else:
|
else:
|
||||||
num_blocks_global = self.fd_config.parallel_config.total_block_num
|
num_blocks_local = self.fd_config.parallel_config.total_block_num
|
||||||
# NOTE(liuzichang): Too big num_blocks_global will lead to error 700
|
|
||||||
# 4. Updata share inputs
|
# 4. Updata share inputs
|
||||||
self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_global)
|
self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_local)
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
"""Initialize device and Construct model runner"""
|
"""Initialize device and Construct model runner"""
|
||||||
|
@@ -12,9 +12,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import difflib
|
|
||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import difflib
|
||||||
|
|
||||||
from paddleformers.trl.llm_utils import init_dist_env
|
from paddleformers.trl.llm_utils import init_dist_env
|
||||||
|
|
||||||
@@ -24,12 +23,7 @@ from fastdeploy.rl.rollout_model import RolloutModel
|
|||||||
_, ranks = init_dist_env()
|
_, ranks = init_dist_env()
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument("--model_path", type=str, required=True, help="Path to the model directory")
|
||||||
"--model_path",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the model directory"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# base result
|
# base result
|
||||||
@@ -55,6 +49,7 @@ for k, v in actor_eval_model.state_dict().items():
|
|||||||
for k, v in actor_eval_model.get_name_mappings_to_training().items():
|
for k, v in actor_eval_model.get_name_mappings_to_training().items():
|
||||||
content += f"{k}:{v}\n"
|
content += f"{k}:{v}\n"
|
||||||
|
|
||||||
|
|
||||||
def compare_strings(a: str, b: str) -> bool:
|
def compare_strings(a: str, b: str) -> bool:
|
||||||
if a == b:
|
if a == b:
|
||||||
print("✅ 两个字符串完全一致")
|
print("✅ 两个字符串完全一致")
|
||||||
@@ -68,8 +63,11 @@ def compare_strings(a: str, b: str) -> bool:
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
with open("baseline.txt", "r", encoding="utf-8") as f:
|
with open("baseline.txt", "r", encoding="utf-8") as f:
|
||||||
baseline = f.read()
|
baseline = f.read()
|
||||||
assert compare_strings(baseline, content), "In the unittest of RL scenario, your modification " \
|
assert compare_strings(baseline, content), (
|
||||||
"caused inconsistency in the content before and after. Please fix it. " \
|
"In the unittest of RL scenario, your modification "
|
||||||
|
"caused inconsistency in the content before and after. Please fix it. "
|
||||||
"Can request assistance from yuanlehome or gzy19990617 (github id)."
|
"Can request assistance from yuanlehome or gzy19990617 (github id)."
|
||||||
|
)
|
||||||
|
@@ -12,11 +12,9 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
def test_rollout_model_with_distributed_launch():
|
def test_rollout_model_with_distributed_launch():
|
||||||
@@ -29,26 +27,24 @@ def test_rollout_model_with_distributed_launch():
|
|||||||
|
|
||||||
base_path = os.getenv("MODEL_PATH")
|
base_path = os.getenv("MODEL_PATH")
|
||||||
if base_path:
|
if base_path:
|
||||||
model_path=os.path.join(base_path, "ernie-4_5-vl-28b-a3b-bf16-paddle")
|
model_path = os.path.join(base_path, "ernie-4_5-vl-28b-a3b-bf16-paddle")
|
||||||
else:
|
else:
|
||||||
model_path="./ernie-4_5-vl-28b-a3b-bf16-paddle"
|
model_path = "./ernie-4_5-vl-28b-a3b-bf16-paddle"
|
||||||
|
|
||||||
command = [
|
command = [
|
||||||
sys.executable,
|
sys.executable,
|
||||||
"-m", "paddle.distributed.launch",
|
"-m",
|
||||||
"--gpus", "0,1",
|
"paddle.distributed.launch",
|
||||||
|
"--gpus",
|
||||||
|
"0,1",
|
||||||
rollout_script,
|
rollout_script,
|
||||||
"--model_path", model_path,
|
"--model_path",
|
||||||
|
model_path,
|
||||||
]
|
]
|
||||||
|
|
||||||
print(f"Executing command: {' '.join(command)}")
|
print(f"Executing command: {' '.join(command)}")
|
||||||
|
|
||||||
process = subprocess.Popen(
|
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
command,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE,
|
|
||||||
text=True
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stdout, stderr = process.communicate(timeout=300)
|
stdout, stderr = process.communicate(timeout=300)
|
||||||
|
@@ -8,9 +8,33 @@ prompts = [
|
|||||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.00001, max_tokens=16)
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.00001, max_tokens=16)
|
||||||
|
|
||||||
# 加载模型
|
# 加载模型
|
||||||
llm = LLM(model="/data1/fastdeploy/ERNIE_300B_4L", tensor_parallel_size=16, max_model_len=8192, static_decode_blocks=0, quantization='wint8', block_size=16)
|
llm = LLM(
|
||||||
|
model="/data1/fastdeploy/ERNIE_300B_4L",
|
||||||
|
tensor_parallel_size=16,
|
||||||
|
max_model_len=8192,
|
||||||
|
static_decode_blocks=0,
|
||||||
|
quantization="wint8",
|
||||||
|
block_size=16,
|
||||||
|
)
|
||||||
|
|
||||||
# 批量进行推理(llm内部基于资源情况进行请求排队、动态插入处理)
|
# 批量进行推理(llm内部基于资源情况进行请求排队、动态插入处理)
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
assert outputs[0].outputs.token_ids==[23768, 97000, 47814, 59335, 68170, 183, 49080, 94717, 82966, 99140, 31615, 51497, 94851, 60764, 10889, 2]
|
assert outputs[0].outputs.token_ids == [
|
||||||
|
23768,
|
||||||
|
97000,
|
||||||
|
47814,
|
||||||
|
59335,
|
||||||
|
68170,
|
||||||
|
183,
|
||||||
|
49080,
|
||||||
|
94717,
|
||||||
|
82966,
|
||||||
|
99140,
|
||||||
|
31615,
|
||||||
|
51497,
|
||||||
|
94851,
|
||||||
|
60764,
|
||||||
|
10889,
|
||||||
|
2,
|
||||||
|
]
|
||||||
|
@@ -73,6 +73,5 @@ def test_sampler():
|
|||||||
print(next_tokens)
|
print(next_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_sampler()
|
test_sampler()
|
||||||
|
Reference in New Issue
Block a user