mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-03 07:46:50 +08:00
[v1loader]Reduce EB300B model loading time (#3700)
* speed up eb45 * update
This commit is contained in:
@@ -415,6 +415,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
model_format = getattr(param, "model_format", "")
|
||||
if model_format == "torch":
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
assert output_dim is not None
|
||||
@@ -446,7 +447,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
shard_offset = self.local_rank * block_size
|
||||
shard_size = (self.local_rank + 1) * block_size
|
||||
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
|
||||
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
if not param._is_initialized():
|
||||
param.initialize()
|
||||
param_shard_size = output_size // 2
|
||||
@@ -548,6 +549,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
|
||||
model_format = getattr(param, "model_format", "")
|
||||
if model_format == "torch":
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already fused on disk
|
||||
@@ -568,12 +570,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if self.nranks != 1:
|
||||
block_size = self._get_shard_size_mapping(loaded_shard_id)
|
||||
dim = -1 if output_dim else 0
|
||||
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
|
||||
shard_offset = shard_id * block_size
|
||||
shard_size = (shard_id + 1) * block_size
|
||||
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
|
||||
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
|
||||
if not param._is_initialized():
|
||||
param.initialize()
|
||||
|
||||
|
Reference in New Issue
Block a user