polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -25,26 +25,31 @@ from paddleformers.transformers import PretrainedModel
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce
from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import (
ColumnParallelLinear, KVBatchLinear, MergedColumnParallelLinear,
ReplicatedLinear, RowParallelLinear)
ColumnParallelLinear,
KVBatchLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.layers.rotary_embedding import \
DeepseekScalingRotaryEmbedding
from fastdeploy.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding,
)
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import \
get_position_ids_and_mask_encoder_batch
from fastdeploy.model_executor.ops.gpu import (
get_position_ids_and_mask_encoder_batch,
)
class DeepSeekV3MLP(nn.Layer):
@@ -86,14 +91,12 @@ class DeepSeekV3MLP(nn.Layer):
)
def load_state_dict(self, state_dict):
"""
"""
""" """
self.up_gate_proj.load_state_dict(state_dict)
self.down_proj.load_state_dict(state_dict)
def forward(self, x):
"""
"""
""" """
gate_up_out = self.up_gate_proj(x)
act_out = self.act_fn(gate_up_out)
down_out = self.down_proj(act_out)
@@ -105,42 +108,34 @@ class DeepSeekV3MoE(nn.Layer):
DeepSeekV3MoE, for MoE Layer.
"""
def __init__(self, fd_config: FDConfig, layer_id: int,
prefix: str) -> None:
def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None:
super().__init__()
self.tp_size = fd_config.parallel_config.tensor_parallel_size
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
"gate_correction_bias_key":
f"{prefix}.gate.e_score_correction_bias",
"up_gate_proj_expert_weight_key":
f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key":
f"{prefix}.experts.{{}}.down_proj.weight",
"gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
}
self.fused_moe = FusedMoE(
fd_config=fd_config,
reduce_results=False,
moe_intermediate_size=fd_config.model_config.
moe_intermediate_size,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
num_experts=fd_config.model_config.n_routed_experts,
top_k=fd_config.model_config.num_experts_per_tok,
topk_method=fd_config.model_config.topk_method,
topk_group=fd_config.model_config.topk_group,
n_group=fd_config.model_config.n_group,
routed_scaling_factor=fd_config.model_config.
routed_scaling_factor,
routed_scaling_factor=fd_config.model_config.routed_scaling_factor,
layer_idx=layer_id,
weight_key_map=weight_key_map,
)
self.num_shared_experts = fd_config.model_config.n_shared_experts
shared_experts_intermediate_size = (
self.num_shared_experts *
fd_config.model_config.moe_intermediate_size)
shared_experts_intermediate_size = self.num_shared_experts * fd_config.model_config.moe_intermediate_size
self.shared_experts = DeepSeekV3MLP(
fd_config=fd_config,
@@ -150,14 +145,12 @@ class DeepSeekV3MoE(nn.Layer):
)
def load_state_dict(self, state_dict):
"""
"""
""" """
self.fused_moe.load_state_dict(state_dict)
self.shared_experts.load_state_dict(state_dict)
def forward(self, hidden_states: paddle.Tensor):
"""
"""
""" """
shared_experts_out = self.shared_experts(hidden_states)
moe_out = self.fused_moe(hidden_states)
moe_out = moe_out + shared_experts_out
@@ -172,10 +165,7 @@ class DeepseekV3MLAAttention(nn.Layer):
DeepseekV3MLAAttention
"""
def __init__(self,
fd_config: FDConfig,
layer_id: int,
prefix: str = "") -> None:
def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None:
super().__init__()
self.tp_size = fd_config.parallel_config.tensor_parallel_size
@@ -196,16 +186,20 @@ class DeepseekV3MLAAttention(nn.Layer):
self.rms_norm_eps = fd_config.model_config.rms_norm_eps
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(fd_config=fd_config,
prefix=f"{prefix}.q_a_proj",
input_size=self.hidden_size,
output_size=self.q_lora_rank,
with_bias=False)
self.q_a_proj = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.q_a_proj",
input_size=self.hidden_size,
output_size=self.q_lora_rank,
with_bias=False,
)
self.q_a_layernorm = RMSNorm(fd_config,
hidden_size=self.q_lora_rank,
eps=self.rms_norm_eps,
prefix=f"{prefix}.q_a_layernorm")
self.q_a_layernorm = RMSNorm(
fd_config,
hidden_size=self.q_lora_rank,
eps=self.rms_norm_eps,
prefix=f"{prefix}.q_a_layernorm",
)
self.q_b_proj = ColumnParallelLinear(
fd_config=fd_config,
@@ -215,8 +209,7 @@ class DeepseekV3MLAAttention(nn.Layer):
with_bias=False,
)
else:
assert (self.q_lora_rank is not None
), "self.q_lora_rank is None, Please Check your config."
assert self.q_lora_rank is not None, "self.q_lora_rank is None, Please Check your config."
# 不切TP,跑 W4A16 Gemm
self.kv_a_proj_with_mqa = ReplicatedLinear(
@@ -224,28 +217,31 @@ class DeepseekV3MLAAttention(nn.Layer):
prefix=f"{prefix}.kv_a_proj_with_mqa",
input_size=self.hidden_size,
output_size=self.kv_lora_rank + self.qk_rope_head_dim,
with_bias=False)
with_bias=False,
)
self.kv_a_layernorm = RMSNorm(fd_config,
hidden_size=self.kv_lora_rank,
eps=self.rms_norm_eps,
prefix=f"{prefix}.kv_a_layernorm")
self.kv_a_layernorm = RMSNorm(
fd_config,
hidden_size=self.kv_lora_rank,
eps=self.rms_norm_eps,
prefix=f"{prefix}.kv_a_layernorm",
)
self.kv_b_proj = ColumnParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.kv_b_proj",
input_size=self.kv_lora_rank,
output_size=self.num_attention_heads *
(self.qk_nope_head_dim + self.v_head_dim),
output_size=self.num_attention_heads * (self.qk_nope_head_dim + self.v_head_dim),
with_bias=False,
)
self.o_proj = RowParallelLinear(fd_config,
prefix=f"{prefix}.o_proj",
input_size=self.num_attention_heads *
self.v_head_dim,
output_size=self.hidden_size,
with_bias=False)
self.o_proj = RowParallelLinear(
fd_config,
prefix=f"{prefix}.o_proj",
input_size=self.num_attention_heads * self.v_head_dim,
output_size=self.hidden_size,
with_bias=False,
)
self.kv_b_proj_bmm = KVBatchLinear(
fd_config=fd_config,
@@ -253,14 +249,14 @@ class DeepseekV3MLAAttention(nn.Layer):
kv_lora_rank=self.kv_lora_rank,
num_attention_heads=self.num_attention_heads,
qk_nope_head_dim=self.qk_nope_head_dim,
v_head_dim=self.v_head_dim)
v_head_dim=self.v_head_dim,
)
self.rope_scaling = fd_config.model_config.rope_scaling
if self.rope_scaling:
mscale_all_dim = self.rope_scaling.get("mscale_all_dim", False)
scaling_factor = self.rope_scaling["factor"]
mscale = self.yarn_get_mscale(scaling_factor,
float(mscale_all_dim))
mscale = self.yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale
rope_scaling_kwargs = {
@@ -270,15 +266,14 @@ class DeepseekV3MLAAttention(nn.Layer):
"beta_slow",
"mscale",
"mscale_all_dim",
] if key in self.rope_scaling
]
if key in self.rope_scaling
}
self.rope_scaling_factor = self.rope_scaling["factor"]
self.rope_scaling_original_max_position_embeddings = self.rope_scaling[
"original_max_position_embeddings"]
self.rope_scaling_original_max_position_embeddings = self.rope_scaling["original_max_position_embeddings"]
self.rotary_emb = DeepseekScalingRotaryEmbedding(
self.qk_rope_head_dim,
max_position_embeddings=self.
rope_scaling_original_max_position_embeddings,
max_position_embeddings=self.rope_scaling_original_max_position_embeddings,
base=self.rope_theta,
scaling_factor=self.rope_scaling_factor,
**rope_scaling_kwargs,
@@ -295,8 +290,7 @@ class DeepseekV3MLAAttention(nn.Layer):
@staticmethod
def yarn_get_mscale(scale=1, mscale=1):
"""
"""
""" """
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
@@ -308,63 +302,61 @@ class DeepseekV3MLAAttention(nn.Layer):
position_ids: paddle.Tensor,
mask_encoder_batch: paddle.Tensor,
):
"""
"""
""" """
layernorm_out = hidden_states
fmha_out = paddle.zeros(shape=[
layernorm_out.shape[0],
self.num_attention_heads_tp * self.v_head_dim
],
dtype=layernorm_out.dtype)
fmha_out = paddle.zeros(
shape=[
layernorm_out.shape[0],
self.num_attention_heads_tp * self.v_head_dim,
],
dtype=layernorm_out.dtype,
)
if forward_meta.max_enc_len_this_time:
query = self.q_a_proj(layernorm_out)
query = self.q_a_layernorm(query)
query = self.q_b_proj(query)
query = query.reshape(
[-1, self.num_attention_heads_tp, self.qk_head_dim])
query_nope, query_pe = query.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
compressed_kv, key_pe = compressed_kv.split(
[self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
compressed_kv = self.kv_a_layernorm(compressed_kv)
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
key_value = self.kv_b_proj(compressed_kv)
key_value = key_value.reshape([
-1, self.num_attention_heads_tp,
self.qk_nope_head_dim + self.v_head_dim
])
key_nope, value = key_value.split(
[self.qk_nope_head_dim, self.v_head_dim], axis=-1)
key_value = key_value.reshape(
[
-1,
self.num_attention_heads_tp,
self.qk_nope_head_dim + self.v_head_dim,
]
)
key_nope, value = key_value.split([self.qk_nope_head_dim, self.v_head_dim], axis=-1)
query[..., self.qk_nope_head_dim:] = query_pe
query[..., self.qk_nope_head_dim :] = query_pe
key = paddle.empty_like(query)
key[..., :self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim:] = key_pe
value = paddle.nn.functional.pad(
value, [0, self.qk_head_dim - self.v_head_dim], value=0)
key[..., : self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim :] = key_pe
value = paddle.nn.functional.pad(value, [0, self.qk_head_dim - self.v_head_dim], value=0)
fmha_out_prefill = self.mla_attn(q=query,
k=key,
v=value,
qkv=None,
compressed_kv=compressed_kv,
k_pe=key_pe,
forward_meta=forward_meta)
fmha_out_prefill = self.mla_attn(
q=query,
k=key,
v=value,
qkv=None,
compressed_kv=compressed_kv,
k_pe=key_pe,
forward_meta=forward_meta,
)
fmha_out_prefill = fmha_out_prefill.reshape(
[-1, self.num_attention_heads_tp, self.qk_head_dim])
fmha_out_prefill = fmha_out_prefill[:, :, :self.v_head_dim]
fmha_out_prefill = fmha_out_prefill.reshape(
[-1, self.num_attention_heads_tp * self.v_head_dim])
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(
fmha_out_prefill.dtype)
fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim]
fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype)
fmha_out = fmha_out + fmha_out_prefill
if forward_meta.max_dec_len_this_time:
@@ -373,51 +365,51 @@ class DeepseekV3MLAAttention(nn.Layer):
ln_out_or_q_c = query
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
compressed_kv, key_pe = compressed_kv.split(
[self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
compressed_kv = self.kv_a_layernorm(compressed_kv)
query = self.q_b_proj(ln_out_or_q_c)
query = query.reshape(
[-1, self.num_attention_heads_tp, self.qk_head_dim])
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
query_nope, query_pe = query.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]),
proj_type='k').transpose([1, 0, 2])
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2])
q_input = paddle.concat([q_nope_out, query_pe], axis=-1)
q_input = q_input.reshape([
-1,
self.num_attention_heads_tp *
(self.kv_lora_rank + self.qk_rope_head_dim),
])
fmha_out_decode = self.mla_attn(q=q_input,
k=None,
v=None,
qkv=None,
compressed_kv=compressed_kv,
k_pe=key_pe,
forward_meta=forward_meta)
q_input = q_input.reshape(
[
-1,
self.num_attention_heads_tp * (self.kv_lora_rank + self.qk_rope_head_dim),
]
)
fmha_out_decode = self.mla_attn(
q=q_input,
k=None,
v=None,
qkv=None,
compressed_kv=compressed_kv,
k_pe=key_pe,
forward_meta=forward_meta,
)
fmha_out_decode = fmha_out_decode.reshape(
[-1, self.num_attention_heads_tp,
self.kv_lora_rank]).transpose([1, 0, 2])
fmha_out_decode = fmha_out_decode.reshape([-1, self.num_attention_heads_tp, self.kv_lora_rank]).transpose(
[1, 0, 2]
)
fmha_out_decode = (self.kv_b_proj_bmm(
fmha_out_decode, proj_type='v').transpose([1, 0, 2]).reshape(
[-1, self.num_attention_heads_tp * self.v_head_dim]))
fmha_out_decode = (
self.kv_b_proj_bmm(fmha_out_decode, proj_type="v")
.transpose([1, 0, 2])
.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
)
fmha_out = fmha_out + fmha_out_decode
output = self.o_proj(fmha_out)
return output
def load_state_dict(self, state_dict):
"""
"""
""" """
self.q_a_proj.load_state_dict(state_dict)
self.q_a_layernorm.load_state_dict(state_dict)
self.kv_a_proj_with_mqa.load_state_dict(state_dict)
@@ -441,7 +433,7 @@ class DeepSeekV3DecoderLayer(nn.Layer):
prefix: str = "",
) -> None:
super().__init__()
layer_id = int(prefix.split(sep='.')[-1])
layer_id = int(prefix.split(sep=".")[-1])
self.self_attn = DeepseekV3MLAAttention(
fd_config=fd_config,
@@ -449,9 +441,10 @@ class DeepSeekV3DecoderLayer(nn.Layer):
prefix=f"{prefix}.self_attn",
)
if (fd_config.model_config.n_routed_experts is not None
and layer_id
>= fd_config.model_config.first_k_dense_replace):
if (
fd_config.model_config.n_routed_experts is not None
and layer_id >= fd_config.model_config.first_k_dense_replace
):
self.mlp = DeepSeekV3MoE(
fd_config=fd_config,
layer_id=layer_id,
@@ -479,8 +472,7 @@ class DeepSeekV3DecoderLayer(nn.Layer):
)
def load_state_dict(self, state_dict):
"""
"""
""" """
self.self_attn.load_state_dict(state_dict)
self.mlp.load_state_dict(state_dict)
self.input_layernorm.load_state_dict(state_dict)
@@ -494,20 +486,16 @@ class DeepSeekV3DecoderLayer(nn.Layer):
position_ids: paddle.Tensor,
mask_encoder_batch: paddle.Tensor,
):
"""
"""
""" """
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(forward_meta, hidden_states,
position_ids, mask_encoder_batch)
hidden_states = self.self_attn(forward_meta, hidden_states, position_ids, mask_encoder_batch)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@@ -536,12 +524,15 @@ class DeepSeekV3Model(nn.Layer):
prefix="deepseek_v3.embed_tokens",
)
self.decoder_layers = nn.LayerList([
DeepSeekV3DecoderLayer(
fd_config,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}")
for i in range(self.num_layers)
])
self.decoder_layers = nn.LayerList(
[
DeepSeekV3DecoderLayer(
fd_config,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}",
)
for i in range(self.num_layers)
]
)
self.norm = RMSNorm(
fd_config,
@@ -567,15 +558,18 @@ class DeepSeekV3Model(nn.Layer):
position_ids: paddle.Tensor,
mask_encoder_batch: paddle.Tensor,
):
"""
"""
""" """
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
residual = None
for i in range(self.num_layers):
hidden_states, residual = self.decoder_layers[i](
forward_meta, hidden_states, residual, position_ids,
mask_encoder_batch)
forward_meta,
hidden_states,
residual,
position_ids,
mask_encoder_batch,
)
hidden_states = hidden_states + residual
out = self.norm(hidden_states)
@@ -604,8 +598,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
@classmethod
def name(cls):
"""
"""
""" """
return "DeepseekV3ForCausalLM"
@paddle.no_grad()
@@ -617,31 +610,28 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
self.lm_head.load_state_dict(state_dict)
def compute_logits(self, hidden_states: paddle.Tensor):
"""
"""
""" """
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits[:, self.ori_vocab_size:] = -float("inf")
logits[:, self.ori_vocab_size :] = -float("inf")
return logits
def pre_process(self, forward_meta):
"""
"""
""" """
seq_lens_encoder = forward_meta.seq_lens_encoder
seq_lens_decoder = forward_meta.seq_lens_decoder
seq_lens_this_time = forward_meta.seq_lens_this_time
position_ids_shape = paddle.sum(seq_lens_this_time)
position_ids = paddle.empty(shape=position_ids_shape,
dtype=seq_lens_encoder.dtype)
mask_encoder_batch = paddle.empty(
shape=position_ids_shape,
dtype=seq_lens_encoder.dtype).unsqueeze(1)
position_ids = paddle.empty(shape=position_ids_shape, dtype=seq_lens_encoder.dtype)
mask_encoder_batch = paddle.empty(shape=position_ids_shape, dtype=seq_lens_encoder.dtype).unsqueeze(1)
get_position_ids_and_mask_encoder_batch(seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
position_ids,
mask_encoder_batch)
get_position_ids_and_mask_encoder_batch(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
position_ids,
mask_encoder_batch,
)
return position_ids, mask_encoder_batch
@@ -650,11 +640,14 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
):
"""
"""
""" """
position_ids, mask_encoder_batch = self.pre_process(forward_meta)
hidden_states = self.model(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta,
position_ids=position_ids, mask_encoder_batch=mask_encoder_batch)
hidden_states = self.model(
ids_remove_padding=ids_remove_padding,
forward_meta=forward_meta,
position_ids=position_ids,
mask_encoder_batch=mask_encoder_batch,
)
return hidden_states
@@ -676,8 +669,8 @@ class DeepSeekV3PretrainedModel(PretrainedModel):
logger.info("DeepseekV3 inference model _get_tensor_parallel_mappings")
from paddleformers.transformers.conversion_utils import \
split_or_merge_func
from paddleformers.transformers.conversion_utils import split_or_merge_func
fn = split_or_merge_func(
is_split=is_split,
tensor_parallel_degree=config.tensor_parallel_degree,
@@ -691,66 +684,40 @@ class DeepSeekV3PretrainedModel(PretrainedModel):
base_actions = {
"lm_head.weight": partial(fn, is_column=True),
"embed_tokens.weight": partial(fn, is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn,
is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
}
# Self Attention Layer which are need TP.
base_actions["layers.0.self_attn.q_b_proj.weight"] = partial(
fn, is_column=True)
base_actions["layers.0.self_attn.kv_b_proj.weight"] = partial(
fn, is_column=True)
base_actions[
"layers.0.self_attn.q_b_proj.weight_scale_inv"] = partial(
fn, is_column=True)
base_actions[
"layers.0.self_attn.kv_b_proj.weight_scale_inv"] = partial(
fn, is_column=True)
base_actions["layers.0.self_attn.q_b_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.kv_b_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.q_b_proj.weight_scale_inv"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.kv_b_proj.weight_scale_inv"] = partial(fn, is_column=True)
# MLP Layer
base_actions["layers.0.mlp.gate_proj.weight"] = partial(
fn, is_column=True)
base_actions["layers.0.mlp.up_proj.weight"] = partial(
fn, is_column=True)
base_actions["layers.0.mlp.down_proj.weight"] = partial(
fn, is_column=False)
base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.down_proj.weight"] = partial(fn, is_column=False)
# Moe Layer
for expert_idx in range(config.n_routed_experts):
base_actions[
f"layers.0.mlp.experts.{expert_idx}.up_proj.weight"] = partial(
fn, is_column=True)
base_actions[
f"layers.0.mlp.experts.{expert_idx}.gate_proj.weight"] = partial(
fn, is_column=True)
base_actions[
f"layers.0.mlp.experts.{expert_idx}.down_proj.weight"] = partial(
fn, is_column=False)
base_actions[f"layers.0.mlp.experts.{expert_idx}.up_proj.weight"] = partial(fn, is_column=True)
base_actions[f"layers.0.mlp.experts.{expert_idx}.gate_proj.weight"] = partial(fn, is_column=True)
base_actions[f"layers.0.mlp.experts.{expert_idx}.down_proj.weight"] = partial(fn, is_column=False)
# Shared Expert Layer
base_actions[
"layers.0.mlp.shared_experts.up_proj.weight"] = partial(
fn, is_column=True)
base_actions[
"layers.0.mlp.shared_experts.gate_proj.weight"] = partial(
fn, is_column=True)
base_actions[
"layers.0.mlp.shared_experts.down_proj.weight"] = partial(
fn, is_column=False)
base_actions["layers.0.mlp.shared_experts.up_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.shared_experts.gate_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.shared_experts.down_proj.weight"] = partial(fn, is_column=False)
# MTP parts
base_actions["layers.61.embed_tokens.weight"] = partial(
fn, is_column=False)
base_actions["layers.61.eh_proj.weight"] = partial(fn,
is_column=True)
base_actions["layers.61.shared_head.head.weight"] = partial(
fn, is_column=True)
base_actions["layers.61.embed_tokens.weight"] = partial(fn, is_column=False)
base_actions["layers.61.eh_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.61.shared_head.head.weight"] = partial(fn, is_column=True)
for key, action in base_actions.items():
if "layers.0." in key:
for i in range(num_layers):
final_actions[key.replace("layers.0.",
f"layers.{i}.")] = action
final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
final_actions[key] = action
return final_actions