mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[Features] support hugging face qwen3 dense and qwen2 model (#3574)
* support qwen2 and qwen3 hugging face * fix moe * defualt_v1 loader * hugging_face_format deprecated * modify hugging_face_foramt to model_format * model_format auto * fix environemt * fix bug * fix qwen3-0.6 bug * model_format is str * fix
This commit is contained in:
@@ -128,6 +128,7 @@ class ModelConfig:
|
|||||||
self.quantization = None
|
self.quantization = None
|
||||||
self.pad_token_id: int = -1
|
self.pad_token_id: int = -1
|
||||||
self.eos_tokens_lens: int = 2
|
self.eos_tokens_lens: int = 2
|
||||||
|
self.model_format = "auto"
|
||||||
for key, value in args.items():
|
for key, value in args.items():
|
||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
@@ -165,6 +166,7 @@ class ModelConfig:
|
|||||||
|
|
||||||
self.override_name_from_config()
|
self.override_name_from_config()
|
||||||
self.read_from_env()
|
self.read_from_env()
|
||||||
|
self.read_model_config()
|
||||||
|
|
||||||
def override_name_from_config(self):
|
def override_name_from_config(self):
|
||||||
"""
|
"""
|
||||||
@@ -206,6 +208,29 @@ class ModelConfig:
|
|||||||
reset_config_value("COMPRESSION_RATIO", 1.0)
|
reset_config_value("COMPRESSION_RATIO", 1.0)
|
||||||
reset_config_value("ROPE_THETA", 10000)
|
reset_config_value("ROPE_THETA", 10000)
|
||||||
|
|
||||||
|
def read_model_config(self):
|
||||||
|
config_path = os.path.join(self.model, "config.json")
|
||||||
|
if os.path.exists(config_path):
|
||||||
|
self.model_config = json.load(open(config_path, "r", encoding="utf-8"))
|
||||||
|
if "torch_dtype" in self.model_config and "dtype" in self.model_config:
|
||||||
|
raise ValueError(
|
||||||
|
"Only one of 'torch_dtype' or 'dtype' should be present in config.json. "
|
||||||
|
"Found both, which indicates an ambiguous model format. "
|
||||||
|
"Please ensure your config.json contains only one dtype field."
|
||||||
|
)
|
||||||
|
elif "torch_dtype" in self.model_config:
|
||||||
|
self.model_format = "torch"
|
||||||
|
logger.info("The model format is Hugging Face")
|
||||||
|
elif "dtype" in self.model_config:
|
||||||
|
self.model_format = "paddle"
|
||||||
|
logger.info("The model format is Paddle")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Unknown model format. Please ensure your config.json contains "
|
||||||
|
"either 'torch_dtype' (for Hugging Face models) or 'dtype' (for Paddle models) field. "
|
||||||
|
f"Config file path: {config_path}"
|
||||||
|
)
|
||||||
|
|
||||||
def _get_download_model(self, model_name, model_type="default"):
|
def _get_download_model(self, model_name, model_type="default"):
|
||||||
# TODO: Provide dynamic graph for self-downloading and save to the specified download directory.
|
# TODO: Provide dynamic graph for self-downloading and save to the specified download directory.
|
||||||
pass
|
pass
|
||||||
@@ -1034,6 +1059,9 @@ class FDConfig:
|
|||||||
self.disable_any_whitespace = disable_any_whitespace
|
self.disable_any_whitespace = disable_any_whitespace
|
||||||
self._str_to_list("innode_prefill_ports", int)
|
self._str_to_list("innode_prefill_ports", int)
|
||||||
|
|
||||||
|
if envs.FD_FOR_TORCH_MODEL_FORMAT:
|
||||||
|
self.model_config.model_format = "torch"
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
self.max_prefill_batch = 3
|
self.max_prefill_batch = 3
|
||||||
if current_platform.is_xpu():
|
if current_platform.is_xpu():
|
||||||
|
@@ -357,7 +357,7 @@ class EngineArgs:
|
|||||||
"""The format of the model weights to load.
|
"""The format of the model weights to load.
|
||||||
Options include:
|
Options include:
|
||||||
- "default": default loader.
|
- "default": default loader.
|
||||||
- "new_loader": new loader.
|
- "default_v1": default_v1 loader.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
@@ -86,6 +86,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"),
|
"FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"),
|
||||||
# support max connections
|
# support max connections
|
||||||
"FD_SUPPORT_MAX_CONNECTIONS": lambda: int(os.getenv("FD_SUPPORT_MAX_CONNECTIONS", "1024")),
|
"FD_SUPPORT_MAX_CONNECTIONS": lambda: int(os.getenv("FD_SUPPORT_MAX_CONNECTIONS", "1024")),
|
||||||
|
"FD_FOR_TORCH_MODEL_FORMAT": lambda: bool(int(os.getenv("FD_FOR_TORCH_MODEL_FORMAT", "0"))),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -57,6 +57,7 @@ class UnquantizedLinearMethod(QuantMethodBase):
|
|||||||
{
|
{
|
||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
||||||
|
"model_format": extra_weight_attrs.get("model_format", ""),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -343,7 +344,9 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
weight_loader=(
|
weight_loader=(
|
||||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||||
),
|
),
|
||||||
|
model_format=fd_config.model_config.model_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.nranks > 0:
|
if self.nranks > 0:
|
||||||
if self.with_bias:
|
if self.with_bias:
|
||||||
# col parallel
|
# col parallel
|
||||||
@@ -402,6 +405,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||||
|
model_format = getattr(param, "model_format", "")
|
||||||
|
if model_format == "torch":
|
||||||
|
loaded_weight = loaded_weight.transpose([1, 0])
|
||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
assert output_dim is not None
|
assert output_dim is not None
|
||||||
shard_dim = -1 if output_dim else 0
|
shard_dim = -1 if output_dim else 0
|
||||||
@@ -424,7 +430,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
# Tensor parallelism splits the weight along the output_dim
|
# Tensor parallelism splits the weight along the output_dim
|
||||||
if self.nranks != 1:
|
if self.nranks != 1:
|
||||||
dim = -1 if output_dim else 0
|
dim = -1 if output_dim else 0
|
||||||
if isinstance(loaded_weight, np.ndarray):
|
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
|
||||||
size = loaded_weight.shape[dim]
|
size = loaded_weight.shape[dim]
|
||||||
else:
|
else:
|
||||||
size = loaded_weight.get_shape()[dim]
|
size = loaded_weight.get_shape()[dim]
|
||||||
@@ -523,6 +529,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
assert output_dim is not None
|
assert output_dim is not None
|
||||||
dim = -1 if output_dim else 0
|
dim = -1 if output_dim else 0
|
||||||
head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
|
head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
|
||||||
|
model_format = getattr(param, "model_format", "")
|
||||||
|
if model_format == "torch":
|
||||||
|
loaded_weight = loaded_weight.transpose([1, 0])
|
||||||
if loaded_shard_id is None:
|
if loaded_shard_id is None:
|
||||||
# Loaded weight is already fused on disk
|
# Loaded weight is already fused on disk
|
||||||
shard_offsets = [
|
shard_offsets = [
|
||||||
@@ -541,7 +550,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
assert loaded_shard_id in ["q", "k", "v"]
|
assert loaded_shard_id in ["q", "k", "v"]
|
||||||
# Tensor parallelism splits the weight along the output_dim
|
# Tensor parallelism splits the weight along the output_dim
|
||||||
if self.nranks != 1:
|
if self.nranks != 1:
|
||||||
if isinstance(loaded_weight, np.ndarray):
|
dim = -1 if output_dim else 0
|
||||||
|
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
|
||||||
size = loaded_weight.shape[dim]
|
size = loaded_weight.shape[dim]
|
||||||
else:
|
else:
|
||||||
size = loaded_weight.get_shape()[dim]
|
size = loaded_weight.get_shape()[dim]
|
||||||
@@ -712,6 +722,7 @@ class RowParallelLinear(LinearBase):
|
|||||||
weight_loader=(
|
weight_loader=(
|
||||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||||
),
|
),
|
||||||
|
model_format=fd_config.model_config.model_format,
|
||||||
)
|
)
|
||||||
if self.nranks > 0:
|
if self.nranks > 0:
|
||||||
if self.with_bias:
|
if self.with_bias:
|
||||||
|
@@ -22,7 +22,7 @@ from paddle import nn
|
|||||||
from paddle.distributed import fleet
|
from paddle.distributed import fleet
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.model_executor.utils import set_weight_attrs
|
from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs
|
||||||
|
|
||||||
from .utils import get_tensor
|
from .utils import get_tensor
|
||||||
|
|
||||||
@@ -61,6 +61,7 @@ class ParallelLMHead(nn.Layer):
|
|||||||
self.use_ep: bool = fd_config.parallel_config.use_ep
|
self.use_ep: bool = fd_config.parallel_config.use_ep
|
||||||
self.column_cut = True
|
self.column_cut = True
|
||||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||||
|
self.fd_config = fd_config
|
||||||
|
|
||||||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
||||||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
||||||
@@ -90,7 +91,14 @@ class ParallelLMHead(nn.Layer):
|
|||||||
weight_attr=None,
|
weight_attr=None,
|
||||||
has_bias=True if self.bias_key is not None else False,
|
has_bias=True if self.bias_key is not None else False,
|
||||||
gather_output=need_gather,
|
gather_output=need_gather,
|
||||||
fuse_matmul_bias=False, # False diff更小
|
fuse_matmul_bias=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
self.linear.weight,
|
||||||
|
{
|
||||||
|
"weight_loader": default_weight_loader(self.fd_config),
|
||||||
|
"model_format": self.fd_config.model_config.model_format,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
if self.nranks > 1:
|
if self.nranks > 1:
|
||||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||||
@@ -102,8 +110,16 @@ class ParallelLMHead(nn.Layer):
|
|||||||
weight_attr=None,
|
weight_attr=None,
|
||||||
has_bias=True if self.bias_key is not None else False,
|
has_bias=True if self.bias_key is not None else False,
|
||||||
input_is_parallel=False,
|
input_is_parallel=False,
|
||||||
fuse_matmul_bias=False, # False diff更小
|
fuse_matmul_bias=False,
|
||||||
)
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
self.linear.weight,
|
||||||
|
{
|
||||||
|
"weight_loader": default_weight_loader(self.fd_config),
|
||||||
|
"model_format": self.fd_config.model_config.model_format,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if self.nranks > 1:
|
if self.nranks > 1:
|
||||||
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
||||||
|
|
||||||
|
@@ -219,7 +219,7 @@ class FusedMoE(nn.Layer):
|
|||||||
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
|
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
|
||||||
dim = -1 if shard_dim else 0
|
dim = -1 if shard_dim else 0
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
if isinstance(loaded_weight, np.ndarray):
|
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
|
||||||
size = loaded_weight.shape[dim]
|
size = loaded_weight.shape[dim]
|
||||||
else:
|
else:
|
||||||
size = loaded_weight.get_shape()[dim]
|
size = loaded_weight.get_shape()[dim]
|
||||||
@@ -259,7 +259,7 @@ class FusedMoE(nn.Layer):
|
|||||||
def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
|
def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
|
||||||
if self.tp_size > 1 and shard_dim is not None:
|
if self.tp_size > 1 and shard_dim is not None:
|
||||||
dim = -1 if shard_dim else 0
|
dim = -1 if shard_dim else 0
|
||||||
if isinstance(loaded_weight, np.ndarray):
|
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
|
||||||
size = loaded_weight.shape[dim]
|
size = loaded_weight.shape[dim]
|
||||||
else:
|
else:
|
||||||
size = loaded_weight.get_shape()[dim]
|
size = loaded_weight.get_shape()[dim]
|
||||||
|
@@ -29,6 +29,7 @@ from safetensors import safe_open
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
|
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||||
from fastdeploy.model_executor.models.tp_utils import (
|
from fastdeploy.model_executor.models.tp_utils import (
|
||||||
check_tensor_parallel_prerequisites,
|
check_tensor_parallel_prerequisites,
|
||||||
)
|
)
|
||||||
@@ -180,8 +181,9 @@ def fast_weights_iterator(safe_tensor_list: list[str]):
|
|||||||
):
|
):
|
||||||
with fast_safe_open(st_file, framework="np") as f:
|
with fast_safe_open(st_file, framework="np") as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
param = f.get_slice(name)
|
param_slice = f.get_slice(name)
|
||||||
yield name, param
|
paddle_tensor = get_tensor(param_slice)
|
||||||
|
yield name, paddle_tensor
|
||||||
|
|
||||||
|
|
||||||
def fastsafetensors_weights_iterator(
|
def fastsafetensors_weights_iterator(
|
||||||
|
@@ -334,6 +334,10 @@ class Qwen2ForCausalLM(ModelForCasualLM):
|
|||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
|
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
|
||||||
for loaded_weight_name, loaded_weight in weights_iterator:
|
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||||
|
model_format = self.fd_config.model_config.model_format
|
||||||
|
# Because the prefix for Paddle is qwen2, and for Hugging Face it is model.
|
||||||
|
if model_format == "torch":
|
||||||
|
loaded_weight_name = loaded_weight_name.replace("model", "qwen2")
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in loaded_weight_name:
|
if weight_name not in loaded_weight_name:
|
||||||
continue
|
continue
|
||||||
|
@@ -16,6 +16,8 @@
|
|||||||
|
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||||
|
|
||||||
@@ -155,9 +157,15 @@ def default_weight_loader(fd_config: FDConfig) -> None:
|
|||||||
def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
|
def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
|
||||||
"""fn"""
|
"""fn"""
|
||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
|
model_format = getattr(param, "model_format", "")
|
||||||
|
if model_format == "torch":
|
||||||
|
loaded_weight = loaded_weight.transpose([1, 0])
|
||||||
# Tensor parallelism splits the weight along the output_dim
|
# Tensor parallelism splits the weight along the output_dim
|
||||||
if output_dim is not None and fd_config.parallel_config.tensor_parallel_size > 1:
|
if output_dim is not None and fd_config.parallel_config.tensor_parallel_size > 1:
|
||||||
dim = -1 if output_dim else 0
|
dim = -1 if output_dim else 0
|
||||||
|
if isinstance(loaded_weight, paddle.Tensor):
|
||||||
|
size = loaded_weight.shape[dim]
|
||||||
|
else:
|
||||||
size = loaded_weight.get_shape()[dim]
|
size = loaded_weight.get_shape()[dim]
|
||||||
block_size = size // fd_config.parallel_config.tensor_parallel_size
|
block_size = size // fd_config.parallel_config.tensor_parallel_size
|
||||||
shard_offset = fd_config.parallel_config.tensor_parallel_rank * block_size
|
shard_offset = fd_config.parallel_config.tensor_parallel_rank * block_size
|
||||||
|
@@ -27,6 +27,25 @@ TokensIdText = list[tuple[list[int], str]]
|
|||||||
# (token_ids, text)
|
# (token_ids, text)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_paths(base_model_name: str) -> tuple[str, str]:
|
||||||
|
"""return (fastdeploy_path, huggingface_path)"""
|
||||||
|
# FastDeploy model path
|
||||||
|
fd_base_path = os.getenv("MODEL_PATH")
|
||||||
|
if fd_base_path:
|
||||||
|
fd_model_path = os.path.join(fd_base_path, base_model_name)
|
||||||
|
else:
|
||||||
|
fd_model_path = base_model_name
|
||||||
|
|
||||||
|
# HuggingFace model path
|
||||||
|
torch_model_path = os.path.join(
|
||||||
|
fd_base_path,
|
||||||
|
"torch",
|
||||||
|
base_model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fd_model_path, torch_model_path
|
||||||
|
|
||||||
|
|
||||||
def check_tokens_id_and_text_close(
|
def check_tokens_id_and_text_close(
|
||||||
*,
|
*,
|
||||||
outputs_0_lst: TokensIdText,
|
outputs_0_lst: TokensIdText,
|
||||||
@@ -104,6 +123,7 @@ model_param_map = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
params = []
|
params = []
|
||||||
for model, cfg in model_param_map.items():
|
for model, cfg in model_param_map.items():
|
||||||
for q in cfg["quantizations"]:
|
for q in cfg["quantizations"]:
|
||||||
@@ -176,3 +196,84 @@ def test_common_model(
|
|||||||
name_0="default loader",
|
name_0="default loader",
|
||||||
name_1="default_v1 loader",
|
name_1="default_v1 loader",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
hugging_face_model_param_map = {
|
||||||
|
"Qwen2.5-7B-Instruct": {
|
||||||
|
"tensor_parallel_size": 2,
|
||||||
|
"quantizations": ["None"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
hf_params = []
|
||||||
|
for model, cfg in hugging_face_model_param_map.items():
|
||||||
|
for q in cfg["quantizations"]:
|
||||||
|
hf_params.append(
|
||||||
|
pytest.param(
|
||||||
|
model,
|
||||||
|
cfg.get("tensor_parallel_size", 1),
|
||||||
|
cfg.get("max_model_len", 1024),
|
||||||
|
q,
|
||||||
|
cfg.get("max_tokens", 32),
|
||||||
|
marks=[pytest.mark.core_model],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name_or_path,tensor_parallel_size,max_model_len,quantization,max_tokens",
|
||||||
|
hf_params,
|
||||||
|
)
|
||||||
|
def test_paddle_vs_torch_model(
|
||||||
|
fd_runner,
|
||||||
|
model_name_or_path: str,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
max_model_len: int,
|
||||||
|
max_tokens: int,
|
||||||
|
quantization: str,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
fd_model_path, torch_model_path = get_model_paths(model_name_or_path)
|
||||||
|
|
||||||
|
result_queue = Queue()
|
||||||
|
|
||||||
|
p_paddle = Process(
|
||||||
|
target=form_model_get_output,
|
||||||
|
args=(
|
||||||
|
fd_runner,
|
||||||
|
fd_model_path,
|
||||||
|
tensor_parallel_size,
|
||||||
|
max_model_len,
|
||||||
|
max_tokens,
|
||||||
|
quantization,
|
||||||
|
"default",
|
||||||
|
result_queue,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p_paddle.start()
|
||||||
|
p_paddle.join()
|
||||||
|
paddle_outputs = result_queue.get(timeout=60)
|
||||||
|
|
||||||
|
p_hf = Process(
|
||||||
|
target=form_model_get_output,
|
||||||
|
args=(
|
||||||
|
fd_runner,
|
||||||
|
torch_model_path,
|
||||||
|
tensor_parallel_size,
|
||||||
|
max_model_len,
|
||||||
|
max_tokens,
|
||||||
|
quantization,
|
||||||
|
"default_v1",
|
||||||
|
result_queue,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p_hf.start()
|
||||||
|
p_hf.join()
|
||||||
|
hf_outputs = result_queue.get(timeout=60)
|
||||||
|
|
||||||
|
check_tokens_id_and_text_close(
|
||||||
|
outputs_0_lst=paddle_outputs,
|
||||||
|
outputs_1_lst=hf_outputs,
|
||||||
|
name_0="Paddle model (default loader)",
|
||||||
|
name_1="HuggingFace model (default_v1 loader)",
|
||||||
|
)
|
||||||
|
Reference in New Issue
Block a user