mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
support dynamic activation quant for w4afp8 (#5117)
This commit is contained in:
@@ -1090,7 +1090,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
"down_proj_in_scale": weight_key_map.get("down_proj_expert_in_scale_key", None),
|
||||
}
|
||||
for name, value in scale_key_map.items():
|
||||
if value is None:
|
||||
if hasattr(layer, name) and value is None:
|
||||
raise ValueError(f"scale {name} should not be none in w4a8 mode.")
|
||||
|
||||
# 2. Extract scale tensor from state dict
|
||||
@@ -1111,8 +1111,9 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
|
||||
for expert_idx in logical_expert_ids:
|
||||
for name, scale_key_template in scale_key_map.items():
|
||||
scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx)
|
||||
scale_weight_map[name].append(scale_tensor)
|
||||
if hasattr(layer, name):
|
||||
scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx)
|
||||
scale_weight_map[name].append(scale_tensor)
|
||||
|
||||
for i, weight_scale_name in enumerate(["up_gate_proj_weight_scale", "down_proj_weight_scale"]):
|
||||
in_scale_name = weight_scale_name.replace("_weight_scale", "_in_scale")
|
||||
|
||||
Reference in New Issue
Block a user