[Feature] support pooling model dummy_run (#4345)

* support qwen3-embedding

* fix ci bug

* support pooling dummy_run

* fix

* delete print

* parallel_config.max_model_len

* delete is_pooling_model in dummy_run

* fix

* fd_model

* fix embedding load

* fix

* fix post_process
This commit is contained in:
lizexu123
2025-10-17 13:30:55 +08:00
committed by GitHub
parent 15b6b8dc25
commit c234b995ab
10 changed files with 291 additions and 126 deletions

View File

@@ -157,7 +157,7 @@ def free_tensor(tensor):
del tensor
def default_weight_loader(fd_config: FDConfig) -> None:
def default_weight_loader(fd_config: FDConfig = None) -> None:
"""Default weight loader"""
def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
@@ -169,7 +169,7 @@ def default_weight_loader(fd_config: FDConfig) -> None:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None and fd_config.parallel_config.tensor_parallel_size > 1:
if output_dim is not None and fd_config is not None and fd_config.parallel_config.tensor_parallel_size > 1:
dim = -1 if output_dim else 0
if isinstance(loaded_weight, paddle.Tensor):
size = loaded_weight.shape[dim]