mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Sync] Update to latest code (#2679)
* [Sync] Update to latest code * Add new code files * Add new code files * update code * Try to fix build.sh * Try to fix build.sh * Update code * Update requirements.txt * Update code --------- Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
@@ -23,6 +22,7 @@ import paddle
|
||||
import paddle.distributed as dist
|
||||
import paddle.distributed.fleet as fleet
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import (DecodingConfig, DeviceConfig, FDConfig,
|
||||
GraphOptimizationConfig, LoadConfig,
|
||||
ModelConfig, MoEConfig, MoEPhase,
|
||||
@@ -61,14 +61,21 @@ class PaddleDisWorkerProc():
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a distributed worker and task queue for single-node multi-GPU setup.
|
||||
Args:
|
||||
fd_config (FDConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
"""
|
||||
self.fd_config = fd_config
|
||||
self.parallel_config = fd_config.parallel_config
|
||||
|
||||
# Initialize distributed enviroment
|
||||
(self.rank, self.local_rank) = self.init_distributed_enviroment()
|
||||
(self.ranks, self.local_rank) = self.init_distributed_enviroment()
|
||||
|
||||
assert self.parallel_config.tensor_parallel_degree * self.parallel_config.expert_parallel_degree == self.rank
|
||||
assert self.parallel_config.tensor_parallel_degree * self.parallel_config.expert_parallel_degree == self.ranks
|
||||
|
||||
self.fd_config.parallel_config.tensor_parallel_rank = \
|
||||
self.local_rank % self.parallel_config.tensor_parallel_degree
|
||||
@@ -81,8 +88,6 @@ class PaddleDisWorkerProc():
|
||||
self.fd_config.moe_config.num_experts_start_offset = \
|
||||
self.fd_config.parallel_config.expert_parallel_rank * self.fd_config.moe_config.num_experts_per_rank
|
||||
|
||||
self.fd_config.parallel_config.column_cut = False
|
||||
|
||||
# For auto TP split
|
||||
self.fd_config.model_config.tensor_parallel_degree = self.parallel_config.tensor_parallel_degree
|
||||
self.fd_config.model_config.tensor_parallel_rank = self.parallel_config.tensor_parallel_rank
|
||||
@@ -95,7 +100,7 @@ class PaddleDisWorkerProc():
|
||||
# TODO(gongshaotian): Use worker factory to get worker
|
||||
self.worker = get_worker(fd_config=fd_config,
|
||||
local_rank=self.local_rank,
|
||||
rank=self.rank)
|
||||
rank=self.ranks)
|
||||
|
||||
# Initialize task queue
|
||||
task_address = ('0.0.0.0',
|
||||
@@ -109,7 +114,7 @@ class PaddleDisWorkerProc():
|
||||
local_data_parallel_id=self.fd_config.parallel_config.
|
||||
expert_parallel_rank)
|
||||
|
||||
def init_health_status(self):
|
||||
def init_health_status(self) -> None:
|
||||
"""
|
||||
Initialize the health status of the worker.
|
||||
Worker Status:
|
||||
@@ -134,7 +139,7 @@ class PaddleDisWorkerProc():
|
||||
self.worker_ready_signal.value[self.local_rank % 8] = 1
|
||||
|
||||
# init worker_healthy_live_signal
|
||||
workers_alive = np.zeros(shape=[self.rank], dtype=np.int32)
|
||||
workers_alive = np.zeros(shape=[self.ranks], dtype=np.int32)
|
||||
self.worker_healthy_live_signal = IPCSignal(
|
||||
name="worker_healthy_live_signal",
|
||||
array=workers_alive,
|
||||
@@ -183,16 +188,7 @@ class PaddleDisWorkerProc():
|
||||
suffix=self.parallel_config.engine_pid,
|
||||
create=False)
|
||||
|
||||
# init model_weights_status
|
||||
workers_model_weights = np.zeros(shape=[1], dtype=np.int32)
|
||||
self.model_weights_status = IPCSignal(
|
||||
name="model_weights_status",
|
||||
array=workers_model_weights,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_pid,
|
||||
create=False)
|
||||
|
||||
def event_loop_ep(self):
|
||||
def event_loop_ep(self) -> None:
|
||||
"""
|
||||
Tmp loop function for ep utill DP is supported
|
||||
"""
|
||||
@@ -217,7 +213,7 @@ class PaddleDisWorkerProc():
|
||||
# These generated tokens can be obtained through get_output op.
|
||||
self.worker.execute_model()
|
||||
|
||||
def event_loop_normal(self):
|
||||
def event_loop_normal(self) -> None:
|
||||
""" Main event loop for Paddle Distrubuted Workers.
|
||||
TODO(gongshaotian): support remote calling of functions that control worker.
|
||||
"""
|
||||
@@ -225,6 +221,12 @@ class PaddleDisWorkerProc():
|
||||
self.nnode = 1
|
||||
req_ids = []
|
||||
while True:
|
||||
if self.local_rank == 0:
|
||||
if self.model_weights_status.value[0] != 0:
|
||||
self.exist_task_signal.value[0] = 2
|
||||
else:
|
||||
self.exist_task_signal.value[0] = 0
|
||||
|
||||
if self.parallel_config.tensor_parallel_degree > 1:
|
||||
# Synchronize before updating weights
|
||||
paddle.distributed.barrier()
|
||||
@@ -234,7 +236,7 @@ class PaddleDisWorkerProc():
|
||||
time.time())
|
||||
|
||||
# The first worker detects whether there are tasks in the task queue
|
||||
mp_num_per_node = self.rank / self.nnode
|
||||
mp_num_per_node = self.ranks / self.nnode
|
||||
if self.local_rank % mp_num_per_node == 0:
|
||||
if self.task_queue.num_tasks() > 0:
|
||||
if self.nnode > 1:
|
||||
@@ -249,6 +251,14 @@ class PaddleDisWorkerProc():
|
||||
# TODO(@wufeisheng): Split TP group and EP group
|
||||
paddle.distributed.barrier()
|
||||
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
if self.exist_task_signal.value[0] == 2:
|
||||
from fastdeploy.rl.dynamic_weight_manager import \
|
||||
DynamicWeightManager
|
||||
DynamicWeightManager.check_model_weights_status(
|
||||
self.model_weights_status, self.worker.model_runner,
|
||||
self.parallel_config.engine_pid)
|
||||
|
||||
if self.exist_task_signal.value[
|
||||
self.fd_config.parallel_config.expert_parallel_rank] == 1 or \
|
||||
self.task_queue.read_finish_flag.get() == 1:
|
||||
@@ -275,7 +285,7 @@ class PaddleDisWorkerProc():
|
||||
self.worker.preprocess_new_task(req_dicts)
|
||||
|
||||
if not self.worker.model_runner.not_need_stop():
|
||||
if self.rank > 1:
|
||||
if self.ranks > 1:
|
||||
paddle.distributed.barrier()
|
||||
|
||||
time.sleep(0.001)
|
||||
@@ -288,15 +298,15 @@ class PaddleDisWorkerProc():
|
||||
self.exist_prefill_task_signal.value[
|
||||
0] = self.worker.prefill_finished()
|
||||
|
||||
def init_distributed_enviroment(self, seed=20) -> List[int]:
|
||||
def init_distributed_enviroment(self, seed: int = 20) -> List[int]:
|
||||
""" Initialize Paddle Fleet and get rank of worker """
|
||||
# Global rank
|
||||
self.rank = dist.get_world_size()
|
||||
self.ranks = dist.get_world_size()
|
||||
dist_strategy = fleet.DistributedStrategy()
|
||||
|
||||
dist_strategy.hybrid_configs = {
|
||||
"dp_degree": 1,
|
||||
"mp_degree": self.rank,
|
||||
"mp_degree": self.ranks,
|
||||
"pp_degree": 1,
|
||||
"sharding_degree": 1,
|
||||
}
|
||||
@@ -308,10 +318,19 @@ class PaddleDisWorkerProc():
|
||||
# Local rank
|
||||
self.local_rank = fleet.worker_index()
|
||||
|
||||
return self.rank, self.local_rank
|
||||
return self.ranks, self.local_rank
|
||||
|
||||
def determine_num_available_blocks(self):
|
||||
"""
|
||||
def determine_num_available_blocks(self) -> None:
|
||||
"""Profiles the peak memory usage of the model to determine how many
|
||||
KV blocks may be allocated without OOMs.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||
that can be allocated with the remaining free memory.
|
||||
|
||||
.. tip::
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
if self.fd_config.parallel_config.do_profile:
|
||||
# 1. Get available memory(bytes)
|
||||
@@ -343,7 +362,8 @@ class PaddleDisWorkerProc():
|
||||
)
|
||||
|
||||
# 3. Send IPCSignal
|
||||
get_profile_block_num = np.zeros(shape=[self.rank], dtype=np.int32)
|
||||
get_profile_block_num = np.zeros(shape=[self.ranks],
|
||||
dtype=np.int32)
|
||||
self.get_profile_block_num_signal = IPCSignal(
|
||||
name="get_profile_block_num",
|
||||
array=get_profile_block_num,
|
||||
@@ -366,12 +386,12 @@ class PaddleDisWorkerProc():
|
||||
# 4. Updata share inputs
|
||||
self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_global)
|
||||
|
||||
def init_device(self):
|
||||
""" """
|
||||
def init_device(self) -> None:
|
||||
""" Initialize device and Construct model runner """
|
||||
self.worker.init_device()
|
||||
|
||||
def load_model(self):
|
||||
""" """
|
||||
def load_model(self) -> None:
|
||||
""" Load weights and create model """
|
||||
self.worker.load_model()
|
||||
|
||||
|
||||
@@ -428,9 +448,6 @@ def parse_args():
|
||||
parser.add_argument("--do_profile",
|
||||
action='store_true',
|
||||
help="do profile or not")
|
||||
parser.add_argument("--dynamic_load_weight",
|
||||
action='store_true',
|
||||
help="dynamic load weight or not")
|
||||
parser.add_argument("--pad_token_id",
|
||||
type=int,
|
||||
default=-1,
|
||||
@@ -467,14 +484,6 @@ def parse_args():
|
||||
default="WINT8",
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention_backend",
|
||||
default="APPEND_ATTN",
|
||||
type=str,
|
||||
choices=[
|
||||
"APPEND_ATTN",
|
||||
],
|
||||
)
|
||||
parser.add_argument("--max_num_batched_tokens",
|
||||
type=int,
|
||||
default=2048,
|
||||
@@ -527,11 +536,26 @@ def parse_args():
|
||||
parser.add_argument("--disable_any_whitespace",
|
||||
action='store_false',
|
||||
help="Disable any whitespace for guided decoding.")
|
||||
parser.add_argument("--dynamic_load_weight",
|
||||
action='store_true',
|
||||
help="Enable dynamic weight loading strategy")
|
||||
parser.add_argument(
|
||||
"--load_strategy",
|
||||
type=str,
|
||||
choices=['ipc', 'ipc_no_reshard', 'ipc_snapshot', 'meta', 'normal'],
|
||||
default='meta',
|
||||
help="Weight loading method when dynamic loading is enabled: "
|
||||
"'ipc': real-time IPC streaming with automatic resharding, "
|
||||
"'ipc_no_reshard': IPC streaming without weight processing, "
|
||||
"'ipc_snapshot': load from disk snapshot of IPC weights, "
|
||||
"'meta': provide RL traing worker, no_weights_load"
|
||||
"'normal':normal load weight")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def initialize_fd_config(args) -> FDConfig:
|
||||
def initialize_fd_config(args: argparse.Namespace) -> FDConfig:
|
||||
"""Initialize FDConfig
|
||||
TODO(gongshaotian): Unified all configs to FDConfig
|
||||
"""
|
||||
@@ -554,7 +578,7 @@ def initialize_fd_config(args) -> FDConfig:
|
||||
# model_config = ModelConfig()
|
||||
|
||||
decoding_config = DecodingConfig()
|
||||
decoding_config = MoEConfig()
|
||||
|
||||
speculative_config = SpeculativeConfig()
|
||||
parallel_config = ParallelConfig()
|
||||
load_config = LoadConfig()
|
||||
@@ -592,7 +616,6 @@ def initialize_fd_config(args) -> FDConfig:
|
||||
parallel_config.pad_token_id = args.pad_token_id
|
||||
parallel_config.eos_tokens_lens = args.eos_tokens_lens
|
||||
parallel_config.enable_chunked_prefill = args.enable_chunked_prefill
|
||||
parallel_config.attention_backend = args.attention_backend
|
||||
parallel_config.max_num_batched_tokens = args.max_num_batched_tokens
|
||||
parallel_config.enable_prefix_caching = args.enable_prefix_caching
|
||||
|
||||
@@ -600,6 +623,7 @@ def initialize_fd_config(args) -> FDConfig:
|
||||
parallel_config.tensor_parallel_degree = args.tensor_parallel_size
|
||||
parallel_config.expert_parallel_degree = args.expert_parallel_size
|
||||
parallel_config.splitwise_role = args.splitwise_role
|
||||
load_config.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1
|
||||
|
||||
parallel_config.guided_decoding_backend = args.guided_decoding_backend
|
||||
parallel_config.disable_any_whitespace = args.disable_any_whitespace
|
||||
@@ -659,19 +683,20 @@ def initialize_fd_config(args) -> FDConfig:
|
||||
|
||||
moe_config.num_max_dispatch_tokens_per_rank = config.get(
|
||||
"num_max_dispatch_tokens_per_rank", 256)
|
||||
moe_config.moe_use_aux_free = config.get("moe_use_aux_free", False)
|
||||
|
||||
model_config.ori_vocab_size = config.get("vocab_size", -1)
|
||||
if "Ernie4_5_ForCausalLM" in config.get("architectures"):
|
||||
model_config.ori_vocab_size = args.ori_vocab_size
|
||||
|
||||
quantization_config = config.get("quantization_config", None)
|
||||
if "DeepseekV3ForCausalLM" in config.get("architectures"):
|
||||
from paddleformers.transformers import AutoConfig
|
||||
model_config.deepseekv3 = AutoConfig.from_pretrained(
|
||||
args.model_name_or_path)
|
||||
|
||||
# Note(@wufeisheng): The `is_quantized` flag should be explicitly set to `true`
|
||||
# when the weights are actually quantized offline. For backward compatibility
|
||||
# with preview logic:
|
||||
# - If `quantization_config` is provided but `is_quantized` is not explicitly set,
|
||||
# the value of `is_quantized` will be determined by whether `kv_cache_quant_type`
|
||||
# has been configured.
|
||||
#TODO(@yuanrisheng): kv_cache quant config can only be
|
||||
# stored in model config file, which should be unified
|
||||
quantization_config = config.get("quantization_config", None)
|
||||
if not model_config.is_quantized:
|
||||
if quantization_config is not None:
|
||||
if "kv_cache_quant_type" not in quantization_config:
|
||||
@@ -689,9 +714,14 @@ def initialize_fd_config(args) -> FDConfig:
|
||||
elif args.quantization != "None":
|
||||
quantization_config = {}
|
||||
quant_config_name = args.quantization
|
||||
if use_moe and quant_config_name == "wint4":
|
||||
quantization_config["quantization"] = quant_config_name
|
||||
# use some trick code for ernie model and will unify it in future.
|
||||
is_ernie = "Ernie4_5_ForCausalLM" in config.get("architectures") or \
|
||||
"Ernie4_5_MoeForCausalLM" in config.get("architectures")
|
||||
if use_moe and quant_config_name == "wint4" and is_ernie:
|
||||
quantization_config["dense_quant_type"] = "wint8"
|
||||
quantization_config["moe_quant_type"] = "wint4"
|
||||
quantization_config["quantization"] = "mix_quant"
|
||||
quant_config_name = "mix_quant"
|
||||
else:
|
||||
quant_config_name = None
|
||||
@@ -706,20 +736,26 @@ def initialize_fd_config(args) -> FDConfig:
|
||||
if quant_config is not None:
|
||||
if model_config.is_quantized:
|
||||
logger.info(
|
||||
"=====The currently loaded model is an offline quantized model====="
|
||||
"Model Status: Offline Quantized (pre-quantized weights loaded)"
|
||||
)
|
||||
else:
|
||||
logger.info("=====The currently loaded model is the original model\
|
||||
The model will be quantized online=====")
|
||||
logger.info(f"{json.dumps(quantization_config, indent=2)}")
|
||||
logger.info(
|
||||
"Model Status: Original (will apply online quantization)")
|
||||
|
||||
logger.info(f"Quantization Method: {args.quantization or 'None'}")
|
||||
else:
|
||||
logger.info(
|
||||
"No quantization config found and use original weight and act dtype."
|
||||
)
|
||||
logger.info("============================================")
|
||||
|
||||
model_config.architectures = config.get("architectures")
|
||||
|
||||
logger.info("===========load_config==============")
|
||||
load_config.dynamic_load_weight = args.dynamic_load_weight
|
||||
load_config.load_strategy = args.load_strategy
|
||||
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
|
||||
logger.info(f"- Load strategy: {load_config.load_strategy}")
|
||||
|
||||
fd_config = FDConfig(model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
speculative_config=speculative_config,
|
||||
@@ -733,7 +769,7 @@ def initialize_fd_config(args) -> FDConfig:
|
||||
return fd_config
|
||||
|
||||
|
||||
def run_worker_proc():
|
||||
def run_worker_proc() -> None:
|
||||
"""
|
||||
start worker process
|
||||
"""
|
||||
|
Reference in New Issue
Block a user