[Others] upgrade paddleformer to 0.4.0 (#5599)

This commit is contained in:
bukejiyu
2025-12-23 21:08:01 +08:00
committed by GitHub
parent 85db9d5e56
commit d1c6e57341
21 changed files with 32 additions and 184 deletions

View File

@@ -106,13 +106,13 @@ def _install_dependency_stubs():
conversion_utils = types.ModuleType("paddleformers.transformers.conversion_utils")
def _split_or_merge_func(is_split, tensor_parallel_degree, tensor_parallel_rank, **_kwargs):
def _split_or_merge_func(is_split, tensor_model_parallel_size, tensor_parallel_rank, **_kwargs):
axis = -1
def _fn(weight, *, is_column=True, **_kwargs):
current_axis = axis if is_column else 0
if is_split:
chunks = np.array_split(weight, tensor_parallel_degree, axis=current_axis)
chunks = np.array_split(weight, tensor_model_parallel_size, axis=current_axis)
if tensor_parallel_rank is None:
return chunks
return chunks[tensor_parallel_rank]