mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
Fix rollout_model init (#2881)
This commit is contained in:
@@ -46,7 +46,6 @@ PRETRAINED_INIT_CONFIGURATION = {
|
|||||||
"num_max_dispatch_tokens_per_rank" : 256,
|
"num_max_dispatch_tokens_per_rank" : 256,
|
||||||
"moe_use_aux_free" : False,
|
"moe_use_aux_free" : False,
|
||||||
"vocab_size" : -1,
|
"vocab_size" : -1,
|
||||||
"use_rope": True,
|
|
||||||
"hidden_dropout_prob" : 0.0,
|
"hidden_dropout_prob" : 0.0,
|
||||||
"initializer_range" : 0.02,
|
"initializer_range" : 0.02,
|
||||||
"max_position_embeddings" : 512,
|
"max_position_embeddings" : 512,
|
||||||
@@ -89,6 +88,7 @@ class ModelConfig:
|
|||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
assert self.model_name_or_path != ""
|
||||||
pretrained_config, _ = PretrainedConfig.get_config_dict(self.model_name_or_path)
|
pretrained_config, _ = PretrainedConfig.get_config_dict(self.model_name_or_path)
|
||||||
self.pretrained_config = PretrainedConfig.from_dict(pretrained_config)
|
self.pretrained_config = PretrainedConfig.from_dict(pretrained_config)
|
||||||
|
|
||||||
|
@@ -32,6 +32,7 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
import zmq
|
import zmq
|
||||||
|
from opentelemetry import trace
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from fastdeploy.engine.args_utils import EngineArgs
|
from fastdeploy.engine.args_utils import EngineArgs
|
||||||
@@ -42,13 +43,13 @@ from fastdeploy.input.preprocess import InputPreprocessor
|
|||||||
from fastdeploy.inter_communicator import (EngineCacheQueue, EngineWorkerQueue,
|
from fastdeploy.inter_communicator import (EngineCacheQueue, EngineWorkerQueue,
|
||||||
IPCSignal, ZmqClient)
|
IPCSignal, ZmqClient)
|
||||||
from fastdeploy.metrics.metrics import main_process_metrics
|
from fastdeploy.metrics.metrics import main_process_metrics
|
||||||
|
from fastdeploy.metrics.trace_util import start_span, start_span_request
|
||||||
from fastdeploy.model_executor.guided_decoding import schema_checker
|
from fastdeploy.model_executor.guided_decoding import schema_checker
|
||||||
from fastdeploy.output.token_processor import (TokenProcessor,
|
from fastdeploy.output.token_processor import (TokenProcessor,
|
||||||
WarmUpTokenProcessor)
|
WarmUpTokenProcessor)
|
||||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||||
from fastdeploy.utils import EngineError, console_logger, llm_logger
|
from fastdeploy.utils import EngineError, console_logger, llm_logger
|
||||||
from fastdeploy.metrics.trace_util import extract_from_metadata, start_span, start_span_request
|
|
||||||
from opentelemetry import trace
|
|
||||||
|
|
||||||
class LLMEngine(object):
|
class LLMEngine(object):
|
||||||
"""
|
"""
|
||||||
@@ -1032,10 +1033,9 @@ class LLMEngine(object):
|
|||||||
f" --speculative_model_name_or_path {self.cfg.speculative_config.model_name_or_path}"
|
f" --speculative_model_name_or_path {self.cfg.speculative_config.model_name_or_path}"
|
||||||
f" --speculative_model_quantization {self.cfg.speculative_config.quantization}"
|
f" --speculative_model_quantization {self.cfg.speculative_config.quantization}"
|
||||||
f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}"
|
f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}"
|
||||||
f" --graph_optimiaztion_config '{self.cfg.graph_optimization_config.to_json_string()}'"
|
f" --graph_optimization_config '{self.cfg.graph_optimization_config.to_json_string()}'"
|
||||||
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
|
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
|
||||||
f" --load_strategy {self.cfg.model_config.load_strategy}"
|
f" --load_strategy {self.cfg.model_config.load_strategy}")
|
||||||
f" --enable_mm {self.cfg.enable_mm}")
|
|
||||||
|
|
||||||
|
|
||||||
worker_append_flag = {
|
worker_append_flag = {
|
||||||
@@ -1050,6 +1050,7 @@ class LLMEngine(object):
|
|||||||
"disable_any_whitespace": self.cfg.disable_any_whitespace,
|
"disable_any_whitespace": self.cfg.disable_any_whitespace,
|
||||||
"enable-custom-all-reduce": self.cfg.parallel_config.enable_custom_all_reduce,
|
"enable-custom-all-reduce": self.cfg.parallel_config.enable_custom_all_reduce,
|
||||||
"enable_logprob": self.cfg.enable_logprob,
|
"enable_logprob": self.cfg.enable_logprob,
|
||||||
|
"enable_mm": self.cfg.enable_mm,
|
||||||
}
|
}
|
||||||
for worker_flag, value in worker_append_flag.items():
|
for worker_flag, value in worker_append_flag.items():
|
||||||
if value:
|
if value:
|
||||||
|
@@ -58,7 +58,6 @@ class VocabParallelEmbedding(nn.Layer):
|
|||||||
self.column_cut = False
|
self.column_cut = False
|
||||||
self.world_size: int = hcg.get_model_parallel_world_size()
|
self.world_size: int = hcg.get_model_parallel_world_size()
|
||||||
self.ring_id: int = hcg.get_model_parallel_group().id
|
self.ring_id: int = hcg.get_model_parallel_group().id
|
||||||
self.use_rope: bool = fd_config.model_config.use_rope
|
|
||||||
self.use_ep: bool = fd_config.parallel_config.use_ep
|
self.use_ep: bool = fd_config.parallel_config.use_ep
|
||||||
self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob
|
self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob
|
||||||
self.initializer_range: float = fd_config.model_config.initializer_range
|
self.initializer_range: float = fd_config.model_config.initializer_range
|
||||||
@@ -92,14 +91,6 @@ class VocabParallelEmbedding(nn.Layer):
|
|||||||
self.embeddings.weight.is_distributed = True
|
self.embeddings.weight.is_distributed = True
|
||||||
self.embeddings.weight.split_axis = 1
|
self.embeddings.weight.split_axis = 1
|
||||||
|
|
||||||
if not self.use_rope:
|
|
||||||
self.position_embeddings = nn.Embedding(
|
|
||||||
self.max_position_embeddings,
|
|
||||||
embedding_dim,
|
|
||||||
weight_attr=paddle.ParamAttr(initializer=nn.initializer.Normal(
|
|
||||||
mean=0.0, std=self.initializer_range), ),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.dropout = nn.Dropout(self.hidden_dropout_prob)
|
self.dropout = nn.Dropout(self.hidden_dropout_prob)
|
||||||
|
|
||||||
|
@@ -1,217 +0,0 @@
|
|||||||
"""
|
|
||||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import paddle
|
|
||||||
import paddle.nn.functional as F
|
|
||||||
from paddle import nn
|
|
||||||
from paddle.distributed import fleet
|
|
||||||
from paddle.distributed.fleet.meta_parallel import (ColumnParallelLinear,
|
|
||||||
VocabParallelEmbedding)
|
|
||||||
from paddleformers.utils.log import logger
|
|
||||||
|
|
||||||
from .utils import get_tensor
|
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(nn.Layer):
|
|
||||||
"""
|
|
||||||
A Residual Block module.
|
|
||||||
|
|
||||||
This module performs a linear transformation followed by a SiLU activation,
|
|
||||||
and then adds the result to the original input, creating a residual connection.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hidden_size (int): The size of the hidden layers in the block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, num_condition=0):
|
|
||||||
super().__init__()
|
|
||||||
self.linear = nn.Linear(hidden_size * (num_condition + 1), hidden_size)
|
|
||||||
if num_condition > 0:
|
|
||||||
self.res_connection = nn.Linear(
|
|
||||||
hidden_size * (num_condition + 1), hidden_size
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.res_connection = nn.Identity()
|
|
||||||
# Initialize as an identity mapping
|
|
||||||
# _no_grad_fill_(self.linear.weight, 0)
|
|
||||||
# Use SiLU activation to keep consistent with the Llama model
|
|
||||||
self.act = nn.Silu()
|
|
||||||
|
|
||||||
@paddle.no_grad()
|
|
||||||
def forward(self, x):
|
|
||||||
"""
|
|
||||||
Forward pass of the ResBlock.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (paddle.Tensor): Input tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
paddle.Tensor: Output after the residual connection and activation.
|
|
||||||
"""
|
|
||||||
return self.res_connection(x) + self.act(self.linear(x))
|
|
||||||
|
|
||||||
|
|
||||||
class HydraHead(nn.Layer):
|
|
||||||
"""
|
|
||||||
A Hydra Head module.
|
|
||||||
|
|
||||||
This module performs multi hydra head layers,
|
|
||||||
each of which is a hydra_lm_head followed by a head
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hydra_num_heads (int): The number of hyhra heads.
|
|
||||||
hydra_num_layers (int): The number of layers.
|
|
||||||
hidden_size (int): The size of the hidden layers in the block.
|
|
||||||
tensor_parallel_degree(int): TP degree.
|
|
||||||
vocab_size (int): The size of vocabulary.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hydra_num_heads,
|
|
||||||
hydra_num_layers,
|
|
||||||
hidden_size,
|
|
||||||
tensor_parallel_degree,
|
|
||||||
vocab_size,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.hydra_num_heads = hydra_num_heads
|
|
||||||
self.hydra_num_layers = hydra_num_layers
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.tensor_parallel_degree = tensor_parallel_degree
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
|
|
||||||
self.hydra_mlp = nn.LayerList(
|
|
||||||
[
|
|
||||||
nn.Sequential(
|
|
||||||
ResBlock(self.hidden_size, hydra_head_idx + 1),
|
|
||||||
*([ResBlock(self.hidden_size)] * (self.hydra_num_layers - 1)),
|
|
||||||
)
|
|
||||||
for hydra_head_idx in range(self.hydra_num_heads)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.tensor_parallel_degree > 1:
|
|
||||||
self.hydra_lm_head = nn.LayerList(
|
|
||||||
[
|
|
||||||
ColumnParallelLinear(
|
|
||||||
self.hidden_size,
|
|
||||||
self.vocab_size,
|
|
||||||
weight_attr=paddle.ParamAttr(
|
|
||||||
initializer=nn.initializer.Normal(mean=0.0, std=0.0)
|
|
||||||
),
|
|
||||||
gather_output=True,
|
|
||||||
has_bias=False,
|
|
||||||
)
|
|
||||||
for _ in range(self.hydra_num_heads)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.hydra_lm_head = nn.LayerList(
|
|
||||||
[
|
|
||||||
nn.Linear(self.hidden_size, self.vocab_size, bias_attr=False)
|
|
||||||
for _ in range(self.hydra_num_heads)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.embeddings = VocabParallelEmbedding(
|
|
||||||
vocab_size,
|
|
||||||
hidden_size,
|
|
||||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
|
||||||
weight_attr=paddle.ParamAttr(initializer=nn.initializer.Normal(mean=0.0)),
|
|
||||||
)
|
|
||||||
|
|
||||||
def custom_set_state_dict(self, state_dict):
|
|
||||||
"""
|
|
||||||
Load Parameter of Hydra Head from state_dict with custom names.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state_dict (dict): KV pair of name and parameters.
|
|
||||||
"""
|
|
||||||
for hydra_head_idx in range(self.hydra_num_heads):
|
|
||||||
self.hydra_mlp[hydra_head_idx][0].res_connection.weight.set_value(
|
|
||||||
get_tensor(
|
|
||||||
state_dict.pop(f"0.{hydra_head_idx}.0.res_connection.weight")
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.hydra_mlp[hydra_head_idx][0].res_connection.bias.set_value(
|
|
||||||
get_tensor(state_dict.pop(f"0.{hydra_head_idx}.0.res_connection.bias"))
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer_idx in range(self.hydra_num_layers):
|
|
||||||
self.hydra_mlp[hydra_head_idx][layer_idx].linear.weight.set_value(
|
|
||||||
get_tensor(
|
|
||||||
state_dict.pop(f"0.{hydra_head_idx}.{layer_idx}.linear.weight")
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.hydra_mlp[hydra_head_idx][layer_idx].linear.bias.set_value(
|
|
||||||
get_tensor(
|
|
||||||
state_dict.pop(f"0.{hydra_head_idx}.{layer_idx}.linear.bias")
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.hydra_lm_head[hydra_head_idx].weight.set_value(
|
|
||||||
get_tensor(state_dict.pop(f"1.{hydra_head_idx}.weight"))
|
|
||||||
)
|
|
||||||
|
|
||||||
self.embeddings.weight.set_value(
|
|
||||||
get_tensor(state_dict.pop("embeddings.weight"))
|
|
||||||
)
|
|
||||||
|
|
||||||
def set_state_dict(self, state_dict):
|
|
||||||
"""
|
|
||||||
Load Parameter of Hydra Head from state_dict.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state_dict (dict): KV pair of name and parameters.
|
|
||||||
"""
|
|
||||||
is_custom = True
|
|
||||||
for key in state_dict.keys():
|
|
||||||
if key != "embeddings.weight" and (
|
|
||||||
"hydra_mlp" in key or "hydra_head" in key
|
|
||||||
):
|
|
||||||
is_custom = False
|
|
||||||
break
|
|
||||||
|
|
||||||
if is_custom:
|
|
||||||
logger.info("Hydra use custom set_state_dict")
|
|
||||||
self.custom_set_state_dict(state_dict)
|
|
||||||
else:
|
|
||||||
logger.info("Hydra use default set_state_dict")
|
|
||||||
super().set_state_dict(state_dict)
|
|
||||||
|
|
||||||
@paddle.no_grad()
|
|
||||||
def forward(self, input_ids, hidden_states, next_tokens):
|
|
||||||
"""
|
|
||||||
Forward pass of Hydra Head
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: [batch_size, 1] The tokens sampled by the previous head go through the embedding,
|
|
||||||
starting with the last accept token
|
|
||||||
hidden_states: [batch_size, hidden_size] The hidden_states of the last accept_tokens
|
|
||||||
"""
|
|
||||||
hydra_inputs = [hidden_states]
|
|
||||||
input_embeds = self.embeddings(input_ids)
|
|
||||||
for hydra_head_idx in range(self.hydra_num_heads):
|
|
||||||
hydra_inputs.append(input_embeds)
|
|
||||||
head_input = paddle.concat(hydra_inputs, axis=-1)
|
|
||||||
hidden_states = self.hydra_mlp[hydra_head_idx](head_input)
|
|
||||||
logits = self.hydra_lm_head[hydra_head_idx](hidden_states)
|
|
||||||
probs = F.softmax(logits)
|
|
||||||
_, topk_tokens = paddle.topk(probs, k=1, axis=-1)
|
|
||||||
next_tokens[:, 1 + hydra_head_idx : 2 + hydra_head_idx] = topk_tokens[:]
|
|
||||||
|
|
||||||
input_embeds = self.embeddings(next_tokens[:, 1 + hydra_head_idx])
|
|
@@ -606,8 +606,8 @@ class Ernie4_5_PretrainedModel(PretrainedModel):
|
|||||||
return final_actions
|
return final_actions
|
||||||
mappings = get_tensor_parallel_split_mappings(
|
mappings = get_tensor_parallel_split_mappings(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
config.moe_num_experts,
|
getattr(config, "moe_num_experts", 0),
|
||||||
config.moe_layer_start_index,
|
getattr(config, "moe_layer_start_index", -1),
|
||||||
config.prefix_name,
|
config.prefix_name,
|
||||||
)
|
)
|
||||||
return mappings
|
return mappings
|
||||||
|
@@ -51,12 +51,13 @@ class RolloutModelConfig:
|
|||||||
enable_prefix_caching: bool = False,
|
enable_prefix_caching: bool = False,
|
||||||
splitwise_role: str = "mixed",
|
splitwise_role: str = "mixed",
|
||||||
expert_parallel_size: int = 1,
|
expert_parallel_size: int = 1,
|
||||||
enable_expert_parallell: bool = False,
|
enable_expert_parallel: bool = False,
|
||||||
ori_vocab_size: int = None,
|
ori_vocab_size: int = None,
|
||||||
quantization: str = "None",
|
quantization: str = "None",
|
||||||
guided_decoding_backend: str = "off",
|
guided_decoding_backend: str = "off",
|
||||||
disable_any_whitespace: bool = True,
|
disable_any_whitespace: bool = True,
|
||||||
enable_logprob: bool = False,
|
enable_logprob: bool = False,
|
||||||
|
graph_optimization_config: str = None,
|
||||||
):
|
):
|
||||||
# Required parameters
|
# Required parameters
|
||||||
self.model_name_or_path = model_name_or_path
|
self.model_name_or_path = model_name_or_path
|
||||||
@@ -90,12 +91,13 @@ class RolloutModelConfig:
|
|||||||
self.enable_prefix_caching = enable_prefix_caching
|
self.enable_prefix_caching = enable_prefix_caching
|
||||||
self.splitwise_role = splitwise_role
|
self.splitwise_role = splitwise_role
|
||||||
self.expert_parallel_size = expert_parallel_size
|
self.expert_parallel_size = expert_parallel_size
|
||||||
self.enable_expert_parallell = enable_expert_parallell
|
self.enable_expert_parallel = enable_expert_parallel
|
||||||
self.ori_vocab_size = ori_vocab_size
|
self.ori_vocab_size = ori_vocab_size
|
||||||
self.quantization = quantization
|
self.quantization = quantization
|
||||||
self.guided_decoding_backend = guided_decoding_backend
|
self.guided_decoding_backend = guided_decoding_backend
|
||||||
self.disable_any_whitespace = disable_any_whitespace
|
self.disable_any_whitespace = disable_any_whitespace
|
||||||
self.enable_logprob = enable_logprob
|
self.enable_logprob = enable_logprob
|
||||||
|
self.graph_optimization_config = graph_optimization_config
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
||||||
|
@@ -39,17 +39,17 @@ class RolloutModel(nn.Layer):
|
|||||||
"""Initialize with FastDeploy configuration."""
|
"""Initialize with FastDeploy configuration."""
|
||||||
super(RolloutModel, self).__init__()
|
super(RolloutModel, self).__init__()
|
||||||
self.fd_config = rollout_model_config.initialize()
|
self.fd_config = rollout_model_config.initialize()
|
||||||
self._init_model()
|
self.rollout_model = self._init_model()
|
||||||
|
|
||||||
def _init_model(self):
|
def _init_model(self) -> nn.Layer:
|
||||||
"""Load model from loader based on config."""
|
"""Load model from loader based on config."""
|
||||||
context = paddle.LazyGuard()
|
context = paddle.LazyGuard()
|
||||||
architectures = f"{self.fd_config.model_config.architectures[0]}RL"
|
architectures = f"{self.fd_config.model_config.architectures[0]}RL"
|
||||||
with context:
|
with context:
|
||||||
model_cls = ModelRegistry.get_class(architectures)
|
model_cls = ModelRegistry.get_class(architectures)
|
||||||
model = model_cls(self.fd_config)
|
model = model_cls(self.fd_config)
|
||||||
|
model.eval()
|
||||||
self.rollout_model = model.eval()
|
return model
|
||||||
|
|
||||||
def get_name_mappings_to_training(self) -> Dict[str, str]:
|
def get_name_mappings_to_training(self) -> Dict[str, str]:
|
||||||
"""Get parameter name mappings between rollout and training models."""
|
"""Get parameter name mappings between rollout and training models."""
|
||||||
@@ -74,15 +74,14 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
|||||||
super(Ernie4_5_MoeForCausalLMRL, self).__init__(fd_config)
|
super(Ernie4_5_MoeForCausalLMRL, self).__init__(fd_config)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
"""name"""
|
"""name"""
|
||||||
return "Ernie4_5_MoeForCausalLMRL"
|
return "Ernie4_5_MoeForCausalLMRL"
|
||||||
|
|
||||||
def get_name_mappings_to_training(self):
|
def get_name_mappings_to_training(self) -> Dict[str, str]:
|
||||||
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
||||||
have_bias = self.fd_config.model_config.get("have_norm_bias", False)
|
|
||||||
# Prepare placeholders
|
# Prepare placeholders
|
||||||
place_holders = ["weight"] + (["bias"] if have_bias else [])
|
place_holders = ["weight"]
|
||||||
|
|
||||||
# Initialize mapping dictionary
|
# Initialize mapping dictionary
|
||||||
infer_to_train = {}
|
infer_to_train = {}
|
||||||
@@ -94,7 +93,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
|||||||
f"{base_name}.embed_tokens.weight",
|
f"{base_name}.embed_tokens.weight",
|
||||||
"lm_head.linear.weight": "lm_head.weight"
|
"lm_head.linear.weight": "lm_head.weight"
|
||||||
}
|
}
|
||||||
if self.fd_config.model_config.get("tie_word_embeddings", False):
|
if getattr(self.fd_config.model_config, "tie_word_embeddings", False):
|
||||||
# Support tie_word_embeddings
|
# Support tie_word_embeddings
|
||||||
logger.debug("enable tie_word_embeddings")
|
logger.debug("enable tie_word_embeddings")
|
||||||
static_mappings.pop("lm_head.linear.weight")
|
static_mappings.pop("lm_head.linear.weight")
|
||||||
@@ -153,15 +152,14 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener
|
|||||||
super(Ernie4_5_VLMoeForConditionalGenerationRL, self).__init__(fd_config)
|
super(Ernie4_5_VLMoeForConditionalGenerationRL, self).__init__(fd_config)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
"""name"""
|
"""name"""
|
||||||
return "Ernie4_5_VLMoeForConditionalGenerationRL"
|
return "Ernie4_5_VLMoeForConditionalGenerationRL"
|
||||||
|
|
||||||
def get_name_mappings_to_training(self):
|
def get_name_mappings_to_training(self) -> Dict[str, str]:
|
||||||
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
||||||
have_bias = self.fd_config.model_config.get("have_norm_bias", False)
|
|
||||||
# Prepare placeholders
|
# Prepare placeholders
|
||||||
place_holders = ["weight"] + (["bias"] if have_bias else [])
|
place_holders = ["weight"]
|
||||||
|
|
||||||
# Initialize mapping dictionary
|
# Initialize mapping dictionary
|
||||||
infer_to_train = {}
|
infer_to_train = {}
|
||||||
@@ -173,7 +171,7 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener
|
|||||||
f"{base_name}.embed_tokens.weight",
|
f"{base_name}.embed_tokens.weight",
|
||||||
"lm_head.linear.weight": "lm_head.weight"
|
"lm_head.linear.weight": "lm_head.weight"
|
||||||
}
|
}
|
||||||
if self.fd_config.model_config.get("tie_word_embeddings", False):
|
if getattr(self.fd_config.model_config, "tie_word_embeddings", False):
|
||||||
# Support tie_word_embeddings
|
# Support tie_word_embeddings
|
||||||
logger.debug("enable tie_word_embeddings")
|
logger.debug("enable tie_word_embeddings")
|
||||||
static_mappings.pop("lm_head.linear.weight")
|
static_mappings.pop("lm_head.linear.weight")
|
||||||
@@ -257,11 +255,11 @@ class Qwen2ForCausalLMRL(Qwen2ForCausalLM):
|
|||||||
super(Qwen2ForCausalLMRL, self).__init__(fd_config)
|
super(Qwen2ForCausalLMRL, self).__init__(fd_config)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
"""name"""
|
"""name"""
|
||||||
return "Qwen2ForCausalLMRL"
|
return "Qwen2ForCausalLMRL"
|
||||||
|
|
||||||
def get_name_mappings_to_training(self):
|
def get_name_mappings_to_training(self) -> Dict[str, str]:
|
||||||
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
||||||
# Prepare placeholders
|
# Prepare placeholders
|
||||||
place_holders = ["weight"]
|
place_holders = ["weight"]
|
||||||
@@ -307,11 +305,11 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM):
|
|||||||
super(Qwen3MoeForCausalLMRL, self).__init__(fd_config)
|
super(Qwen3MoeForCausalLMRL, self).__init__(fd_config)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
"""name"""
|
"""name"""
|
||||||
return "Qwen3MoeForCausalLMRL"
|
return "Qwen3MoeForCausalLMRL"
|
||||||
|
|
||||||
def get_name_mappings_to_training(self):
|
def get_name_mappings_to_training(self) -> Dict[str, str]:
|
||||||
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
||||||
# Prepare placeholders
|
# Prepare placeholders
|
||||||
place_holders = ["weight"]
|
place_holders = ["weight"]
|
||||||
@@ -379,6 +377,6 @@ class Qwen3ForCausalLMRL(Qwen3ForCausalLM):
|
|||||||
super(Qwen3ForCausalLMRL, self).__init__(fd_config)
|
super(Qwen3ForCausalLMRL, self).__init__(fd_config)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
"""name"""
|
"""name"""
|
||||||
return "Qwen3ForCausalLMRL"
|
return "Qwen3ForCausalLMRL"
|
||||||
|
@@ -47,14 +47,12 @@ from fastdeploy.platforms import current_platform
|
|||||||
if not current_platform.is_dcu():
|
if not current_platform.is_dcu():
|
||||||
from fastdeploy.spec_decode import MTPProposer, NgramProposer
|
from fastdeploy.spec_decode import MTPProposer, NgramProposer
|
||||||
|
|
||||||
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
|
|
||||||
from fastdeploy.input.mm_processor import DataProcessor
|
from fastdeploy.input.mm_processor import DataProcessor
|
||||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import \
|
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import \
|
||||||
ScatterOp
|
ScatterOp
|
||||||
from fastdeploy.worker.model_runner_base import ModelRunnerBase
|
from fastdeploy.worker.model_runner_base import ModelRunnerBase
|
||||||
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
|
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
|
||||||
from fastdeploy.worker.utils import check_safetensors_model
|
|
||||||
|
|
||||||
|
|
||||||
class GPUModelRunner(ModelRunnerBase):
|
class GPUModelRunner(ModelRunnerBase):
|
||||||
@@ -81,16 +79,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
# VL model config:
|
# VL model config:
|
||||||
if self.enable_mm:
|
if self.enable_mm:
|
||||||
model_path = os.path.dirname(self.parallel_config.model_name_or_path)
|
self._init_image_preprocess()
|
||||||
self.is_safetensors_model = check_safetensors_model(
|
|
||||||
self.parallel_config.model_name_or_path)
|
|
||||||
if not self.is_safetensors_model:
|
|
||||||
self.tokenizer_path = self.image_preprocessor_path = model_path
|
|
||||||
else:
|
|
||||||
self.tokenizer_path = self.parallel_config.model_name_or_path
|
|
||||||
self.image_preprocessor_path = self.parallel_config.model_name_or_path
|
|
||||||
self.vision_model_name_or_path = os.path.join(
|
|
||||||
model_path, "DFNRopeVisionTransformer")
|
|
||||||
|
|
||||||
self.amp_black = [
|
self.amp_black = [
|
||||||
"reduce_sum",
|
"reduce_sum",
|
||||||
@@ -734,8 +723,6 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
f"Starting to load model {self.model_config.architectures[0]}")
|
f"Starting to load model {self.model_config.architectures[0]}")
|
||||||
time_before_load = time.perf_counter()
|
time_before_load = time.perf_counter()
|
||||||
# 1. Load original model
|
# 1. Load original model
|
||||||
if self.enable_mm:
|
|
||||||
self.load_mm_config_and_image_preprocess()
|
|
||||||
self.model = get_model_from_loader(fd_config=self.fd_config)
|
self.model = get_model_from_loader(fd_config=self.fd_config)
|
||||||
# 1.1 Load RL dynamic model
|
# 1.1 Load RL dynamic model
|
||||||
if self.fd_config.load_config.dynamic_load_weight:
|
if self.fd_config.load_config.dynamic_load_weight:
|
||||||
@@ -1440,8 +1427,8 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
def _init_image_preprocess(self) -> None:
|
def _init_image_preprocess(self) -> None:
|
||||||
processor = DataProcessor(
|
processor = DataProcessor(
|
||||||
tokenizer_name=self.tokenizer_path,
|
tokenizer_name=self.parallel_config.model_name_or_path,
|
||||||
image_preprocessor_name=str(self.image_preprocessor_path),
|
image_preprocessor_name=str(self.parallel_config.model_name_or_path),
|
||||||
)
|
)
|
||||||
processor.eval()
|
processor.eval()
|
||||||
image_preprocess = processor.image_preprocessor
|
image_preprocess = processor.image_preprocessor
|
||||||
@@ -1459,31 +1446,6 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
-1)
|
-1)
|
||||||
self.image_preprocess = image_preprocess
|
self.image_preprocess = image_preprocess
|
||||||
|
|
||||||
def load_mm_config_and_image_preprocess(self) -> None:
|
|
||||||
tokenizer = ErnieBotTokenizer.from_pretrained(
|
|
||||||
self.tokenizer_path,
|
|
||||||
model_max_length=self.parallel_config.max_model_len,
|
|
||||||
padding_side="right",
|
|
||||||
use_fast=False,
|
|
||||||
)
|
|
||||||
tokenizer.ignored_index = -100
|
|
||||||
if tokenizer.pad_token is None:
|
|
||||||
tokenizer.pad_token = tokenizer.unk_token
|
|
||||||
|
|
||||||
self.fd_config.model_config.tensor_parallel_degree = self.parallel_config.tensor_parallel_size
|
|
||||||
self.fd_config.model_config.tensor_parallel_rank = self.parallel_config.tensor_parallel_rank
|
|
||||||
vision_config = self.fd_config.model_config.vision_config
|
|
||||||
vision_config.dtype = self.fd_config.model_config.dtype
|
|
||||||
vision_config.tensor_parallel_degree = self.parallel_config.tensor_parallel_size
|
|
||||||
vision_config.tensor_parallel_rank = self.parallel_config.tensor_parallel_rank
|
|
||||||
self.fd_config.model_config.im_patch_id = tokenizer.get_vocab()[
|
|
||||||
"<|IMAGE_PLACEHOLDER|>"
|
|
||||||
]
|
|
||||||
self.fd_config.model_config.think_end_id = tokenizer.get_vocab()["</think>"]
|
|
||||||
self.fd_config.model_config.sequence_parallel = self.parallel_config.sequence_parallel
|
|
||||||
self.model_config = self.fd_config.model_config
|
|
||||||
self._init_image_preprocess()
|
|
||||||
|
|
||||||
def _preprocess_mm_task(self, one: dict) -> None:
|
def _preprocess_mm_task(self, one: dict) -> None:
|
||||||
"""process batch"""
|
"""process batch"""
|
||||||
|
|
||||||
|
@@ -26,6 +26,7 @@ import paddle.distributed.fleet as fleet
|
|||||||
from fastdeploy.config import (DecodingConfig, DeviceConfig, FDConfig,
|
from fastdeploy.config import (DecodingConfig, DeviceConfig, FDConfig,
|
||||||
GraphOptimizationConfig, LoadConfig,
|
GraphOptimizationConfig, LoadConfig,
|
||||||
ModelConfig, ParallelConfig, SpeculativeConfig)
|
ModelConfig, ParallelConfig, SpeculativeConfig)
|
||||||
|
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
|
||||||
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
|
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
|
||||||
from fastdeploy.inter_communicator import IPCSignal
|
from fastdeploy.inter_communicator import IPCSignal
|
||||||
from fastdeploy.model_executor.layers.quantization import \
|
from fastdeploy.model_executor.layers.quantization import \
|
||||||
@@ -83,6 +84,30 @@ def init_distributed_environment(seed: int = 20) -> List[int]:
|
|||||||
|
|
||||||
return ranks, local_rank
|
return ranks, local_rank
|
||||||
|
|
||||||
|
def update_fd_config_for_mm(fd_config: FDConfig) -> None:
|
||||||
|
if fd_config.model_config.enable_mm:
|
||||||
|
tokenizer = ErnieBotTokenizer.from_pretrained(
|
||||||
|
fd_config.parallel_config.model_name_or_path,
|
||||||
|
model_max_length=fd_config.parallel_config.max_model_len,
|
||||||
|
padding_side="right",
|
||||||
|
use_fast=False,
|
||||||
|
)
|
||||||
|
tokenizer.ignored_index = -100
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.unk_token
|
||||||
|
|
||||||
|
fd_config.model_config.tensor_parallel_degree = fd_config.parallel_config.tensor_parallel_size
|
||||||
|
fd_config.model_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||||
|
vision_config = fd_config.model_config.vision_config
|
||||||
|
vision_config.dtype = fd_config.model_config.dtype
|
||||||
|
# vision_config.tensor_parallel_degree = fd_config.parallel_config.tensor_parallel_size
|
||||||
|
# vision_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||||
|
fd_config.model_config.im_patch_id = tokenizer.get_vocab()[
|
||||||
|
"<|IMAGE_PLACEHOLDER|>"
|
||||||
|
]
|
||||||
|
fd_config.model_config.think_end_id = tokenizer.get_vocab()["</think>"]
|
||||||
|
fd_config.model_config.sequence_parallel = fd_config.parallel_config.sequence_parallel
|
||||||
|
|
||||||
class PaddleDisWorkerProc():
|
class PaddleDisWorkerProc():
|
||||||
"""
|
"""
|
||||||
Paddle Distrubuted wrapper for fastdeploy.worker.Worker,
|
Paddle Distrubuted wrapper for fastdeploy.worker.Worker,
|
||||||
@@ -504,9 +529,9 @@ def parse_args():
|
|||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="expert parallel size")
|
help="expert parallel size")
|
||||||
parser.add_argument("--enable_expert_parallell",
|
parser.add_argument("--enable_expert_parallel",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="enable expert parallell")
|
help="enable expert parallel")
|
||||||
parser.add_argument("--ori_vocab_size", type=int, default=None)
|
parser.add_argument("--ori_vocab_size", type=int, default=None)
|
||||||
|
|
||||||
parser.add_argument("--quantization",
|
parser.add_argument("--quantization",
|
||||||
@@ -517,7 +542,7 @@ def parse_args():
|
|||||||
"default is None. The priority of this configuration "\
|
"default is None. The priority of this configuration "\
|
||||||
"is lower than that of the config file. " \
|
"is lower than that of the config file. " \
|
||||||
"More complex quantization methods need to be configured via the config file.")
|
"More complex quantization methods need to be configured via the config file.")
|
||||||
parser.add_argument("--graph_optimiaztion_config",
|
parser.add_argument("--graph_optimization_config",
|
||||||
type=json.loads,
|
type=json.loads,
|
||||||
default=None,
|
default=None,
|
||||||
help=" Configation of Graph optimization backend. "
|
help=" Configation of Graph optimization backend. "
|
||||||
@@ -541,9 +566,8 @@ def parse_args():
|
|||||||
"'ipc': real-time IPC streaming with automatic resharding, "
|
"'ipc': real-time IPC streaming with automatic resharding, "
|
||||||
"'ipc_snapshot': load from disk snapshot of IPC weights.")
|
"'ipc_snapshot': load from disk snapshot of IPC weights.")
|
||||||
parser.add_argument("--enable_mm",
|
parser.add_argument("--enable_mm",
|
||||||
type=str,
|
action='store_true',
|
||||||
default="false",
|
help="Whether to enable vl model")
|
||||||
help="Whether to use vl")
|
|
||||||
parser.add_argument("--enable_logprob",
|
parser.add_argument("--enable_logprob",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="Enable output of token-level log probabilities.")
|
help="Enable output of token-level log probabilities.")
|
||||||
@@ -572,10 +596,12 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
|||||||
parallel_config.expert_parallel_rank = int(local_rank / ranks)
|
parallel_config.expert_parallel_rank = int(local_rank / ranks)
|
||||||
load_config = LoadConfig(vars(args))
|
load_config = LoadConfig(vars(args))
|
||||||
|
|
||||||
|
graph_opt_config = GraphOptimizationConfig()
|
||||||
|
if args.graph_optimization_config is not None:
|
||||||
graph_opt_config = GraphOptimizationConfig(
|
graph_opt_config = GraphOptimizationConfig(
|
||||||
use_cudagraph=args.graph_optimiaztion_config["use_cudagraph"],
|
use_cudagraph=args.graph_optimization_config["use_cudagraph"],
|
||||||
graph_opt_level=args.graph_optimiaztion_config["graph_opt_level"],
|
graph_opt_level=args.graph_optimization_config["graph_opt_level"],
|
||||||
cudagraph_capture_sizes=args.graph_optimiaztion_config["cudagraph_capture_sizes"]
|
cudagraph_capture_sizes=args.graph_optimization_config["cudagraph_capture_sizes"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note(tangbinhan): used for load_checkpoint
|
# Note(tangbinhan): used for load_checkpoint
|
||||||
@@ -650,7 +676,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Set VL tag
|
# Set VL tag
|
||||||
model_config.enable_mm = getattr(args, 'enable_mm', 'false').lower() == 'true'
|
model_config.enable_mm = args.enable_mm
|
||||||
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
|
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
|
||||||
logger.info(f"- Load strategy: {load_config.load_strategy}")
|
logger.info(f"- Load strategy: {load_config.load_strategy}")
|
||||||
|
|
||||||
@@ -662,6 +688,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
|||||||
decoding_config=decoding_config,
|
decoding_config=decoding_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
graph_opt_config=graph_opt_config)
|
graph_opt_config=graph_opt_config)
|
||||||
|
update_fd_config_for_mm(fd_config)
|
||||||
|
|
||||||
return fd_config
|
return fd_config
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user