diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 4450c3937..112ae40db 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -128,6 +128,7 @@ class ModelConfig: self.quantization = None self.pad_token_id: int = -1 self.eos_tokens_lens: int = 2 + self.model_format = "auto" for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) @@ -165,6 +166,7 @@ class ModelConfig: self.override_name_from_config() self.read_from_env() + self.read_model_config() def override_name_from_config(self): """ @@ -206,6 +208,29 @@ class ModelConfig: reset_config_value("COMPRESSION_RATIO", 1.0) reset_config_value("ROPE_THETA", 10000) + def read_model_config(self): + config_path = os.path.join(self.model, "config.json") + if os.path.exists(config_path): + self.model_config = json.load(open(config_path, "r", encoding="utf-8")) + if "torch_dtype" in self.model_config and "dtype" in self.model_config: + raise ValueError( + "Only one of 'torch_dtype' or 'dtype' should be present in config.json. " + "Found both, which indicates an ambiguous model format. " + "Please ensure your config.json contains only one dtype field." + ) + elif "torch_dtype" in self.model_config: + self.model_format = "torch" + logger.info("The model format is Hugging Face") + elif "dtype" in self.model_config: + self.model_format = "paddle" + logger.info("The model format is Paddle") + else: + raise ValueError( + "Unknown model format. Please ensure your config.json contains " + "either 'torch_dtype' (for Hugging Face models) or 'dtype' (for Paddle models) field. " + f"Config file path: {config_path}" + ) + def _get_download_model(self, model_name, model_type="default"): # TODO: Provide dynamic graph for self-downloading and save to the specified download directory. pass @@ -1034,6 +1059,9 @@ class FDConfig: self.disable_any_whitespace = disable_any_whitespace self._str_to_list("innode_prefill_ports", int) + if envs.FD_FOR_TORCH_MODEL_FORMAT: + self.model_config.model_format = "torch" + # TODO self.max_prefill_batch = 3 if current_platform.is_xpu(): diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 8bb5695e7..cbdb0cecf 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -357,7 +357,7 @@ class EngineArgs: """The format of the model weights to load. Options include: - "default": default loader. - - "new_loader": new loader. + - "default_v1": default_v1 loader. """ def __post_init__(self): diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 0155e260f..790af9552 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -86,6 +86,7 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"), # support max connections "FD_SUPPORT_MAX_CONNECTIONS": lambda: int(os.getenv("FD_SUPPORT_MAX_CONNECTIONS", "1024")), + "FD_FOR_TORCH_MODEL_FORMAT": lambda: bool(int(os.getenv("FD_FOR_TORCH_MODEL_FORMAT", "0"))), } diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index d6958a919..b864e4aa3 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -57,6 +57,7 @@ class UnquantizedLinearMethod(QuantMethodBase): { **extra_weight_attrs, "weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)), + "model_format": extra_weight_attrs.get("model_format", ""), }, ) @@ -343,7 +344,9 @@ class ColumnParallelLinear(LinearBase): weight_loader=( self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config) ), + model_format=fd_config.model_config.model_format, ) + if self.nranks > 0: if self.with_bias: # col parallel @@ -402,6 +405,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ) def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): + model_format = getattr(param, "model_format", "") + if model_format == "torch": + loaded_weight = loaded_weight.transpose([1, 0]) output_dim = getattr(param, "output_dim", None) assert output_dim is not None shard_dim = -1 if output_dim else 0 @@ -424,7 +430,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): # Tensor parallelism splits the weight along the output_dim if self.nranks != 1: dim = -1 if output_dim else 0 - if isinstance(loaded_weight, np.ndarray): + if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)): size = loaded_weight.shape[dim] else: size = loaded_weight.get_shape()[dim] @@ -523,6 +529,9 @@ class QKVParallelLinear(ColumnParallelLinear): assert output_dim is not None dim = -1 if output_dim else 0 head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) + model_format = getattr(param, "model_format", "") + if model_format == "torch": + loaded_weight = loaded_weight.transpose([1, 0]) if loaded_shard_id is None: # Loaded weight is already fused on disk shard_offsets = [ @@ -541,7 +550,8 @@ class QKVParallelLinear(ColumnParallelLinear): assert loaded_shard_id in ["q", "k", "v"] # Tensor parallelism splits the weight along the output_dim if self.nranks != 1: - if isinstance(loaded_weight, np.ndarray): + dim = -1 if output_dim else 0 + if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)): size = loaded_weight.shape[dim] else: size = loaded_weight.get_shape()[dim] @@ -712,6 +722,7 @@ class RowParallelLinear(LinearBase): weight_loader=( self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config) ), + model_format=fd_config.model_config.model_format, ) if self.nranks > 0: if self.with_bias: diff --git a/fastdeploy/model_executor/layers/lm_head.py b/fastdeploy/model_executor/layers/lm_head.py index f71f828eb..d976c2e3a 100644 --- a/fastdeploy/model_executor/layers/lm_head.py +++ b/fastdeploy/model_executor/layers/lm_head.py @@ -22,7 +22,7 @@ from paddle import nn from paddle.distributed import fleet from fastdeploy.config import FDConfig -from fastdeploy.model_executor.utils import set_weight_attrs +from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs from .utils import get_tensor @@ -61,6 +61,7 @@ class ParallelLMHead(nn.Layer): self.use_ep: bool = fd_config.parallel_config.use_ep self.column_cut = True self.nranks = fd_config.parallel_config.tensor_parallel_size + self.fd_config = fd_config ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear @@ -90,7 +91,14 @@ class ParallelLMHead(nn.Layer): weight_attr=None, has_bias=True if self.bias_key is not None else False, gather_output=need_gather, - fuse_matmul_bias=False, # False diff更小 + fuse_matmul_bias=False, + ) + set_weight_attrs( + self.linear.weight, + { + "weight_loader": default_weight_loader(self.fd_config), + "model_format": self.fd_config.model_config.model_format, + }, ) if self.nranks > 1: set_weight_attrs(self.linear.weight, {"output_dim": True}) @@ -102,8 +110,16 @@ class ParallelLMHead(nn.Layer): weight_attr=None, has_bias=True if self.bias_key is not None else False, input_is_parallel=False, - fuse_matmul_bias=False, # False diff更小 + fuse_matmul_bias=False, ) + set_weight_attrs( + self.linear.weight, + { + "weight_loader": default_weight_loader(self.fd_config), + "model_format": self.fd_config.model_config.model_format, + }, + ) + if self.nranks > 1: set_weight_attrs(self.linear.weight, {"output_dim": False}) diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 475b3015c..6db31a988 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -219,7 +219,7 @@ class FusedMoE(nn.Layer): def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None): dim = -1 if shard_dim else 0 if self.tp_size > 1: - if isinstance(loaded_weight, np.ndarray): + if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)): size = loaded_weight.shape[dim] else: size = loaded_weight.get_shape()[dim] @@ -259,7 +259,7 @@ class FusedMoE(nn.Layer): def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None): if self.tp_size > 1 and shard_dim is not None: dim = -1 if shard_dim else 0 - if isinstance(loaded_weight, np.ndarray): + if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)): size = loaded_weight.shape[dim] else: size = loaded_weight.get_shape()[dim] diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index 6aacb3a59..9c5503f24 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -29,6 +29,7 @@ from safetensors import safe_open from tqdm import tqdm from fastdeploy.config import FDConfig +from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.models.tp_utils import ( check_tensor_parallel_prerequisites, ) @@ -180,8 +181,9 @@ def fast_weights_iterator(safe_tensor_list: list[str]): ): with fast_safe_open(st_file, framework="np") as f: for name in f.keys(): - param = f.get_slice(name) - yield name, param + param_slice = f.get_slice(name) + paddle_tensor = get_tensor(param_slice) + yield name, paddle_tensor def fastsafetensors_weights_iterator( diff --git a/fastdeploy/model_executor/models/qwen2.py b/fastdeploy/model_executor/models/qwen2.py index eaa1e26a8..c9c15a33f 100644 --- a/fastdeploy/model_executor/models/qwen2.py +++ b/fastdeploy/model_executor/models/qwen2.py @@ -334,6 +334,10 @@ class Qwen2ForCausalLM(ModelForCasualLM): params_dict = dict(self.named_parameters()) process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers())) for loaded_weight_name, loaded_weight in weights_iterator: + model_format = self.fd_config.model_config.model_format + # Because the prefix for Paddle is qwen2, and for Hugging Face it is model. + if model_format == "torch": + loaded_weight_name = loaded_weight_name.replace("model", "qwen2") for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in loaded_weight_name: continue diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 31cd67172..ded21574f 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -16,6 +16,8 @@ from typing import Any, Optional, Union +import paddle + from fastdeploy.config import FDConfig from fastdeploy.model_executor.layers.utils import get_tensor @@ -155,10 +157,16 @@ def default_weight_loader(fd_config: FDConfig) -> None: def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None): """fn""" output_dim = getattr(param, "output_dim", None) + model_format = getattr(param, "model_format", "") + if model_format == "torch": + loaded_weight = loaded_weight.transpose([1, 0]) # Tensor parallelism splits the weight along the output_dim if output_dim is not None and fd_config.parallel_config.tensor_parallel_size > 1: dim = -1 if output_dim else 0 - size = loaded_weight.get_shape()[dim] + if isinstance(loaded_weight, paddle.Tensor): + size = loaded_weight.shape[dim] + else: + size = loaded_weight.get_shape()[dim] block_size = size // fd_config.parallel_config.tensor_parallel_size shard_offset = fd_config.parallel_config.tensor_parallel_rank * block_size shard_size = (fd_config.parallel_config.tensor_parallel_rank + 1) * block_size diff --git a/tests/model_loader/test_common_model.py b/tests/model_loader/test_common_model.py index b8b005f02..3b3f4fbf7 100644 --- a/tests/model_loader/test_common_model.py +++ b/tests/model_loader/test_common_model.py @@ -27,6 +27,25 @@ TokensIdText = list[tuple[list[int], str]] # (token_ids, text) +def get_model_paths(base_model_name: str) -> tuple[str, str]: + """return (fastdeploy_path, huggingface_path)""" + # FastDeploy model path + fd_base_path = os.getenv("MODEL_PATH") + if fd_base_path: + fd_model_path = os.path.join(fd_base_path, base_model_name) + else: + fd_model_path = base_model_name + + # HuggingFace model path + torch_model_path = os.path.join( + fd_base_path, + "torch", + base_model_name, + ) + + return fd_model_path, torch_model_path + + def check_tokens_id_and_text_close( *, outputs_0_lst: TokensIdText, @@ -104,6 +123,7 @@ model_param_map = { }, } + params = [] for model, cfg in model_param_map.items(): for q in cfg["quantizations"]: @@ -176,3 +196,84 @@ def test_common_model( name_0="default loader", name_1="default_v1 loader", ) + + +hugging_face_model_param_map = { + "Qwen2.5-7B-Instruct": { + "tensor_parallel_size": 2, + "quantizations": ["None"], + }, +} + +hf_params = [] +for model, cfg in hugging_face_model_param_map.items(): + for q in cfg["quantizations"]: + hf_params.append( + pytest.param( + model, + cfg.get("tensor_parallel_size", 1), + cfg.get("max_model_len", 1024), + q, + cfg.get("max_tokens", 32), + marks=[pytest.mark.core_model], + ) + ) + + +@pytest.mark.parametrize( + "model_name_or_path,tensor_parallel_size,max_model_len,quantization,max_tokens", + hf_params, +) +def test_paddle_vs_torch_model( + fd_runner, + model_name_or_path: str, + tensor_parallel_size: int, + max_model_len: int, + max_tokens: int, + quantization: str, +) -> None: + + fd_model_path, torch_model_path = get_model_paths(model_name_or_path) + + result_queue = Queue() + + p_paddle = Process( + target=form_model_get_output, + args=( + fd_runner, + fd_model_path, + tensor_parallel_size, + max_model_len, + max_tokens, + quantization, + "default", + result_queue, + ), + ) + p_paddle.start() + p_paddle.join() + paddle_outputs = result_queue.get(timeout=60) + + p_hf = Process( + target=form_model_get_output, + args=( + fd_runner, + torch_model_path, + tensor_parallel_size, + max_model_len, + max_tokens, + quantization, + "default_v1", + result_queue, + ), + ) + p_hf.start() + p_hf.join() + hf_outputs = result_queue.get(timeout=60) + + check_tokens_id_and_text_close( + outputs_0_lst=paddle_outputs, + outputs_1_lst=hf_outputs, + name_0="Paddle model (default loader)", + name_1="HuggingFace model (default_v1 loader)", + )