mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
Fix rollout_model init (#2881)
This commit is contained in:
@@ -26,6 +26,7 @@ import paddle.distributed.fleet as fleet
|
||||
from fastdeploy.config import (DecodingConfig, DeviceConfig, FDConfig,
|
||||
GraphOptimizationConfig, LoadConfig,
|
||||
ModelConfig, ParallelConfig, SpeculativeConfig)
|
||||
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
|
||||
from fastdeploy.inter_communicator import IPCSignal
|
||||
from fastdeploy.model_executor.layers.quantization import \
|
||||
@@ -83,6 +84,30 @@ def init_distributed_environment(seed: int = 20) -> List[int]:
|
||||
|
||||
return ranks, local_rank
|
||||
|
||||
def update_fd_config_for_mm(fd_config: FDConfig) -> None:
|
||||
if fd_config.model_config.enable_mm:
|
||||
tokenizer = ErnieBotTokenizer.from_pretrained(
|
||||
fd_config.parallel_config.model_name_or_path,
|
||||
model_max_length=fd_config.parallel_config.max_model_len,
|
||||
padding_side="right",
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer.ignored_index = -100
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
|
||||
fd_config.model_config.tensor_parallel_degree = fd_config.parallel_config.tensor_parallel_size
|
||||
fd_config.model_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
vision_config = fd_config.model_config.vision_config
|
||||
vision_config.dtype = fd_config.model_config.dtype
|
||||
# vision_config.tensor_parallel_degree = fd_config.parallel_config.tensor_parallel_size
|
||||
# vision_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
fd_config.model_config.im_patch_id = tokenizer.get_vocab()[
|
||||
"<|IMAGE_PLACEHOLDER|>"
|
||||
]
|
||||
fd_config.model_config.think_end_id = tokenizer.get_vocab()["</think>"]
|
||||
fd_config.model_config.sequence_parallel = fd_config.parallel_config.sequence_parallel
|
||||
|
||||
class PaddleDisWorkerProc():
|
||||
"""
|
||||
Paddle Distrubuted wrapper for fastdeploy.worker.Worker,
|
||||
@@ -504,9 +529,9 @@ def parse_args():
|
||||
type=int,
|
||||
default=1,
|
||||
help="expert parallel size")
|
||||
parser.add_argument("--enable_expert_parallell",
|
||||
parser.add_argument("--enable_expert_parallel",
|
||||
action='store_true',
|
||||
help="enable expert parallell")
|
||||
help="enable expert parallel")
|
||||
parser.add_argument("--ori_vocab_size", type=int, default=None)
|
||||
|
||||
parser.add_argument("--quantization",
|
||||
@@ -517,7 +542,7 @@ def parse_args():
|
||||
"default is None. The priority of this configuration "\
|
||||
"is lower than that of the config file. " \
|
||||
"More complex quantization methods need to be configured via the config file.")
|
||||
parser.add_argument("--graph_optimiaztion_config",
|
||||
parser.add_argument("--graph_optimization_config",
|
||||
type=json.loads,
|
||||
default=None,
|
||||
help=" Configation of Graph optimization backend. "
|
||||
@@ -541,9 +566,8 @@ def parse_args():
|
||||
"'ipc': real-time IPC streaming with automatic resharding, "
|
||||
"'ipc_snapshot': load from disk snapshot of IPC weights.")
|
||||
parser.add_argument("--enable_mm",
|
||||
type=str,
|
||||
default="false",
|
||||
help="Whether to use vl")
|
||||
action='store_true',
|
||||
help="Whether to enable vl model")
|
||||
parser.add_argument("--enable_logprob",
|
||||
action='store_true',
|
||||
help="Enable output of token-level log probabilities.")
|
||||
@@ -572,11 +596,13 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
parallel_config.expert_parallel_rank = int(local_rank / ranks)
|
||||
load_config = LoadConfig(vars(args))
|
||||
|
||||
graph_opt_config = GraphOptimizationConfig(
|
||||
use_cudagraph=args.graph_optimiaztion_config["use_cudagraph"],
|
||||
graph_opt_level=args.graph_optimiaztion_config["graph_opt_level"],
|
||||
cudagraph_capture_sizes=args.graph_optimiaztion_config["cudagraph_capture_sizes"]
|
||||
)
|
||||
graph_opt_config = GraphOptimizationConfig()
|
||||
if args.graph_optimization_config is not None:
|
||||
graph_opt_config = GraphOptimizationConfig(
|
||||
use_cudagraph=args.graph_optimization_config["use_cudagraph"],
|
||||
graph_opt_level=args.graph_optimization_config["graph_opt_level"],
|
||||
cudagraph_capture_sizes=args.graph_optimization_config["cudagraph_capture_sizes"]
|
||||
)
|
||||
|
||||
# Note(tangbinhan): used for load_checkpoint
|
||||
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
|
||||
@@ -650,7 +676,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
)
|
||||
|
||||
# Set VL tag
|
||||
model_config.enable_mm = getattr(args, 'enable_mm', 'false').lower() == 'true'
|
||||
model_config.enable_mm = args.enable_mm
|
||||
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
|
||||
logger.info(f"- Load strategy: {load_config.load_strategy}")
|
||||
|
||||
@@ -662,6 +688,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
decoding_config=decoding_config,
|
||||
quant_config=quant_config,
|
||||
graph_opt_config=graph_opt_config)
|
||||
update_fd_config_for_mm(fd_config)
|
||||
|
||||
return fd_config
|
||||
|
||||
|
Reference in New Issue
Block a user