diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 34c295ac3..2e5473728 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -46,7 +46,6 @@ PRETRAINED_INIT_CONFIGURATION = { "num_max_dispatch_tokens_per_rank" : 256, "moe_use_aux_free" : False, "vocab_size" : -1, - "use_rope": True, "hidden_dropout_prob" : 0.0, "initializer_range" : 0.02, "max_position_embeddings" : 512, @@ -89,6 +88,7 @@ class ModelConfig: if hasattr(self, key): setattr(self, key, value) + assert self.model_name_or_path != "" pretrained_config, _ = PretrainedConfig.get_config_dict(self.model_name_or_path) self.pretrained_config = PretrainedConfig.from_dict(pretrained_config) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index df8e6918b..6a5d30d21 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -32,6 +32,7 @@ from typing import Dict, List, Optional, Tuple import numpy as np import paddle import zmq +from opentelemetry import trace from tqdm import tqdm from fastdeploy.engine.args_utils import EngineArgs @@ -42,13 +43,13 @@ from fastdeploy.input.preprocess import InputPreprocessor from fastdeploy.inter_communicator import (EngineCacheQueue, EngineWorkerQueue, IPCSignal, ZmqClient) from fastdeploy.metrics.metrics import main_process_metrics +from fastdeploy.metrics.trace_util import start_span, start_span_request from fastdeploy.model_executor.guided_decoding import schema_checker from fastdeploy.output.token_processor import (TokenProcessor, WarmUpTokenProcessor) from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.utils import EngineError, console_logger, llm_logger -from fastdeploy.metrics.trace_util import extract_from_metadata, start_span, start_span_request -from opentelemetry import trace + class LLMEngine(object): """ @@ -358,9 +359,9 @@ class LLMEngine(object): request, insert_task = None, [] results: List[Tuple[str, Optional[str]]] = list() if data: - request = Request.from_dict(data) - start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER) - + request = Request.from_dict(data) + start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER) + llm_logger.debug(f"Receive request: {request}") @@ -693,7 +694,7 @@ class LLMEngine(object): Insert tasks to engine. """ for task in tasks: - start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER) + start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER) # TODO 返回至 scheduler if allocated: current_tasks = [] @@ -1032,10 +1033,9 @@ class LLMEngine(object): f" --speculative_model_name_or_path {self.cfg.speculative_config.model_name_or_path}" f" --speculative_model_quantization {self.cfg.speculative_config.quantization}" f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}" - f" --graph_optimiaztion_config '{self.cfg.graph_optimization_config.to_json_string()}'" + f" --graph_optimization_config '{self.cfg.graph_optimization_config.to_json_string()}'" f" --guided_decoding_backend {self.cfg.guided_decoding_backend}" - f" --load_strategy {self.cfg.model_config.load_strategy}" - f" --enable_mm {self.cfg.enable_mm}") + f" --load_strategy {self.cfg.model_config.load_strategy}") worker_append_flag = { @@ -1050,6 +1050,7 @@ class LLMEngine(object): "disable_any_whitespace": self.cfg.disable_any_whitespace, "enable-custom-all-reduce": self.cfg.parallel_config.enable_custom_all_reduce, "enable_logprob": self.cfg.enable_logprob, + "enable_mm": self.cfg.enable_mm, } for worker_flag, value in worker_append_flag.items(): if value: diff --git a/fastdeploy/model_executor/layers/embeddings.py b/fastdeploy/model_executor/layers/embeddings.py index c7ad68ec1..f1c856604 100644 --- a/fastdeploy/model_executor/layers/embeddings.py +++ b/fastdeploy/model_executor/layers/embeddings.py @@ -58,7 +58,6 @@ class VocabParallelEmbedding(nn.Layer): self.column_cut = False self.world_size: int = hcg.get_model_parallel_world_size() self.ring_id: int = hcg.get_model_parallel_group().id - self.use_rope: bool = fd_config.model_config.use_rope self.use_ep: bool = fd_config.parallel_config.use_ep self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob self.initializer_range: float = fd_config.model_config.initializer_range @@ -92,14 +91,6 @@ class VocabParallelEmbedding(nn.Layer): self.embeddings.weight.is_distributed = True self.embeddings.weight.split_axis = 1 - if not self.use_rope: - self.position_embeddings = nn.Embedding( - self.max_position_embeddings, - embedding_dim, - weight_attr=paddle.ParamAttr(initializer=nn.initializer.Normal( - mean=0.0, std=self.initializer_range), ), - ) - self.prefix = prefix self.dropout = nn.Dropout(self.hidden_dropout_prob) diff --git a/fastdeploy/model_executor/layers/hydra_head.py b/fastdeploy/model_executor/layers/hydra_head.py deleted file mode 100644 index 2f3f026a5..000000000 --- a/fastdeploy/model_executor/layers/hydra_head.py +++ /dev/null @@ -1,217 +0,0 @@ -""" -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -import paddle -import paddle.nn.functional as F -from paddle import nn -from paddle.distributed import fleet -from paddle.distributed.fleet.meta_parallel import (ColumnParallelLinear, - VocabParallelEmbedding) -from paddleformers.utils.log import logger - -from .utils import get_tensor - - -class ResBlock(nn.Layer): - """ - A Residual Block module. - - This module performs a linear transformation followed by a SiLU activation, - and then adds the result to the original input, creating a residual connection. - - Args: - hidden_size (int): The size of the hidden layers in the block. - """ - - def __init__(self, hidden_size, num_condition=0): - super().__init__() - self.linear = nn.Linear(hidden_size * (num_condition + 1), hidden_size) - if num_condition > 0: - self.res_connection = nn.Linear( - hidden_size * (num_condition + 1), hidden_size - ) - else: - self.res_connection = nn.Identity() - # Initialize as an identity mapping - # _no_grad_fill_(self.linear.weight, 0) - # Use SiLU activation to keep consistent with the Llama model - self.act = nn.Silu() - - @paddle.no_grad() - def forward(self, x): - """ - Forward pass of the ResBlock. - - Args: - x (paddle.Tensor): Input tensor. - - Returns: - paddle.Tensor: Output after the residual connection and activation. - """ - return self.res_connection(x) + self.act(self.linear(x)) - - -class HydraHead(nn.Layer): - """ - A Hydra Head module. - - This module performs multi hydra head layers, - each of which is a hydra_lm_head followed by a head - - Args: - hydra_num_heads (int): The number of hyhra heads. - hydra_num_layers (int): The number of layers. - hidden_size (int): The size of the hidden layers in the block. - tensor_parallel_degree(int): TP degree. - vocab_size (int): The size of vocabulary. - """ - - def __init__( - self, - hydra_num_heads, - hydra_num_layers, - hidden_size, - tensor_parallel_degree, - vocab_size, - ): - super().__init__() - self.hydra_num_heads = hydra_num_heads - self.hydra_num_layers = hydra_num_layers - self.hidden_size = hidden_size - self.tensor_parallel_degree = tensor_parallel_degree - self.vocab_size = vocab_size - - self.hydra_mlp = nn.LayerList( - [ - nn.Sequential( - ResBlock(self.hidden_size, hydra_head_idx + 1), - *([ResBlock(self.hidden_size)] * (self.hydra_num_layers - 1)), - ) - for hydra_head_idx in range(self.hydra_num_heads) - ] - ) - - if self.tensor_parallel_degree > 1: - self.hydra_lm_head = nn.LayerList( - [ - ColumnParallelLinear( - self.hidden_size, - self.vocab_size, - weight_attr=paddle.ParamAttr( - initializer=nn.initializer.Normal(mean=0.0, std=0.0) - ), - gather_output=True, - has_bias=False, - ) - for _ in range(self.hydra_num_heads) - ] - ) - else: - self.hydra_lm_head = nn.LayerList( - [ - nn.Linear(self.hidden_size, self.vocab_size, bias_attr=False) - for _ in range(self.hydra_num_heads) - ] - ) - - self.embeddings = VocabParallelEmbedding( - vocab_size, - hidden_size, - mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), - weight_attr=paddle.ParamAttr(initializer=nn.initializer.Normal(mean=0.0)), - ) - - def custom_set_state_dict(self, state_dict): - """ - Load Parameter of Hydra Head from state_dict with custom names. - - Args: - state_dict (dict): KV pair of name and parameters. - """ - for hydra_head_idx in range(self.hydra_num_heads): - self.hydra_mlp[hydra_head_idx][0].res_connection.weight.set_value( - get_tensor( - state_dict.pop(f"0.{hydra_head_idx}.0.res_connection.weight") - ) - ) - self.hydra_mlp[hydra_head_idx][0].res_connection.bias.set_value( - get_tensor(state_dict.pop(f"0.{hydra_head_idx}.0.res_connection.bias")) - ) - - for layer_idx in range(self.hydra_num_layers): - self.hydra_mlp[hydra_head_idx][layer_idx].linear.weight.set_value( - get_tensor( - state_dict.pop(f"0.{hydra_head_idx}.{layer_idx}.linear.weight") - ) - ) - self.hydra_mlp[hydra_head_idx][layer_idx].linear.bias.set_value( - get_tensor( - state_dict.pop(f"0.{hydra_head_idx}.{layer_idx}.linear.bias") - ) - ) - - self.hydra_lm_head[hydra_head_idx].weight.set_value( - get_tensor(state_dict.pop(f"1.{hydra_head_idx}.weight")) - ) - - self.embeddings.weight.set_value( - get_tensor(state_dict.pop("embeddings.weight")) - ) - - def set_state_dict(self, state_dict): - """ - Load Parameter of Hydra Head from state_dict. - - Args: - state_dict (dict): KV pair of name and parameters. - """ - is_custom = True - for key in state_dict.keys(): - if key != "embeddings.weight" and ( - "hydra_mlp" in key or "hydra_head" in key - ): - is_custom = False - break - - if is_custom: - logger.info("Hydra use custom set_state_dict") - self.custom_set_state_dict(state_dict) - else: - logger.info("Hydra use default set_state_dict") - super().set_state_dict(state_dict) - - @paddle.no_grad() - def forward(self, input_ids, hidden_states, next_tokens): - """ - Forward pass of Hydra Head - - Args: - input_ids: [batch_size, 1] The tokens sampled by the previous head go through the embedding, - starting with the last accept token - hidden_states: [batch_size, hidden_size] The hidden_states of the last accept_tokens - """ - hydra_inputs = [hidden_states] - input_embeds = self.embeddings(input_ids) - for hydra_head_idx in range(self.hydra_num_heads): - hydra_inputs.append(input_embeds) - head_input = paddle.concat(hydra_inputs, axis=-1) - hidden_states = self.hydra_mlp[hydra_head_idx](head_input) - logits = self.hydra_lm_head[hydra_head_idx](hidden_states) - probs = F.softmax(logits) - _, topk_tokens = paddle.topk(probs, k=1, axis=-1) - next_tokens[:, 1 + hydra_head_idx : 2 + hydra_head_idx] = topk_tokens[:] - - input_embeds = self.embeddings(next_tokens[:, 1 + hydra_head_idx]) diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 98ec7090b..873bc041f 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -606,8 +606,8 @@ class Ernie4_5_PretrainedModel(PretrainedModel): return final_actions mappings = get_tensor_parallel_split_mappings( config.num_hidden_layers, - config.moe_num_experts, - config.moe_layer_start_index, + getattr(config, "moe_num_experts", 0), + getattr(config, "moe_layer_start_index", -1), config.prefix_name, ) return mappings diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index 8aa11897d..dcb95ea2d 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -51,12 +51,13 @@ class RolloutModelConfig: enable_prefix_caching: bool = False, splitwise_role: str = "mixed", expert_parallel_size: int = 1, - enable_expert_parallell: bool = False, + enable_expert_parallel: bool = False, ori_vocab_size: int = None, quantization: str = "None", guided_decoding_backend: str = "off", disable_any_whitespace: bool = True, enable_logprob: bool = False, + graph_optimization_config: str = None, ): # Required parameters self.model_name_or_path = model_name_or_path @@ -90,12 +91,13 @@ class RolloutModelConfig: self.enable_prefix_caching = enable_prefix_caching self.splitwise_role = splitwise_role self.expert_parallel_size = expert_parallel_size - self.enable_expert_parallell = enable_expert_parallell + self.enable_expert_parallel = enable_expert_parallel self.ori_vocab_size = ori_vocab_size self.quantization = quantization self.guided_decoding_backend = guided_decoding_backend self.disable_any_whitespace = disable_any_whitespace self.enable_logprob = enable_logprob + self.graph_optimization_config = graph_optimization_config def __str__(self): return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items()) diff --git a/fastdeploy/rl/rollout_model.py b/fastdeploy/rl/rollout_model.py index 99ab3455f..7c1e0fe0e 100644 --- a/fastdeploy/rl/rollout_model.py +++ b/fastdeploy/rl/rollout_model.py @@ -39,17 +39,17 @@ class RolloutModel(nn.Layer): """Initialize with FastDeploy configuration.""" super(RolloutModel, self).__init__() self.fd_config = rollout_model_config.initialize() - self._init_model() + self.rollout_model = self._init_model() - def _init_model(self): + def _init_model(self) -> nn.Layer: """Load model from loader based on config.""" context = paddle.LazyGuard() architectures = f"{self.fd_config.model_config.architectures[0]}RL" with context: model_cls = ModelRegistry.get_class(architectures) model = model_cls(self.fd_config) - - self.rollout_model = model.eval() + model.eval() + return model def get_name_mappings_to_training(self) -> Dict[str, str]: """Get parameter name mappings between rollout and training models.""" @@ -74,15 +74,14 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM): super(Ernie4_5_MoeForCausalLMRL, self).__init__(fd_config) @classmethod - def name(self): + def name(self) -> str: """name""" return "Ernie4_5_MoeForCausalLMRL" - def get_name_mappings_to_training(self): + def get_name_mappings_to_training(self) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" - have_bias = self.fd_config.model_config.get("have_norm_bias", False) # Prepare placeholders - place_holders = ["weight"] + (["bias"] if have_bias else []) + place_holders = ["weight"] # Initialize mapping dictionary infer_to_train = {} @@ -94,7 +93,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM): f"{base_name}.embed_tokens.weight", "lm_head.linear.weight": "lm_head.weight" } - if self.fd_config.model_config.get("tie_word_embeddings", False): + if getattr(self.fd_config.model_config, "tie_word_embeddings", False): # Support tie_word_embeddings logger.debug("enable tie_word_embeddings") static_mappings.pop("lm_head.linear.weight") @@ -153,15 +152,14 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener super(Ernie4_5_VLMoeForConditionalGenerationRL, self).__init__(fd_config) @classmethod - def name(self): + def name(self) -> str: """name""" return "Ernie4_5_VLMoeForConditionalGenerationRL" - def get_name_mappings_to_training(self): + def get_name_mappings_to_training(self) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" - have_bias = self.fd_config.model_config.get("have_norm_bias", False) # Prepare placeholders - place_holders = ["weight"] + (["bias"] if have_bias else []) + place_holders = ["weight"] # Initialize mapping dictionary infer_to_train = {} @@ -173,7 +171,7 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener f"{base_name}.embed_tokens.weight", "lm_head.linear.weight": "lm_head.weight" } - if self.fd_config.model_config.get("tie_word_embeddings", False): + if getattr(self.fd_config.model_config, "tie_word_embeddings", False): # Support tie_word_embeddings logger.debug("enable tie_word_embeddings") static_mappings.pop("lm_head.linear.weight") @@ -257,11 +255,11 @@ class Qwen2ForCausalLMRL(Qwen2ForCausalLM): super(Qwen2ForCausalLMRL, self).__init__(fd_config) @classmethod - def name(self): + def name(self) -> str: """name""" return "Qwen2ForCausalLMRL" - def get_name_mappings_to_training(self): + def get_name_mappings_to_training(self) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" # Prepare placeholders place_holders = ["weight"] @@ -307,11 +305,11 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM): super(Qwen3MoeForCausalLMRL, self).__init__(fd_config) @classmethod - def name(self): + def name(self) -> str: """name""" return "Qwen3MoeForCausalLMRL" - def get_name_mappings_to_training(self): + def get_name_mappings_to_training(self) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" # Prepare placeholders place_holders = ["weight"] @@ -379,6 +377,6 @@ class Qwen3ForCausalLMRL(Qwen3ForCausalLM): super(Qwen3ForCausalLMRL, self).__init__(fd_config) @classmethod - def name(self): + def name(self) -> str: """name""" return "Qwen3ForCausalLMRL" diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index d2bb41a26..c1ab3082b 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -47,14 +47,12 @@ from fastdeploy.platforms import current_platform if not current_platform.is_dcu(): from fastdeploy.spec_decode import MTPProposer, NgramProposer -from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer from fastdeploy.input.mm_processor import DataProcessor from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import \ ScatterOp from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput -from fastdeploy.worker.utils import check_safetensors_model class GPUModelRunner(ModelRunnerBase): @@ -81,16 +79,7 @@ class GPUModelRunner(ModelRunnerBase): # VL model config: if self.enable_mm: - model_path = os.path.dirname(self.parallel_config.model_name_or_path) - self.is_safetensors_model = check_safetensors_model( - self.parallel_config.model_name_or_path) - if not self.is_safetensors_model: - self.tokenizer_path = self.image_preprocessor_path = model_path - else: - self.tokenizer_path = self.parallel_config.model_name_or_path - self.image_preprocessor_path = self.parallel_config.model_name_or_path - self.vision_model_name_or_path = os.path.join( - model_path, "DFNRopeVisionTransformer") + self._init_image_preprocess() self.amp_black = [ "reduce_sum", @@ -734,8 +723,6 @@ class GPUModelRunner(ModelRunnerBase): f"Starting to load model {self.model_config.architectures[0]}") time_before_load = time.perf_counter() # 1. Load original model - if self.enable_mm: - self.load_mm_config_and_image_preprocess() self.model = get_model_from_loader(fd_config=self.fd_config) # 1.1 Load RL dynamic model if self.fd_config.load_config.dynamic_load_weight: @@ -1440,8 +1427,8 @@ class GPUModelRunner(ModelRunnerBase): def _init_image_preprocess(self) -> None: processor = DataProcessor( - tokenizer_name=self.tokenizer_path, - image_preprocessor_name=str(self.image_preprocessor_path), + tokenizer_name=self.parallel_config.model_name_or_path, + image_preprocessor_name=str(self.parallel_config.model_name_or_path), ) processor.eval() image_preprocess = processor.image_preprocessor @@ -1459,31 +1446,6 @@ class GPUModelRunner(ModelRunnerBase): -1) self.image_preprocess = image_preprocess - def load_mm_config_and_image_preprocess(self) -> None: - tokenizer = ErnieBotTokenizer.from_pretrained( - self.tokenizer_path, - model_max_length=self.parallel_config.max_model_len, - padding_side="right", - use_fast=False, - ) - tokenizer.ignored_index = -100 - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.unk_token - - self.fd_config.model_config.tensor_parallel_degree = self.parallel_config.tensor_parallel_size - self.fd_config.model_config.tensor_parallel_rank = self.parallel_config.tensor_parallel_rank - vision_config = self.fd_config.model_config.vision_config - vision_config.dtype = self.fd_config.model_config.dtype - vision_config.tensor_parallel_degree = self.parallel_config.tensor_parallel_size - vision_config.tensor_parallel_rank = self.parallel_config.tensor_parallel_rank - self.fd_config.model_config.im_patch_id = tokenizer.get_vocab()[ - "<|IMAGE_PLACEHOLDER|>" - ] - self.fd_config.model_config.think_end_id = tokenizer.get_vocab()[""] - self.fd_config.model_config.sequence_parallel = self.parallel_config.sequence_parallel - self.model_config = self.fd_config.model_config - self._init_image_preprocess() - def _preprocess_mm_task(self, one: dict) -> None: """process batch""" diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 99504008c..e273d0714 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -26,6 +26,7 @@ import paddle.distributed.fleet as fleet from fastdeploy.config import (DecodingConfig, DeviceConfig, FDConfig, GraphOptimizationConfig, LoadConfig, ModelConfig, ParallelConfig, SpeculativeConfig) +from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue from fastdeploy.inter_communicator import IPCSignal from fastdeploy.model_executor.layers.quantization import \ @@ -83,6 +84,30 @@ def init_distributed_environment(seed: int = 20) -> List[int]: return ranks, local_rank +def update_fd_config_for_mm(fd_config: FDConfig) -> None: + if fd_config.model_config.enable_mm: + tokenizer = ErnieBotTokenizer.from_pretrained( + fd_config.parallel_config.model_name_or_path, + model_max_length=fd_config.parallel_config.max_model_len, + padding_side="right", + use_fast=False, + ) + tokenizer.ignored_index = -100 + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.unk_token + + fd_config.model_config.tensor_parallel_degree = fd_config.parallel_config.tensor_parallel_size + fd_config.model_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank + vision_config = fd_config.model_config.vision_config + vision_config.dtype = fd_config.model_config.dtype + # vision_config.tensor_parallel_degree = fd_config.parallel_config.tensor_parallel_size + # vision_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank + fd_config.model_config.im_patch_id = tokenizer.get_vocab()[ + "<|IMAGE_PLACEHOLDER|>" + ] + fd_config.model_config.think_end_id = tokenizer.get_vocab()[""] + fd_config.model_config.sequence_parallel = fd_config.parallel_config.sequence_parallel + class PaddleDisWorkerProc(): """ Paddle Distrubuted wrapper for fastdeploy.worker.Worker, @@ -504,9 +529,9 @@ def parse_args(): type=int, default=1, help="expert parallel size") - parser.add_argument("--enable_expert_parallell", + parser.add_argument("--enable_expert_parallel", action='store_true', - help="enable expert parallell") + help="enable expert parallel") parser.add_argument("--ori_vocab_size", type=int, default=None) parser.add_argument("--quantization", @@ -517,7 +542,7 @@ def parse_args(): "default is None. The priority of this configuration "\ "is lower than that of the config file. " \ "More complex quantization methods need to be configured via the config file.") - parser.add_argument("--graph_optimiaztion_config", + parser.add_argument("--graph_optimization_config", type=json.loads, default=None, help=" Configation of Graph optimization backend. " @@ -541,9 +566,8 @@ def parse_args(): "'ipc': real-time IPC streaming with automatic resharding, " "'ipc_snapshot': load from disk snapshot of IPC weights.") parser.add_argument("--enable_mm", - type=str, - default="false", - help="Whether to use vl") + action='store_true', + help="Whether to enable vl model") parser.add_argument("--enable_logprob", action='store_true', help="Enable output of token-level log probabilities.") @@ -572,11 +596,13 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: parallel_config.expert_parallel_rank = int(local_rank / ranks) load_config = LoadConfig(vars(args)) - graph_opt_config = GraphOptimizationConfig( - use_cudagraph=args.graph_optimiaztion_config["use_cudagraph"], - graph_opt_level=args.graph_optimiaztion_config["graph_opt_level"], - cudagraph_capture_sizes=args.graph_optimiaztion_config["cudagraph_capture_sizes"] - ) + graph_opt_config = GraphOptimizationConfig() + if args.graph_optimization_config is not None: + graph_opt_config = GraphOptimizationConfig( + use_cudagraph=args.graph_optimization_config["use_cudagraph"], + graph_opt_level=args.graph_optimization_config["graph_opt_level"], + cudagraph_capture_sizes=args.graph_optimization_config["cudagraph_capture_sizes"] + ) # Note(tangbinhan): used for load_checkpoint model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank @@ -650,7 +676,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: ) # Set VL tag - model_config.enable_mm = getattr(args, 'enable_mm', 'false').lower() == 'true' + model_config.enable_mm = args.enable_mm logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}") logger.info(f"- Load strategy: {load_config.load_strategy}") @@ -662,6 +688,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: decoding_config=decoding_config, quant_config=quant_config, graph_opt_config=graph_opt_config) + update_fd_config_for_mm(fd_config) return fd_config