mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[V1 Loader]Ernie VL support loader v1 (#3494)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* ernie vl support new loader * add unittest * fix test
This commit is contained in:
@@ -191,7 +191,7 @@ class FusedMoE(nn.Layer):
|
||||
loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size]
|
||||
self.weight_loader(param, loaded_weight_shard, expert_id, shard_id)
|
||||
else:
|
||||
expert_param = param[expert_id]
|
||||
expert_param = param[expert_id - self.expert_id_offset]
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
expert_param.copy_(loaded_weight, False)
|
||||
else:
|
||||
@@ -262,7 +262,7 @@ class FusedMoE(nn.Layer):
|
||||
loaded_weight,
|
||||
shard_id,
|
||||
):
|
||||
expert_param = param[expert_id]
|
||||
expert_param = param[expert_id - self.expert_id_offset]
|
||||
if shard_id == "down":
|
||||
self._load_down_weight(expert_param, shard_dim, loaded_weight, shard_id)
|
||||
elif shard_id in ["gate", "up"]:
|
||||
@@ -279,6 +279,7 @@ class FusedMoE(nn.Layer):
|
||||
param_gate_up_proj_name: Optional[str] = None,
|
||||
param_down_proj_name: Optional[str] = None,
|
||||
ckpt_expert_key_name: str = "experts",
|
||||
experts_offset: int = 0,
|
||||
) -> list[tuple[str, str, int, str]]:
|
||||
param_name_maping = []
|
||||
|
||||
@@ -303,7 +304,7 @@ class FusedMoE(nn.Layer):
|
||||
expert_id,
|
||||
shard_id,
|
||||
)
|
||||
for expert_id in range(num_experts)
|
||||
for expert_id in range(experts_offset, experts_offset + num_experts)
|
||||
for shard_id, weight_name in param_name_maping
|
||||
]
|
||||
|
||||
|
Reference in New Issue
Block a user