Fix rollout_model init (#2881)

This commit is contained in:
Yuanle Liu
2025-07-17 13:36:21 +08:00
committed by GitHub
parent 1f15ca21e4
commit dbb9e2506b
9 changed files with 76 additions and 312 deletions

View File

@@ -46,7 +46,6 @@ PRETRAINED_INIT_CONFIGURATION = {
"num_max_dispatch_tokens_per_rank" : 256,
"moe_use_aux_free" : False,
"vocab_size" : -1,
"use_rope": True,
"hidden_dropout_prob" : 0.0,
"initializer_range" : 0.02,
"max_position_embeddings" : 512,
@@ -89,6 +88,7 @@ class ModelConfig:
if hasattr(self, key):
setattr(self, key, value)
assert self.model_name_or_path != ""
pretrained_config, _ = PretrainedConfig.get_config_dict(self.model_name_or_path)
self.pretrained_config = PretrainedConfig.from_dict(pretrained_config)

View File

@@ -32,6 +32,7 @@ from typing import Dict, List, Optional, Tuple
import numpy as np
import paddle
import zmq
from opentelemetry import trace
from tqdm import tqdm
from fastdeploy.engine.args_utils import EngineArgs
@@ -42,13 +43,13 @@ from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import (EngineCacheQueue, EngineWorkerQueue,
IPCSignal, ZmqClient)
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.metrics.trace_util import start_span, start_span_request
from fastdeploy.model_executor.guided_decoding import schema_checker
from fastdeploy.output.token_processor import (TokenProcessor,
WarmUpTokenProcessor)
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, console_logger, llm_logger
from fastdeploy.metrics.trace_util import extract_from_metadata, start_span, start_span_request
from opentelemetry import trace
class LLMEngine(object):
"""
@@ -1032,10 +1033,9 @@ class LLMEngine(object):
f" --speculative_model_name_or_path {self.cfg.speculative_config.model_name_or_path}"
f" --speculative_model_quantization {self.cfg.speculative_config.quantization}"
f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}"
f" --graph_optimiaztion_config '{self.cfg.graph_optimization_config.to_json_string()}'"
f" --graph_optimization_config '{self.cfg.graph_optimization_config.to_json_string()}'"
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
f" --load_strategy {self.cfg.model_config.load_strategy}"
f" --enable_mm {self.cfg.enable_mm}")
f" --load_strategy {self.cfg.model_config.load_strategy}")
worker_append_flag = {
@@ -1050,6 +1050,7 @@ class LLMEngine(object):
"disable_any_whitespace": self.cfg.disable_any_whitespace,
"enable-custom-all-reduce": self.cfg.parallel_config.enable_custom_all_reduce,
"enable_logprob": self.cfg.enable_logprob,
"enable_mm": self.cfg.enable_mm,
}
for worker_flag, value in worker_append_flag.items():
if value:

View File

@@ -58,7 +58,6 @@ class VocabParallelEmbedding(nn.Layer):
self.column_cut = False
self.world_size: int = hcg.get_model_parallel_world_size()
self.ring_id: int = hcg.get_model_parallel_group().id
self.use_rope: bool = fd_config.model_config.use_rope
self.use_ep: bool = fd_config.parallel_config.use_ep
self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob
self.initializer_range: float = fd_config.model_config.initializer_range
@@ -92,14 +91,6 @@ class VocabParallelEmbedding(nn.Layer):
self.embeddings.weight.is_distributed = True
self.embeddings.weight.split_axis = 1
if not self.use_rope:
self.position_embeddings = nn.Embedding(
self.max_position_embeddings,
embedding_dim,
weight_attr=paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range), ),
)
self.prefix = prefix
self.dropout = nn.Dropout(self.hidden_dropout_prob)

View File

@@ -1,217 +0,0 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import (ColumnParallelLinear,
VocabParallelEmbedding)
from paddleformers.utils.log import logger
from .utils import get_tensor
class ResBlock(nn.Layer):
"""
A Residual Block module.
This module performs a linear transformation followed by a SiLU activation,
and then adds the result to the original input, creating a residual connection.
Args:
hidden_size (int): The size of the hidden layers in the block.
"""
def __init__(self, hidden_size, num_condition=0):
super().__init__()
self.linear = nn.Linear(hidden_size * (num_condition + 1), hidden_size)
if num_condition > 0:
self.res_connection = nn.Linear(
hidden_size * (num_condition + 1), hidden_size
)
else:
self.res_connection = nn.Identity()
# Initialize as an identity mapping
# _no_grad_fill_(self.linear.weight, 0)
# Use SiLU activation to keep consistent with the Llama model
self.act = nn.Silu()
@paddle.no_grad()
def forward(self, x):
"""
Forward pass of the ResBlock.
Args:
x (paddle.Tensor): Input tensor.
Returns:
paddle.Tensor: Output after the residual connection and activation.
"""
return self.res_connection(x) + self.act(self.linear(x))
class HydraHead(nn.Layer):
"""
A Hydra Head module.
This module performs multi hydra head layers,
each of which is a hydra_lm_head followed by a head
Args:
hydra_num_heads (int): The number of hyhra heads.
hydra_num_layers (int): The number of layers.
hidden_size (int): The size of the hidden layers in the block.
tensor_parallel_degree(int): TP degree.
vocab_size (int): The size of vocabulary.
"""
def __init__(
self,
hydra_num_heads,
hydra_num_layers,
hidden_size,
tensor_parallel_degree,
vocab_size,
):
super().__init__()
self.hydra_num_heads = hydra_num_heads
self.hydra_num_layers = hydra_num_layers
self.hidden_size = hidden_size
self.tensor_parallel_degree = tensor_parallel_degree
self.vocab_size = vocab_size
self.hydra_mlp = nn.LayerList(
[
nn.Sequential(
ResBlock(self.hidden_size, hydra_head_idx + 1),
*([ResBlock(self.hidden_size)] * (self.hydra_num_layers - 1)),
)
for hydra_head_idx in range(self.hydra_num_heads)
]
)
if self.tensor_parallel_degree > 1:
self.hydra_lm_head = nn.LayerList(
[
ColumnParallelLinear(
self.hidden_size,
self.vocab_size,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=0.0)
),
gather_output=True,
has_bias=False,
)
for _ in range(self.hydra_num_heads)
]
)
else:
self.hydra_lm_head = nn.LayerList(
[
nn.Linear(self.hidden_size, self.vocab_size, bias_attr=False)
for _ in range(self.hydra_num_heads)
]
)
self.embeddings = VocabParallelEmbedding(
vocab_size,
hidden_size,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
weight_attr=paddle.ParamAttr(initializer=nn.initializer.Normal(mean=0.0)),
)
def custom_set_state_dict(self, state_dict):
"""
Load Parameter of Hydra Head from state_dict with custom names.
Args:
state_dict (dict): KV pair of name and parameters.
"""
for hydra_head_idx in range(self.hydra_num_heads):
self.hydra_mlp[hydra_head_idx][0].res_connection.weight.set_value(
get_tensor(
state_dict.pop(f"0.{hydra_head_idx}.0.res_connection.weight")
)
)
self.hydra_mlp[hydra_head_idx][0].res_connection.bias.set_value(
get_tensor(state_dict.pop(f"0.{hydra_head_idx}.0.res_connection.bias"))
)
for layer_idx in range(self.hydra_num_layers):
self.hydra_mlp[hydra_head_idx][layer_idx].linear.weight.set_value(
get_tensor(
state_dict.pop(f"0.{hydra_head_idx}.{layer_idx}.linear.weight")
)
)
self.hydra_mlp[hydra_head_idx][layer_idx].linear.bias.set_value(
get_tensor(
state_dict.pop(f"0.{hydra_head_idx}.{layer_idx}.linear.bias")
)
)
self.hydra_lm_head[hydra_head_idx].weight.set_value(
get_tensor(state_dict.pop(f"1.{hydra_head_idx}.weight"))
)
self.embeddings.weight.set_value(
get_tensor(state_dict.pop("embeddings.weight"))
)
def set_state_dict(self, state_dict):
"""
Load Parameter of Hydra Head from state_dict.
Args:
state_dict (dict): KV pair of name and parameters.
"""
is_custom = True
for key in state_dict.keys():
if key != "embeddings.weight" and (
"hydra_mlp" in key or "hydra_head" in key
):
is_custom = False
break
if is_custom:
logger.info("Hydra use custom set_state_dict")
self.custom_set_state_dict(state_dict)
else:
logger.info("Hydra use default set_state_dict")
super().set_state_dict(state_dict)
@paddle.no_grad()
def forward(self, input_ids, hidden_states, next_tokens):
"""
Forward pass of Hydra Head
Args:
input_ids: [batch_size, 1] The tokens sampled by the previous head go through the embedding,
starting with the last accept token
hidden_states: [batch_size, hidden_size] The hidden_states of the last accept_tokens
"""
hydra_inputs = [hidden_states]
input_embeds = self.embeddings(input_ids)
for hydra_head_idx in range(self.hydra_num_heads):
hydra_inputs.append(input_embeds)
head_input = paddle.concat(hydra_inputs, axis=-1)
hidden_states = self.hydra_mlp[hydra_head_idx](head_input)
logits = self.hydra_lm_head[hydra_head_idx](hidden_states)
probs = F.softmax(logits)
_, topk_tokens = paddle.topk(probs, k=1, axis=-1)
next_tokens[:, 1 + hydra_head_idx : 2 + hydra_head_idx] = topk_tokens[:]
input_embeds = self.embeddings(next_tokens[:, 1 + hydra_head_idx])

View File

@@ -606,8 +606,8 @@ class Ernie4_5_PretrainedModel(PretrainedModel):
return final_actions
mappings = get_tensor_parallel_split_mappings(
config.num_hidden_layers,
config.moe_num_experts,
config.moe_layer_start_index,
getattr(config, "moe_num_experts", 0),
getattr(config, "moe_layer_start_index", -1),
config.prefix_name,
)
return mappings

View File

@@ -51,12 +51,13 @@ class RolloutModelConfig:
enable_prefix_caching: bool = False,
splitwise_role: str = "mixed",
expert_parallel_size: int = 1,
enable_expert_parallell: bool = False,
enable_expert_parallel: bool = False,
ori_vocab_size: int = None,
quantization: str = "None",
guided_decoding_backend: str = "off",
disable_any_whitespace: bool = True,
enable_logprob: bool = False,
graph_optimization_config: str = None,
):
# Required parameters
self.model_name_or_path = model_name_or_path
@@ -90,12 +91,13 @@ class RolloutModelConfig:
self.enable_prefix_caching = enable_prefix_caching
self.splitwise_role = splitwise_role
self.expert_parallel_size = expert_parallel_size
self.enable_expert_parallell = enable_expert_parallell
self.enable_expert_parallel = enable_expert_parallel
self.ori_vocab_size = ori_vocab_size
self.quantization = quantization
self.guided_decoding_backend = guided_decoding_backend
self.disable_any_whitespace = disable_any_whitespace
self.enable_logprob = enable_logprob
self.graph_optimization_config = graph_optimization_config
def __str__(self):
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())

View File

@@ -39,17 +39,17 @@ class RolloutModel(nn.Layer):
"""Initialize with FastDeploy configuration."""
super(RolloutModel, self).__init__()
self.fd_config = rollout_model_config.initialize()
self._init_model()
self.rollout_model = self._init_model()
def _init_model(self):
def _init_model(self) -> nn.Layer:
"""Load model from loader based on config."""
context = paddle.LazyGuard()
architectures = f"{self.fd_config.model_config.architectures[0]}RL"
with context:
model_cls = ModelRegistry.get_class(architectures)
model = model_cls(self.fd_config)
self.rollout_model = model.eval()
model.eval()
return model
def get_name_mappings_to_training(self) -> Dict[str, str]:
"""Get parameter name mappings between rollout and training models."""
@@ -74,15 +74,14 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
super(Ernie4_5_MoeForCausalLMRL, self).__init__(fd_config)
@classmethod
def name(self):
def name(self) -> str:
"""name"""
return "Ernie4_5_MoeForCausalLMRL"
def get_name_mappings_to_training(self):
def get_name_mappings_to_training(self) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
have_bias = self.fd_config.model_config.get("have_norm_bias", False)
# Prepare placeholders
place_holders = ["weight"] + (["bias"] if have_bias else [])
place_holders = ["weight"]
# Initialize mapping dictionary
infer_to_train = {}
@@ -94,7 +93,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
f"{base_name}.embed_tokens.weight",
"lm_head.linear.weight": "lm_head.weight"
}
if self.fd_config.model_config.get("tie_word_embeddings", False):
if getattr(self.fd_config.model_config, "tie_word_embeddings", False):
# Support tie_word_embeddings
logger.debug("enable tie_word_embeddings")
static_mappings.pop("lm_head.linear.weight")
@@ -153,15 +152,14 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener
super(Ernie4_5_VLMoeForConditionalGenerationRL, self).__init__(fd_config)
@classmethod
def name(self):
def name(self) -> str:
"""name"""
return "Ernie4_5_VLMoeForConditionalGenerationRL"
def get_name_mappings_to_training(self):
def get_name_mappings_to_training(self) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
have_bias = self.fd_config.model_config.get("have_norm_bias", False)
# Prepare placeholders
place_holders = ["weight"] + (["bias"] if have_bias else [])
place_holders = ["weight"]
# Initialize mapping dictionary
infer_to_train = {}
@@ -173,7 +171,7 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener
f"{base_name}.embed_tokens.weight",
"lm_head.linear.weight": "lm_head.weight"
}
if self.fd_config.model_config.get("tie_word_embeddings", False):
if getattr(self.fd_config.model_config, "tie_word_embeddings", False):
# Support tie_word_embeddings
logger.debug("enable tie_word_embeddings")
static_mappings.pop("lm_head.linear.weight")
@@ -257,11 +255,11 @@ class Qwen2ForCausalLMRL(Qwen2ForCausalLM):
super(Qwen2ForCausalLMRL, self).__init__(fd_config)
@classmethod
def name(self):
def name(self) -> str:
"""name"""
return "Qwen2ForCausalLMRL"
def get_name_mappings_to_training(self):
def get_name_mappings_to_training(self) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
# Prepare placeholders
place_holders = ["weight"]
@@ -307,11 +305,11 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM):
super(Qwen3MoeForCausalLMRL, self).__init__(fd_config)
@classmethod
def name(self):
def name(self) -> str:
"""name"""
return "Qwen3MoeForCausalLMRL"
def get_name_mappings_to_training(self):
def get_name_mappings_to_training(self) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
# Prepare placeholders
place_holders = ["weight"]
@@ -379,6 +377,6 @@ class Qwen3ForCausalLMRL(Qwen3ForCausalLM):
super(Qwen3ForCausalLMRL, self).__init__(fd_config)
@classmethod
def name(self):
def name(self) -> str:
"""name"""
return "Qwen3ForCausalLMRL"

View File

@@ -47,14 +47,12 @@ from fastdeploy.platforms import current_platform
if not current_platform.is_dcu():
from fastdeploy.spec_decode import MTPProposer, NgramProposer
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
from fastdeploy.input.mm_processor import DataProcessor
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import \
ScatterOp
from fastdeploy.worker.model_runner_base import ModelRunnerBase
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
from fastdeploy.worker.utils import check_safetensors_model
class GPUModelRunner(ModelRunnerBase):
@@ -81,16 +79,7 @@ class GPUModelRunner(ModelRunnerBase):
# VL model config:
if self.enable_mm:
model_path = os.path.dirname(self.parallel_config.model_name_or_path)
self.is_safetensors_model = check_safetensors_model(
self.parallel_config.model_name_or_path)
if not self.is_safetensors_model:
self.tokenizer_path = self.image_preprocessor_path = model_path
else:
self.tokenizer_path = self.parallel_config.model_name_or_path
self.image_preprocessor_path = self.parallel_config.model_name_or_path
self.vision_model_name_or_path = os.path.join(
model_path, "DFNRopeVisionTransformer")
self._init_image_preprocess()
self.amp_black = [
"reduce_sum",
@@ -734,8 +723,6 @@ class GPUModelRunner(ModelRunnerBase):
f"Starting to load model {self.model_config.architectures[0]}")
time_before_load = time.perf_counter()
# 1. Load original model
if self.enable_mm:
self.load_mm_config_and_image_preprocess()
self.model = get_model_from_loader(fd_config=self.fd_config)
# 1.1 Load RL dynamic model
if self.fd_config.load_config.dynamic_load_weight:
@@ -1440,8 +1427,8 @@ class GPUModelRunner(ModelRunnerBase):
def _init_image_preprocess(self) -> None:
processor = DataProcessor(
tokenizer_name=self.tokenizer_path,
image_preprocessor_name=str(self.image_preprocessor_path),
tokenizer_name=self.parallel_config.model_name_or_path,
image_preprocessor_name=str(self.parallel_config.model_name_or_path),
)
processor.eval()
image_preprocess = processor.image_preprocessor
@@ -1459,31 +1446,6 @@ class GPUModelRunner(ModelRunnerBase):
-1)
self.image_preprocess = image_preprocess
def load_mm_config_and_image_preprocess(self) -> None:
tokenizer = ErnieBotTokenizer.from_pretrained(
self.tokenizer_path,
model_max_length=self.parallel_config.max_model_len,
padding_side="right",
use_fast=False,
)
tokenizer.ignored_index = -100
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.unk_token
self.fd_config.model_config.tensor_parallel_degree = self.parallel_config.tensor_parallel_size
self.fd_config.model_config.tensor_parallel_rank = self.parallel_config.tensor_parallel_rank
vision_config = self.fd_config.model_config.vision_config
vision_config.dtype = self.fd_config.model_config.dtype
vision_config.tensor_parallel_degree = self.parallel_config.tensor_parallel_size
vision_config.tensor_parallel_rank = self.parallel_config.tensor_parallel_rank
self.fd_config.model_config.im_patch_id = tokenizer.get_vocab()[
"<|IMAGE_PLACEHOLDER|>"
]
self.fd_config.model_config.think_end_id = tokenizer.get_vocab()["</think>"]
self.fd_config.model_config.sequence_parallel = self.parallel_config.sequence_parallel
self.model_config = self.fd_config.model_config
self._init_image_preprocess()
def _preprocess_mm_task(self, one: dict) -> None:
"""process batch"""

View File

@@ -26,6 +26,7 @@ import paddle.distributed.fleet as fleet
from fastdeploy.config import (DecodingConfig, DeviceConfig, FDConfig,
GraphOptimizationConfig, LoadConfig,
ModelConfig, ParallelConfig, SpeculativeConfig)
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.model_executor.layers.quantization import \
@@ -83,6 +84,30 @@ def init_distributed_environment(seed: int = 20) -> List[int]:
return ranks, local_rank
def update_fd_config_for_mm(fd_config: FDConfig) -> None:
if fd_config.model_config.enable_mm:
tokenizer = ErnieBotTokenizer.from_pretrained(
fd_config.parallel_config.model_name_or_path,
model_max_length=fd_config.parallel_config.max_model_len,
padding_side="right",
use_fast=False,
)
tokenizer.ignored_index = -100
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.unk_token
fd_config.model_config.tensor_parallel_degree = fd_config.parallel_config.tensor_parallel_size
fd_config.model_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
vision_config = fd_config.model_config.vision_config
vision_config.dtype = fd_config.model_config.dtype
# vision_config.tensor_parallel_degree = fd_config.parallel_config.tensor_parallel_size
# vision_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
fd_config.model_config.im_patch_id = tokenizer.get_vocab()[
"<|IMAGE_PLACEHOLDER|>"
]
fd_config.model_config.think_end_id = tokenizer.get_vocab()["</think>"]
fd_config.model_config.sequence_parallel = fd_config.parallel_config.sequence_parallel
class PaddleDisWorkerProc():
"""
Paddle Distrubuted wrapper for fastdeploy.worker.Worker,
@@ -504,9 +529,9 @@ def parse_args():
type=int,
default=1,
help="expert parallel size")
parser.add_argument("--enable_expert_parallell",
parser.add_argument("--enable_expert_parallel",
action='store_true',
help="enable expert parallell")
help="enable expert parallel")
parser.add_argument("--ori_vocab_size", type=int, default=None)
parser.add_argument("--quantization",
@@ -517,7 +542,7 @@ def parse_args():
"default is None. The priority of this configuration "\
"is lower than that of the config file. " \
"More complex quantization methods need to be configured via the config file.")
parser.add_argument("--graph_optimiaztion_config",
parser.add_argument("--graph_optimization_config",
type=json.loads,
default=None,
help=" Configation of Graph optimization backend. "
@@ -541,9 +566,8 @@ def parse_args():
"'ipc': real-time IPC streaming with automatic resharding, "
"'ipc_snapshot': load from disk snapshot of IPC weights.")
parser.add_argument("--enable_mm",
type=str,
default="false",
help="Whether to use vl")
action='store_true',
help="Whether to enable vl model")
parser.add_argument("--enable_logprob",
action='store_true',
help="Enable output of token-level log probabilities.")
@@ -572,10 +596,12 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
parallel_config.expert_parallel_rank = int(local_rank / ranks)
load_config = LoadConfig(vars(args))
graph_opt_config = GraphOptimizationConfig()
if args.graph_optimization_config is not None:
graph_opt_config = GraphOptimizationConfig(
use_cudagraph=args.graph_optimiaztion_config["use_cudagraph"],
graph_opt_level=args.graph_optimiaztion_config["graph_opt_level"],
cudagraph_capture_sizes=args.graph_optimiaztion_config["cudagraph_capture_sizes"]
use_cudagraph=args.graph_optimization_config["use_cudagraph"],
graph_opt_level=args.graph_optimization_config["graph_opt_level"],
cudagraph_capture_sizes=args.graph_optimization_config["cudagraph_capture_sizes"]
)
# Note(tangbinhan): used for load_checkpoint
@@ -650,7 +676,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
)
# Set VL tag
model_config.enable_mm = getattr(args, 'enable_mm', 'false').lower() == 'true'
model_config.enable_mm = args.enable_mm
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
logger.info(f"- Load strategy: {load_config.load_strategy}")
@@ -662,6 +688,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
decoding_config=decoding_config,
quant_config=quant_config,
graph_opt_config=graph_opt_config)
update_fd_config_for_mm(fd_config)
return fd_config