mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Others] upgrade paddleformer to 0.4.0 (#5599)
This commit is contained in:
@@ -106,13 +106,13 @@ def _install_dependency_stubs():
|
||||
|
||||
conversion_utils = types.ModuleType("paddleformers.transformers.conversion_utils")
|
||||
|
||||
def _split_or_merge_func(is_split, tensor_parallel_degree, tensor_parallel_rank, **_kwargs):
|
||||
def _split_or_merge_func(is_split, tensor_model_parallel_size, tensor_parallel_rank, **_kwargs):
|
||||
axis = -1
|
||||
|
||||
def _fn(weight, *, is_column=True, **_kwargs):
|
||||
current_axis = axis if is_column else 0
|
||||
if is_split:
|
||||
chunks = np.array_split(weight, tensor_parallel_degree, axis=current_axis)
|
||||
chunks = np.array_split(weight, tensor_model_parallel_size, axis=current_axis)
|
||||
if tensor_parallel_rank is None:
|
||||
return chunks
|
||||
return chunks[tensor_parallel_rank]
|
||||
|
||||
Reference in New Issue
Block a user