[Features] support hugging face qwen3 dense and qwen2 model (#3574)

* support qwen2 and qwen3 hugging face

* fix moe

* defualt_v1 loader

* hugging_face_format deprecated

* modify hugging_face_foramt to model_format

* model_format auto

* fix environemt

* fix bug

* fix qwen3-0.6 bug

* model_format is str

* fix
This commit is contained in:
lizexu123
2025-08-26 10:54:53 +08:00
committed by GitHub
parent 66c5addce4
commit c43a4bec00
10 changed files with 182 additions and 11 deletions

View File

@@ -128,6 +128,7 @@ class ModelConfig:
self.quantization = None self.quantization = None
self.pad_token_id: int = -1 self.pad_token_id: int = -1
self.eos_tokens_lens: int = 2 self.eos_tokens_lens: int = 2
self.model_format = "auto"
for key, value in args.items(): for key, value in args.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
@@ -165,6 +166,7 @@ class ModelConfig:
self.override_name_from_config() self.override_name_from_config()
self.read_from_env() self.read_from_env()
self.read_model_config()
def override_name_from_config(self): def override_name_from_config(self):
""" """
@@ -206,6 +208,29 @@ class ModelConfig:
reset_config_value("COMPRESSION_RATIO", 1.0) reset_config_value("COMPRESSION_RATIO", 1.0)
reset_config_value("ROPE_THETA", 10000) reset_config_value("ROPE_THETA", 10000)
def read_model_config(self):
config_path = os.path.join(self.model, "config.json")
if os.path.exists(config_path):
self.model_config = json.load(open(config_path, "r", encoding="utf-8"))
if "torch_dtype" in self.model_config and "dtype" in self.model_config:
raise ValueError(
"Only one of 'torch_dtype' or 'dtype' should be present in config.json. "
"Found both, which indicates an ambiguous model format. "
"Please ensure your config.json contains only one dtype field."
)
elif "torch_dtype" in self.model_config:
self.model_format = "torch"
logger.info("The model format is Hugging Face")
elif "dtype" in self.model_config:
self.model_format = "paddle"
logger.info("The model format is Paddle")
else:
raise ValueError(
"Unknown model format. Please ensure your config.json contains "
"either 'torch_dtype' (for Hugging Face models) or 'dtype' (for Paddle models) field. "
f"Config file path: {config_path}"
)
def _get_download_model(self, model_name, model_type="default"): def _get_download_model(self, model_name, model_type="default"):
# TODO: Provide dynamic graph for self-downloading and save to the specified download directory. # TODO: Provide dynamic graph for self-downloading and save to the specified download directory.
pass pass
@@ -1034,6 +1059,9 @@ class FDConfig:
self.disable_any_whitespace = disable_any_whitespace self.disable_any_whitespace = disable_any_whitespace
self._str_to_list("innode_prefill_ports", int) self._str_to_list("innode_prefill_ports", int)
if envs.FD_FOR_TORCH_MODEL_FORMAT:
self.model_config.model_format = "torch"
# TODO # TODO
self.max_prefill_batch = 3 self.max_prefill_batch = 3
if current_platform.is_xpu(): if current_platform.is_xpu():

View File

@@ -357,7 +357,7 @@ class EngineArgs:
"""The format of the model weights to load. """The format of the model weights to load.
Options include: Options include:
- "default": default loader. - "default": default loader.
- "new_loader": new loader. - "default_v1": default_v1 loader.
""" """
def __post_init__(self): def __post_init__(self):

View File

@@ -86,6 +86,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"), "FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"),
# support max connections # support max connections
"FD_SUPPORT_MAX_CONNECTIONS": lambda: int(os.getenv("FD_SUPPORT_MAX_CONNECTIONS", "1024")), "FD_SUPPORT_MAX_CONNECTIONS": lambda: int(os.getenv("FD_SUPPORT_MAX_CONNECTIONS", "1024")),
"FD_FOR_TORCH_MODEL_FORMAT": lambda: bool(int(os.getenv("FD_FOR_TORCH_MODEL_FORMAT", "0"))),
} }

View File

@@ -57,6 +57,7 @@ class UnquantizedLinearMethod(QuantMethodBase):
{ {
**extra_weight_attrs, **extra_weight_attrs,
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)), "weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
"model_format": extra_weight_attrs.get("model_format", ""),
}, },
) )
@@ -343,7 +344,9 @@ class ColumnParallelLinear(LinearBase):
weight_loader=( weight_loader=(
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config) self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
), ),
model_format=fd_config.model_config.model_format,
) )
if self.nranks > 0: if self.nranks > 0:
if self.with_bias: if self.with_bias:
# col parallel # col parallel
@@ -402,6 +405,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
) )
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
model_format = getattr(param, "model_format", "")
if model_format == "torch":
loaded_weight = loaded_weight.transpose([1, 0])
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
assert output_dim is not None assert output_dim is not None
shard_dim = -1 if output_dim else 0 shard_dim = -1 if output_dim else 0
@@ -424,7 +430,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Tensor parallelism splits the weight along the output_dim # Tensor parallelism splits the weight along the output_dim
if self.nranks != 1: if self.nranks != 1:
dim = -1 if output_dim else 0 dim = -1 if output_dim else 0
if isinstance(loaded_weight, np.ndarray): if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
size = loaded_weight.shape[dim] size = loaded_weight.shape[dim]
else: else:
size = loaded_weight.get_shape()[dim] size = loaded_weight.get_shape()[dim]
@@ -523,6 +529,9 @@ class QKVParallelLinear(ColumnParallelLinear):
assert output_dim is not None assert output_dim is not None
dim = -1 if output_dim else 0 dim = -1 if output_dim else 0
head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
model_format = getattr(param, "model_format", "")
if model_format == "torch":
loaded_weight = loaded_weight.transpose([1, 0])
if loaded_shard_id is None: if loaded_shard_id is None:
# Loaded weight is already fused on disk # Loaded weight is already fused on disk
shard_offsets = [ shard_offsets = [
@@ -541,7 +550,8 @@ class QKVParallelLinear(ColumnParallelLinear):
assert loaded_shard_id in ["q", "k", "v"] assert loaded_shard_id in ["q", "k", "v"]
# Tensor parallelism splits the weight along the output_dim # Tensor parallelism splits the weight along the output_dim
if self.nranks != 1: if self.nranks != 1:
if isinstance(loaded_weight, np.ndarray): dim = -1 if output_dim else 0
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
size = loaded_weight.shape[dim] size = loaded_weight.shape[dim]
else: else:
size = loaded_weight.get_shape()[dim] size = loaded_weight.get_shape()[dim]
@@ -712,6 +722,7 @@ class RowParallelLinear(LinearBase):
weight_loader=( weight_loader=(
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config) self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
), ),
model_format=fd_config.model_config.model_format,
) )
if self.nranks > 0: if self.nranks > 0:
if self.with_bias: if self.with_bias:

View File

@@ -22,7 +22,7 @@ from paddle import nn
from paddle.distributed import fleet from paddle.distributed import fleet
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.model_executor.utils import set_weight_attrs from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs
from .utils import get_tensor from .utils import get_tensor
@@ -61,6 +61,7 @@ class ParallelLMHead(nn.Layer):
self.use_ep: bool = fd_config.parallel_config.use_ep self.use_ep: bool = fd_config.parallel_config.use_ep
self.column_cut = True self.column_cut = True
self.nranks = fd_config.parallel_config.tensor_parallel_size self.nranks = fd_config.parallel_config.tensor_parallel_size
self.fd_config = fd_config
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear
@@ -90,7 +91,14 @@ class ParallelLMHead(nn.Layer):
weight_attr=None, weight_attr=None,
has_bias=True if self.bias_key is not None else False, has_bias=True if self.bias_key is not None else False,
gather_output=need_gather, gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小 fuse_matmul_bias=False,
)
set_weight_attrs(
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"model_format": self.fd_config.model_config.model_format,
},
) )
if self.nranks > 1: if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": True}) set_weight_attrs(self.linear.weight, {"output_dim": True})
@@ -102,8 +110,16 @@ class ParallelLMHead(nn.Layer):
weight_attr=None, weight_attr=None,
has_bias=True if self.bias_key is not None else False, has_bias=True if self.bias_key is not None else False,
input_is_parallel=False, input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小 fuse_matmul_bias=False,
) )
set_weight_attrs(
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"model_format": self.fd_config.model_config.model_format,
},
)
if self.nranks > 1: if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": False}) set_weight_attrs(self.linear.weight, {"output_dim": False})

View File

@@ -219,7 +219,7 @@ class FusedMoE(nn.Layer):
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None): def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
dim = -1 if shard_dim else 0 dim = -1 if shard_dim else 0
if self.tp_size > 1: if self.tp_size > 1:
if isinstance(loaded_weight, np.ndarray): if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
size = loaded_weight.shape[dim] size = loaded_weight.shape[dim]
else: else:
size = loaded_weight.get_shape()[dim] size = loaded_weight.get_shape()[dim]
@@ -259,7 +259,7 @@ class FusedMoE(nn.Layer):
def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None): def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
if self.tp_size > 1 and shard_dim is not None: if self.tp_size > 1 and shard_dim is not None:
dim = -1 if shard_dim else 0 dim = -1 if shard_dim else 0
if isinstance(loaded_weight, np.ndarray): if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
size = loaded_weight.shape[dim] size = loaded_weight.shape[dim]
else: else:
size = loaded_weight.get_shape()[dim] size = loaded_weight.get_shape()[dim]

View File

@@ -29,6 +29,7 @@ from safetensors import safe_open
from tqdm import tqdm from tqdm import tqdm
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.models.tp_utils import ( from fastdeploy.model_executor.models.tp_utils import (
check_tensor_parallel_prerequisites, check_tensor_parallel_prerequisites,
) )
@@ -180,8 +181,9 @@ def fast_weights_iterator(safe_tensor_list: list[str]):
): ):
with fast_safe_open(st_file, framework="np") as f: with fast_safe_open(st_file, framework="np") as f:
for name in f.keys(): for name in f.keys():
param = f.get_slice(name) param_slice = f.get_slice(name)
yield name, param paddle_tensor = get_tensor(param_slice)
yield name, paddle_tensor
def fastsafetensors_weights_iterator( def fastsafetensors_weights_iterator(

View File

@@ -334,6 +334,10 @@ class Qwen2ForCausalLM(ModelForCasualLM):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers())) process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
for loaded_weight_name, loaded_weight in weights_iterator: for loaded_weight_name, loaded_weight in weights_iterator:
model_format = self.fd_config.model_config.model_format
# Because the prefix for Paddle is qwen2, and for Hugging Face it is model.
if model_format == "torch":
loaded_weight_name = loaded_weight_name.replace("model", "qwen2")
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in loaded_weight_name: if weight_name not in loaded_weight_name:
continue continue

View File

@@ -16,6 +16,8 @@
from typing import Any, Optional, Union from typing import Any, Optional, Union
import paddle
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.layers.utils import get_tensor
@@ -155,9 +157,15 @@ def default_weight_loader(fd_config: FDConfig) -> None:
def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None): def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
"""fn""" """fn"""
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
model_format = getattr(param, "model_format", "")
if model_format == "torch":
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim # Tensor parallelism splits the weight along the output_dim
if output_dim is not None and fd_config.parallel_config.tensor_parallel_size > 1: if output_dim is not None and fd_config.parallel_config.tensor_parallel_size > 1:
dim = -1 if output_dim else 0 dim = -1 if output_dim else 0
if isinstance(loaded_weight, paddle.Tensor):
size = loaded_weight.shape[dim]
else:
size = loaded_weight.get_shape()[dim] size = loaded_weight.get_shape()[dim]
block_size = size // fd_config.parallel_config.tensor_parallel_size block_size = size // fd_config.parallel_config.tensor_parallel_size
shard_offset = fd_config.parallel_config.tensor_parallel_rank * block_size shard_offset = fd_config.parallel_config.tensor_parallel_rank * block_size

View File

@@ -27,6 +27,25 @@ TokensIdText = list[tuple[list[int], str]]
# (token_ids, text) # (token_ids, text)
def get_model_paths(base_model_name: str) -> tuple[str, str]:
"""return (fastdeploy_path, huggingface_path)"""
# FastDeploy model path
fd_base_path = os.getenv("MODEL_PATH")
if fd_base_path:
fd_model_path = os.path.join(fd_base_path, base_model_name)
else:
fd_model_path = base_model_name
# HuggingFace model path
torch_model_path = os.path.join(
fd_base_path,
"torch",
base_model_name,
)
return fd_model_path, torch_model_path
def check_tokens_id_and_text_close( def check_tokens_id_and_text_close(
*, *,
outputs_0_lst: TokensIdText, outputs_0_lst: TokensIdText,
@@ -104,6 +123,7 @@ model_param_map = {
}, },
} }
params = [] params = []
for model, cfg in model_param_map.items(): for model, cfg in model_param_map.items():
for q in cfg["quantizations"]: for q in cfg["quantizations"]:
@@ -176,3 +196,84 @@ def test_common_model(
name_0="default loader", name_0="default loader",
name_1="default_v1 loader", name_1="default_v1 loader",
) )
hugging_face_model_param_map = {
"Qwen2.5-7B-Instruct": {
"tensor_parallel_size": 2,
"quantizations": ["None"],
},
}
hf_params = []
for model, cfg in hugging_face_model_param_map.items():
for q in cfg["quantizations"]:
hf_params.append(
pytest.param(
model,
cfg.get("tensor_parallel_size", 1),
cfg.get("max_model_len", 1024),
q,
cfg.get("max_tokens", 32),
marks=[pytest.mark.core_model],
)
)
@pytest.mark.parametrize(
"model_name_or_path,tensor_parallel_size,max_model_len,quantization,max_tokens",
hf_params,
)
def test_paddle_vs_torch_model(
fd_runner,
model_name_or_path: str,
tensor_parallel_size: int,
max_model_len: int,
max_tokens: int,
quantization: str,
) -> None:
fd_model_path, torch_model_path = get_model_paths(model_name_or_path)
result_queue = Queue()
p_paddle = Process(
target=form_model_get_output,
args=(
fd_runner,
fd_model_path,
tensor_parallel_size,
max_model_len,
max_tokens,
quantization,
"default",
result_queue,
),
)
p_paddle.start()
p_paddle.join()
paddle_outputs = result_queue.get(timeout=60)
p_hf = Process(
target=form_model_get_output,
args=(
fd_runner,
torch_model_path,
tensor_parallel_size,
max_model_len,
max_tokens,
quantization,
"default_v1",
result_queue,
),
)
p_hf.start()
p_hf.join()
hf_outputs = result_queue.get(timeout=60)
check_tokens_id_and_text_close(
outputs_0_lst=paddle_outputs,
outputs_1_lst=hf_outputs,
name_0="Paddle model (default loader)",
name_1="HuggingFace model (default_v1 loader)",
)