qwen loader (#3057)

This commit is contained in:
bukejiyu
2025-07-30 19:09:38 +08:00
committed by GitHub
parent 28fff1b035
commit db698bda01
22 changed files with 494 additions and 92 deletions

View File

@@ -19,7 +19,8 @@ from __future__ import annotations
import json
import os
from dataclasses import dataclass, field
from typing import Literal, Optional
from enum import Enum
from typing import Literal, Optional, Union
from paddleformers.transformers.configuration_utils import PretrainedConfig
@@ -650,6 +651,14 @@ class EarlyStopConfig:
argument = self.enable_early_stop
class LoadChoices(str, Enum):
"""LoadChoices"""
DEFAULT = "default"
# only support qwen3-bf16 now
NEW_LOADER = "new_loader"
class LoadConfig:
"""
Configuration for dynamic weight loading strategies
@@ -666,6 +675,7 @@ class LoadConfig:
self,
args,
):
self.load_choices: Union[str, LoadChoices] = LoadChoices.DEFAULT.value
self.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1
self.dynamic_load_weight: bool = False
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot"]] = None

View File

@@ -326,6 +326,13 @@ class EngineArgs:
Configuration for early stop.
"""
load_choices: str = "default"
"""The format of the model weights to load.
Options include:
- "default": default loader.
- "new_loader": new loader.
"""
def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
@@ -543,6 +550,16 @@ class EngineArgs:
help="Enable expert parallelism.",
)
# Load group
load_group = parser.add_argument_group("Load Configuration")
load_group.add_argument(
"--load_choices",
type=str,
default=EngineArgs.load_choices,
help="The format of the model weights to load.\
default/new_loader.",
)
# CacheConfig parameters group
cache_group = parser.add_argument_group("Cache Configuration")
@@ -897,4 +914,5 @@ class EngineArgs:
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
enable_logprob=self.enable_logprob,
early_stop_config=early_stop_cfg,
load_choices=self.load_choices,
)

View File

@@ -54,6 +54,7 @@ class Config:
splitwise_role (str): Splitwise role.
innode_prefill_ports (Optional[List[int]]): Innode prefill ports.
Temporary configuration, will be removed in the future.
load_choices(str):The format of the model weights to load. .Default is default
"""
def __init__(
@@ -88,6 +89,7 @@ class Config:
disable_any_whitespace: bool = False,
enable_logprob: bool = False,
early_stop_config: Optional[Dict[str, Any]] = None,
load_choices: str = "default",
):
"""
Initialize the Config class.
@@ -118,6 +120,7 @@ class Config:
Default is False.
enable_logprob(bool): Enable logprob. Default is False.
early_stop_config (Optional[Dict[str, Any]]): Early stop configuration. Default is None.
load_choices(str):The format of the model weights to load. .Default is default
"""
self.model_config = model_config
self.cache_config = cache_config
@@ -167,6 +170,7 @@ class Config:
self.guided_decoding_backend = guided_decoding_backend
self.disable_any_whitespace = disable_any_whitespace
self._str_to_list("innode_prefill_ports", int)
self.load_choices = load_choices
assert self.splitwise_role in ["mixed", "prefill", "decode"]

View File

@@ -1089,6 +1089,7 @@ class LLMEngine:
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
f" --load_strategy {self.cfg.load_config.load_strategy}"
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
f" --load_choices {self.cfg.load_choices}"
)
worker_append_flag = {

View File

@@ -22,6 +22,7 @@ from paddle import nn
from paddle.distributed import fleet
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.models.utils import set_weight_attrs
from .utils import get_tensor
@@ -80,6 +81,7 @@ class VocabParallelEmbedding(nn.Layer):
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
),
)
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
else:
# column cut embedding
self.embeddings = nn.Embedding(
@@ -89,6 +91,7 @@ class VocabParallelEmbedding(nn.Layer):
self.embeddings.weight.is_distributed = True
self.embeddings.weight.split_axis = 1
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
self.prefix = prefix
self.dropout = nn.Dropout(self.hidden_dropout_prob)

View File

@@ -14,11 +14,17 @@
# limitations under the License.
"""
from typing import Optional
import paddle
from paddle import nn
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.models.utils import (
default_weight_loader,
set_weight_attrs,
)
from fastdeploy.platforms import current_platform
from .utils import _set_var_distributed, divide, get_tensor
@@ -107,6 +113,15 @@ class LinearBase(nn.Layer):
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
self.weight,
{
"weight_loader": (
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
)
},
)
self.bias = None
if self.with_bias:
self.bias = self.create_parameter(
@@ -115,6 +130,15 @@ class LinearBase(nn.Layer):
is_bias=True,
)
set_weight_attrs(
self.weight,
{
"weight_loader": (
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
)
},
)
# smooth quant
self.linear_shift = None
self.linear_smooth = None
@@ -273,6 +297,7 @@ class ColumnParallelLinear(LinearBase):
add_bias=add_bias,
skip_quant=skip_quant,
)
self.fd_config = fd_config
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.input_size = input_size
self.output_size = divide(output_size, self.nranks) # Split the output_size using TP inference.
@@ -300,6 +325,15 @@ class ColumnParallelLinear(LinearBase):
if self.nranks > 0:
# col parallel
_set_var_distributed(self.weight, split_axis=1)
set_weight_attrs(
self.weight,
{
"output_dim": True,
"weight_loader": (
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
),
},
)
self.bias = None
if self.with_bias:
@@ -311,6 +345,17 @@ class ColumnParallelLinear(LinearBase):
if self.nranks > 0:
# col parallel
_set_var_distributed(self.bias, split_axis=1)
set_weight_attrs(
self.weight,
{
"output_dim": True,
"weight_loader": (
self.weight_loader
if hasattr(self, "weight_loader")
else default_weight_loader(self.fd_config)
),
},
)
# smooth quant
self.linear_shift = None
@@ -354,6 +399,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self.activation = activation
self.hidden_size = fd_config.model_config.hidden_size
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.output_size = output_size
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
super().__init__(
fd_config=fd_config,
@@ -365,6 +412,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_quant=skip_quant,
)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
# 1.fused gate_up in disk
# 2.split gate up
assert loaded_shard_id in ["gate", "up"]
output_dim = getattr(param, "output_dim", None)
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
dim = -1
size = loaded_weight.get_shape()[dim]
block_size = size // self.nranks
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]
loaded_weight = get_tensor(loaded_weight)
if loaded_shard_id == "gate":
param[:, : self.output_size // 2] = loaded_weight
elif loaded_shard_id == "up":
param[:, self.output_size // 2 :] = loaded_weight
def load_state_dict(self, state_dict: dict):
"""
Load the checkpoint state dictionary into the layer.
@@ -415,6 +483,7 @@ class QKVParallelLinear(ColumnParallelLinear):
self.hidden_size = fd_config.model_config.hidden_size
self.head_dim = fd_config.model_config.head_dim
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
if self.kv_num_heads < self.nranks and self.nranks % self.kv_num_heads == 0:
self.kv_num_heads_per_rank = 1
@@ -432,6 +501,34 @@ class QKVParallelLinear(ColumnParallelLinear):
add_bias=add_bias,
)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
# 1.fused qkv in disk
# 2.split q k v
assert loaded_shard_id in ["q", "k", "v"]
output_dim = getattr(param, "output_dim", None)
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
dim = -1
size = loaded_weight.get_shape()[dim]
block_size = size // self.nranks
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]
loaded_weight = get_tensor(loaded_weight)
if loaded_shard_id == "q":
param[:, : self.num_heads_per_rank * self.head_dim] = loaded_weight
elif loaded_shard_id == "k":
param[
:,
self.num_heads_per_rank
* self.head_dim : (self.num_heads_per_rank + self.kv_num_heads_per_rank)
* self.head_dim,
] = loaded_weight
elif loaded_shard_id == "v":
param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :] = loaded_weight
def load_weight(self, state_dict: dict):
"""
Load the weight from the state dictionary.
@@ -588,6 +685,18 @@ class RowParallelLinear(LinearBase):
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
if self.nranks > 0:
# row parallel
set_weight_attrs(
self.weight,
{
"output_dim": False,
"weight_loader": (
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
),
},
)
_set_var_distributed(self.weight, split_axis=0)
self.bias = None
if self.with_bias:
@@ -596,10 +705,18 @@ class RowParallelLinear(LinearBase):
dtype=self._dtype,
is_bias=True,
)
if self.nranks > 0:
# row parallel
_set_var_distributed(self.weight, split_axis=0)
set_weight_attrs(
self.bias,
{
"output_dim": False,
"weight_loader": (
self.weight_loader
if hasattr(self, "weight_loader")
else default_weight_loader(self.fd_config)
),
},
)
# smooth quant
self.linear_shift = None

View File

@@ -22,6 +22,7 @@ from paddle import nn
from paddle.distributed import fleet
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.models.utils import set_weight_attrs
from .utils import get_tensor
@@ -83,6 +84,7 @@ class ParallelLMHead(nn.Layer):
gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小
)
set_weight_attrs(self.linear.weight, {"output_dim": True})
else:
self.linear = RowParallelLinear(
embedding_dim,
@@ -93,6 +95,7 @@ class ParallelLMHead(nn.Layer):
input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小
)
set_weight_attrs(self.linear.weight, {"output_dim": False})
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""

View File

@@ -16,12 +16,14 @@
import json
import os
import time
import paddle
import paddle.distributed as dist
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
from paddleformers.transformers import PretrainedModel
from paddleformers.transformers.model_utils import load_tp_checkpoint
from paddleformers.utils.log import logger
from safetensors import safe_open
from tqdm import tqdm
@@ -32,6 +34,17 @@ from fastdeploy.model_executor.models.tp_utils import (
from fastdeploy.platforms import current_platform
def measure_time(func):
def wrapper(*args, **kwargs):
time_before_load = time.time()
result = func(*args, **kwargs)
time_after_load = time.time()
logger.info(f"Model loading took {time_after_load - time_before_load} seconds")
return result
return wrapper
def load_reordered_experts(model_path: str, key_name: str):
from safetensors import safe_open
@@ -152,9 +165,11 @@ def safetensors_weights_iterator(
safe_tensor_list,
desc="Loading safetensors checkpoint shards",
):
with safe_open(st_file, framework="np") as f:
from paddleformers.utils.safetensors import fast_safe_open
with fast_safe_open(st_file, framework="np") as f:
for name in f.keys():
param = f.get_tensor(name)
param = f.get_slice(name)
yield name, param

View File

@@ -0,0 +1,32 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from fastdeploy.config import LoadChoices, LoadConfig
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
from fastdeploy.model_executor.model_loader.default_loader import DefaultModelLoader
from fastdeploy.model_executor.model_loader.new_loader import NewModelLoader
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""get_model_loader"""
if load_config.load_choices == LoadChoices.NEW_LOADER:
return NewModelLoader(load_config)
return DefaultModelLoader(load_config)
__all__ = ["get_model_loader"]

View File

@@ -0,0 +1,38 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from abc import ABC, abstractmethod
from paddle import nn
from fastdeploy.config import FDConfig, LoadConfig, ModelConfig
class BaseModelLoader(ABC):
"""Base class for model loaders."""
def __init__(self, load_config: LoadConfig):
self.load_config = load_config
@abstractmethod
def download_model(self, load_config: ModelConfig) -> None:
"""Download a model so that it can be immediately loaded."""
raise NotImplementedError
@abstractmethod
def load_model(self, fd_config: FDConfig) -> nn.Layer:
"""Load a model with the given configurations."""
raise NotImplementedError

View File

@@ -14,68 +14,30 @@
# limitations under the License.
"""
from abc import ABC, abstractmethod
import paddle
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig, LoadConfig, ModelConfig
from fastdeploy.model_executor.load_weight_utils import load_composite_checkpoint
from fastdeploy.model_executor.models.deepseek_v3 import DeepSeekV3PretrainedModel
from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_PretrainedModel
from fastdeploy.model_executor.models.ernie4_5_mtp import Ernie4_5_MTPPretrainedModel
from fastdeploy.model_executor.models.ernie4_5_vl.ernie4_5_vl_moe import (
Ernie4_5_VLPretrainedModel,
from fastdeploy.model_executor.load_weight_utils import (
load_composite_checkpoint,
measure_time,
)
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
from fastdeploy.model_executor.model_loader.utils import get_pretrain_cls
from fastdeploy.model_executor.models.model_base import ModelRegistry
from fastdeploy.model_executor.models.qwen2 import Qwen2PretrainedModel
from fastdeploy.model_executor.models.qwen3 import Qwen3PretrainedModel
from fastdeploy.model_executor.models.qwen3moe import Qwen3MoePretrainedModel
from fastdeploy.platforms import current_platform
MODEL_CLASSES = {
"Ernie4_5_MoeForCausalLM": Ernie4_5_PretrainedModel,
"Ernie4_5_MTPForCausalLM": Ernie4_5_MTPPretrainedModel,
"Qwen2ForCausalLM": Qwen2PretrainedModel,
"Qwen3ForCausalLM": Qwen3PretrainedModel,
"Qwen3MoeForCausalLM": Qwen3MoePretrainedModel,
"Ernie4_5_ForCausalLM": Ernie4_5_PretrainedModel,
"DeepseekV3ForCausalLM": DeepSeekV3PretrainedModel,
"Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLPretrainedModel,
}
def get_model_from_loader(fd_config: FDConfig) -> nn.Layer:
"""load or download model"""
model_loader = DefaultModelLoader(fd_config.load_config)
model = model_loader.load_model(fd_config)
return model
class BaseModelLoader(ABC):
"""Base class for model loaders."""
def __init__(self, load_config: LoadConfig):
self.load_config = load_config
@abstractmethod
def download_model(self, load_config: ModelConfig) -> None:
"""Download a model so that it can be immediately loaded."""
raise NotImplementedError
@abstractmethod
def load_model(self, fd_config: FDConfig) -> nn.Layer:
"""Load a model with the given configurations."""
raise NotImplementedError
class DefaultModelLoader(BaseModelLoader):
"""ModelLoader that can load registered models"""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
logger.info("Load the model and weights using DefaultModelLoader")
def download_model(self, model_config: ModelConfig) -> None:
"""download_model"""
pass
def clean_memory_fragments(self, state_dict: dict) -> None:
@@ -88,9 +50,22 @@ class DefaultModelLoader(BaseModelLoader):
paddle.device.cuda.empty_cache()
paddle.device.synchronize()
@measure_time
def load_weights(self, model, fd_config: FDConfig, architectures: str) -> None:
model_class = get_pretrain_cls(architectures)
state_dict = load_composite_checkpoint(
fd_config.model_config.model,
model_class,
fd_config,
return_numpy=True,
)
model.set_state_dict(state_dict)
self.clean_memory_fragments(state_dict)
def load_model(self, fd_config: FDConfig) -> nn.Layer:
context = paddle.LazyGuard()
architectures = fd_config.model_config.architectures[0]
logger.info(f"Starting to load model {architectures}")
if fd_config.load_config.dynamic_load_weight:
# register rl model
@@ -109,13 +84,5 @@ class DefaultModelLoader(BaseModelLoader):
return model
# TODO(gongshaotian): Now, only support safetensor
model_class = MODEL_CLASSES[architectures]
state_dict = load_composite_checkpoint(
fd_config.model_config.model,
model_class,
fd_config,
return_numpy=True,
)
model.set_state_dict(state_dict)
self.clean_memory_fragments(state_dict)
self.load_weights(model, fd_config, architectures)
return model

View File

@@ -0,0 +1,74 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig, LoadConfig, ModelConfig
from fastdeploy.model_executor.load_weight_utils import (
get_all_safetensors,
measure_time,
safetensors_weights_iterator,
)
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
from fastdeploy.model_executor.models.model_base import ModelRegistry
from fastdeploy.platforms import current_platform
class NewModelLoader(BaseModelLoader):
"""ModelLoader that can load registered models"""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
def download_model(self, model_config: ModelConfig) -> None:
pass
def clean_memory_fragments(self) -> None:
"""clean_memory_fragments"""
if current_platform.is_cuda():
paddle.device.cuda.empty_cache()
paddle.device.synchronize()
@measure_time
def load_weights(self, model, fd_config: FDConfig) -> None:
_, safetensor_files = get_all_safetensors(fd_config.model_config.model)
weights_iterator = safetensors_weights_iterator(safetensor_files)
model.load_weights(weights_iterator)
self.clean_memory_fragments()
def load_model(self, fd_config: FDConfig) -> nn.Layer:
architectures = fd_config.model_config.architectures[0]
logger.info(f"Starting to load model {architectures}")
if fd_config.load_config.dynamic_load_weight:
# register rl model
import fastdeploy.rl # noqa
architectures = architectures + "RL"
model_cls = ModelRegistry.get_class(architectures)
model = model_cls(fd_config)
model.eval()
# RL model not need set_state_dict
if fd_config.load_config.dynamic_load_weight:
return model
self.load_weights(model, fd_config)
return model

View File

@@ -0,0 +1,43 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from paddleformers.transformers import PretrainedModel
from fastdeploy.model_executor.models.deepseek_v3 import DeepSeekV3PretrainedModel
from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_PretrainedModel
from fastdeploy.model_executor.models.ernie4_5_mtp import Ernie4_5_MTPPretrainedModel
from fastdeploy.model_executor.models.ernie4_5_vl.ernie4_5_vl_moe import (
Ernie4_5_VLPretrainedModel,
)
from fastdeploy.model_executor.models.qwen2 import Qwen2PretrainedModel
from fastdeploy.model_executor.models.qwen3 import Qwen3PretrainedModel
from fastdeploy.model_executor.models.qwen3moe import Qwen3MoePretrainedModel
MODEL_CLASSES = {
"Ernie4_5_MoeForCausalLM": Ernie4_5_PretrainedModel,
"Ernie4_5_MTPForCausalLM": Ernie4_5_MTPPretrainedModel,
"Qwen2ForCausalLM": Qwen2PretrainedModel,
"Qwen3ForCausalLM": Qwen3PretrainedModel,
"Qwen3MoeForCausalLM": Qwen3MoePretrainedModel,
"Ernie4_5_ForCausalLM": Ernie4_5_PretrainedModel,
"DeepseekV3ForCausalLM": DeepSeekV3PretrainedModel,
"Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLPretrainedModel,
}
def get_pretrain_cls(architectures: str) -> PretrainedModel:
"""get_pretrain_cls"""
return MODEL_CLASSES[architectures]

View File

@@ -228,7 +228,7 @@ class Qwen3ForCausalLM(ModelForCasualLM):
fd_config (FDConfig): Configurations for the LLM model.
"""
super(Qwen3ForCausalLM, self).__init__(fd_config)
self.fd_config = fd_config
self.model = Qwen3Model(fd_config=fd_config)
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
@@ -245,6 +245,47 @@ class Qwen3ForCausalLM(ModelForCasualLM):
""" """
return "Qwen3ForCausalLM"
@paddle.no_grad()
def load_weights(self, weights_iterator) -> None:
"""
Load model parameters from a given weights_iterator object.
Args:
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""
from fastdeploy.model_executor.models.utils import default_weight_loader
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("up_gate_proj", "gate_proj", "gate"),
("up_gate_proj", "up_proj", "up"),
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
]
params_dict = dict(self.named_parameters())
for loaded_weight_name, loaded_weight in weights_iterator:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in loaded_weight_name:
continue
model_param_name = loaded_weight_name.replace(weight_name, param_name)
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight, shard_id)
break
else:
if loaded_weight_name not in params_dict:
continue
param = params_dict[loaded_weight_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight)
@paddle.no_grad()
def set_state_dict(self, state_dict):
"""

View File

@@ -24,7 +24,7 @@ import random
import re
import struct
from functools import partial
from typing import NamedTuple, Optional
from typing import Any, NamedTuple, Optional, Union
import numpy as np
import paddle
@@ -40,10 +40,51 @@ from paddleformers.utils.env import (
from paddleformers.utils.log import logger
from tqdm import tqdm
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.utils import get_tensor
MAX_BSZ = 512
MAX_DRAFT_TOKENS = 6
def set_weight_attrs(param, param_attr_map: Optional[dict[str, Any]]):
if param_attr_map is None:
return
for key, value in param_attr_map.items():
setattr(param, key, value)
def default_weight_loader(fd_config: FDConfig) -> None:
"""Default weight loader"""
def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
"""fn"""
try:
output_dim = getattr(param, "output_dim", None)
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
dim = -1 if output_dim else 0
size = loaded_weight.get_shape()[dim]
block_size = size // fd_config.parallel_config.tensor_parallel_size
shard_offset = fd_config.parallel_config.tensor_parallel_rank * block_size
shard_size = (fd_config.parallel_config.tensor_parallel_rank + 1) * block_size
if output_dim:
loaded_weight = loaded_weight[..., shard_offset:shard_size]
else:
loaded_weight = loaded_weight[shard_offset:shard_size, ...]
loaded_weight = get_tensor(loaded_weight)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)
except Exception:
raise
return fn
class LayerIdPlaceholder(str, enum.Enum):
"""LayerIdPlaceholder"""

View File

@@ -20,7 +20,6 @@ import paddle
from paddle import nn
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.model_loader import ModelRegistry
from fastdeploy.model_executor.models.ernie4_5_moe import (
Ernie4_5_MoeForCausalLM,
Ernie4_5_PretrainedModel,
@@ -29,6 +28,7 @@ from fastdeploy.model_executor.models.ernie4_5_vl.ernie4_5_vl_moe import (
Ernie4_5_VLMoeForConditionalGeneration,
Ernie4_5_VLPretrainedModel,
)
from fastdeploy.model_executor.models.model_base import ModelRegistry
from fastdeploy.model_executor.models.qwen2 import (
Qwen2ForCausalLM,
Qwen2PretrainedModel,

View File

@@ -84,9 +84,10 @@ class MTPProposer(Proposer):
"""
Load MTP Layer
"""
from fastdeploy.model_executor.model_loader import get_model_from_loader
from fastdeploy.model_executor.model_loader import get_model_loader
self.model = get_model_from_loader(self.cfg)
model_loader = get_model_loader(load_config=self.cfg.load_config)
self.model = model_loader.load_model(fd_config=self.cfg)
def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
"""Set dummy prefill inputs to model_inputs"""

View File

@@ -41,7 +41,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
from fastdeploy.model_executor.model_loader import get_model_from_loader
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.ops.gcu import set_value_by_flags_and_idx
from fastdeploy.model_executor.pre_and_post_process import (
post_process,
@@ -547,10 +547,9 @@ class GCUModelRunner(ModelRunnerBase):
def load_model(self) -> None:
"""load or download model"""
logger.info(f"Starting to load model {self.model_config.architectures[0]}")
time_before_load = time.perf_counter()
# 1. Load original model
self.model = get_model_from_loader(fd_config=self.fd_config)
model_loader = get_model_loader(load_config=self.fd_config.load_config)
self.model = model_loader.load_model(fd_config=self.fd_config)
# 1.1 Load RL dynamic model
if self.fd_config.load_config.dynamic_load_weight:
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager
@@ -561,9 +560,6 @@ class GCUModelRunner(ModelRunnerBase):
# 3. Load drafter model(for speculative decoding)
time_after_load = time.perf_counter()
logger.info(f"Model loading took {time_after_load - time_before_load} seconds")
# 4. Init proposer for speculative method
self._init_speculative_proposer()

View File

@@ -40,7 +40,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope_3d
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
from fastdeploy.model_executor.model_loader import get_model_from_loader
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.ops.gpu import (
recover_decode_task,
set_value_by_flags_and_idx,
@@ -813,9 +813,9 @@ class GPUModelRunner(ModelRunnerBase):
def load_model(self) -> None:
"""load or download model"""
logger.info(f"Starting to load model {self.model_config.architectures[0]}")
time_before_load = time.perf_counter()
# 1. Load original model
self.model = get_model_from_loader(fd_config=self.fd_config)
model_loader = get_model_loader(load_config=self.fd_config.load_config)
self.model = model_loader.load_model(fd_config=self.fd_config)
# 1.1 Load RL dynamic model
if self.fd_config.load_config.dynamic_load_weight:
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager
@@ -826,9 +826,6 @@ class GPUModelRunner(ModelRunnerBase):
# 3. Load drafter model(for speculative decoding)
time_after_load = time.perf_counter()
logger.info(f"Model loading took {time_after_load - time_before_load} seconds")
# 4. Init proposer for speculative method
self._init_speculative_proposer()

View File

@@ -37,7 +37,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
from fastdeploy.model_executor.model_loader import get_model_from_loader
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.ops.iluvatar import set_value_by_flags_and_idx
from fastdeploy.model_executor.pre_and_post_process import (
post_process,
@@ -519,17 +519,14 @@ class IluvatarModelRunner(ModelRunnerBase):
def load_model(self) -> None:
"""load or download model"""
logger.info(f"Starting to load model {self.model_config.architectures[0]}")
time_before_load = time.perf_counter()
# 1. Load original model
self.model = get_model_from_loader(fd_config=self.fd_config)
model_loader = get_model_loader(load_config=self.fd_config.load_config)
self.model = model_loader.load_model(fd_config=self.fd_config)
# 2. Load lora model
# 3. Load drafter model(for speculative decoding)
time_after_load = time.perf_counter()
logger.info(f"Model loading took {time_after_load - time_before_load} seconds")
def get_model(self) -> nn.Layer:
"""get current model"""
return self.model

View File

@@ -573,6 +573,13 @@ def parse_args():
help="Configuration of early stop.",
)
parser.add_argument(
"--load_choices",
type=str,
default="default",
help="The format of the model weights to load. default/new_loader.",
)
args = parser.parse_args()
return args

View File

@@ -37,7 +37,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import Sampler
from fastdeploy.model_executor.model_loader import get_model_from_loader
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.ops.xpu import (
adjust_batch,
get_infer_param,
@@ -686,17 +686,14 @@ class XPUModelRunner(ModelRunnerBase):
def load_model(self) -> None:
"""load or download model"""
logger.info(f"Starting to load model {self.model_config.architectures[0]}")
time_before_load = time.perf_counter()
# 1. Load original model
self.model = get_model_from_loader(fd_config=self.fd_config)
model_loader = get_model_loader(load_config=self.fd_config.load_config)
self.model = model_loader.load_model(fd_config=self.fd_config)
# 2. Load lora model
# 3. Load drafter model(for speculative decoding)
time_after_load = time.perf_counter()
logger.info(f"Model loading took {time_after_load - time_before_load} seconds")
def get_model(self) -> nn.Layer:
"""get current model"""
return self.model