""" # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ from __future__ import annotations import math import re from functools import partial import paddle from paddle import nn from paddleformers.transformers import PretrainedModel from paddleformers.utils.log import logger from fastdeploy.config import FDConfig from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.graph_optimization.decorator import ( support_graph_optimization, ) from fastdeploy.model_executor.layers.activation import SiluAndMul from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding from fastdeploy.model_executor.layers.linear import ( ColumnParallelLinear, KVBatchLinear, MergedColumnParallelLinear, ReplicatedLinear, RowParallelLinear, ) from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.moe.moe import FusedMoE from fastdeploy.model_executor.layers.normalization import RMSNorm from fastdeploy.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, ) from fastdeploy.model_executor.models.model_base import ModelForCasualLM from fastdeploy.platforms import current_platform if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import ( get_position_ids_and_mask_encoder_batch, ) class DeepSeekV3MLP(nn.Layer): """ DeepSeekV3MLP, for Dense FFN and Shared Experts Layer. """ def __init__( self, fd_config: FDConfig, intermediate_size: int, prefix: str = "", reduce_results: bool = True, ) -> None: super().__init__() self.up_gate_proj = MergedColumnParallelLinear( fd_config=fd_config, prefix=f"{prefix}.up_gate_proj", input_size=fd_config.model_config.hidden_size, output_size=intermediate_size * 2, with_bias=False, activation=fd_config.model_config.hidden_act, ) self.down_proj = RowParallelLinear( fd_config=fd_config, prefix=f"{prefix}.down_proj", input_size=intermediate_size, output_size=fd_config.model_config.hidden_size, with_bias=False, reduce_results=reduce_results, ) self.act_fn = SiluAndMul( fd_config=fd_config, bias=None, act_method=fd_config.model_config.hidden_act, ) def load_state_dict(self, state_dict): """ """ self.up_gate_proj.load_state_dict(state_dict) self.down_proj.load_state_dict(state_dict) def forward(self, x): """ """ gate_up_out = self.up_gate_proj(x) act_out = self.act_fn(gate_up_out) down_out = self.down_proj(act_out) return down_out class DeepSeekV3MoE(nn.Layer): """ DeepSeekV3MoE, for MoE Layer. """ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None: super().__init__() self.tp_size = fd_config.parallel_config.tensor_parallel_size self.norm_topk_prob = fd_config.model_config.norm_topk_prob weight_key_map = { "gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias", "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight", "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight", } self.gate = ReplicatedLinear( fd_config=fd_config, prefix=f"{prefix}.gate", input_size=fd_config.model_config.hidden_size, output_size=fd_config.model_config.n_routed_experts, with_bias=False, skip_quant=True, weight_dtype="float32", ) if fd_config.model_config.topk_method == "noaux_tc": self.gate.e_score_correction_bias = self.create_parameter( shape=[1, fd_config.model_config.n_routed_experts], dtype="float32", default_initializer=paddle.nn.initializer.Constant(0), ) else: self.gate.e_score_correction_bias = None self.experts = FusedMoE( fd_config=fd_config, reduce_results=False, renormalize=self.norm_topk_prob, moe_intermediate_size=fd_config.model_config.moe_intermediate_size, num_experts=fd_config.model_config.n_routed_experts, top_k=fd_config.model_config.num_experts_per_tok, topk_method=fd_config.model_config.topk_method, topk_group=fd_config.model_config.topk_group, n_group=fd_config.model_config.n_group, routed_scaling_factor=fd_config.model_config.routed_scaling_factor, layer_idx=layer_id, gate_correction_bias=self.gate.e_score_correction_bias, weight_key_map=weight_key_map, ) self.num_shared_experts = fd_config.model_config.n_shared_experts shared_experts_intermediate_size = self.num_shared_experts * fd_config.model_config.moe_intermediate_size self.shared_experts = DeepSeekV3MLP( fd_config=fd_config, intermediate_size=shared_experts_intermediate_size, prefix=f"{prefix}.shared_experts", reduce_results=False, ) def load_state_dict(self, state_dict): """ """ self.gate.load_state_dict(state_dict) self.experts.load_state_dict(state_dict) self.shared_experts.load_state_dict(state_dict) def forward(self, hidden_states: paddle.Tensor): """ """ shared_experts_out = self.shared_experts(hidden_states) moe_out = self.experts(hidden_states, self.gate) moe_out = moe_out + shared_experts_out # We do to TP all reduce after the sum of experts. if self.tp_size > 1: tensor_model_parallel_all_reduce(moe_out) return moe_out class DeepseekV3MLAAttention(nn.Layer): """ DeepseekV3MLAAttention """ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None: super().__init__() self.tp_size = fd_config.parallel_config.tensor_parallel_size self.hidden_size = fd_config.model_config.hidden_size self.num_attention_heads = fd_config.model_config.num_attention_heads self.num_attention_heads_tp = self.num_attention_heads // self.tp_size # MLA self.qk_nope_head_dim = fd_config.model_config.qk_nope_head_dim self.qk_rope_head_dim = fd_config.model_config.qk_rope_head_dim self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim self.v_head_dim = fd_config.model_config.v_head_dim self.q_lora_rank = fd_config.model_config.q_lora_rank self.kv_lora_rank = fd_config.model_config.kv_lora_rank self.attn_softmax_scale = self.qk_head_dim**-0.5 self.rope_theta = fd_config.model_config.rope_theta self.rms_norm_eps = fd_config.model_config.rms_norm_eps if self.q_lora_rank is not None: # NOTE: (changwenbin) qkv_a_proj horizontal fusion self.qkv_a_proj_with_mqa = ReplicatedLinear( fd_config=fd_config, prefix=f"{prefix}.qkv_a_proj_with_mqa", input_size=self.hidden_size, output_size=self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim, with_bias=False, ) self.q_a_layernorm = RMSNorm( fd_config, hidden_size=self.q_lora_rank, eps=self.rms_norm_eps, prefix=f"{prefix}.q_a_layernorm", ) self.q_b_proj = ColumnParallelLinear( fd_config=fd_config, prefix=f"{prefix}.q_b_proj", input_size=self.q_lora_rank, output_size=self.num_attention_heads * self.qk_head_dim, with_bias=False, ) else: assert self.q_lora_rank is not None, "self.q_lora_rank is None, Please Check your config." self.kv_a_layernorm = RMSNorm( fd_config, hidden_size=self.kv_lora_rank, eps=self.rms_norm_eps, prefix=f"{prefix}.kv_a_layernorm", ) self.kv_b_proj = ColumnParallelLinear( fd_config=fd_config, prefix=f"{prefix}.kv_b_proj", input_size=self.kv_lora_rank, output_size=self.num_attention_heads * (self.qk_nope_head_dim + self.v_head_dim), with_bias=False, ) self.o_proj = RowParallelLinear( fd_config, prefix=f"{prefix}.o_proj", input_size=self.num_attention_heads * self.v_head_dim, output_size=self.hidden_size, with_bias=False, ) self.kv_b_proj_bmm = KVBatchLinear( fd_config=fd_config, kv_b_proj=self.kv_b_proj, prefix=f"{prefix}.kv_b_proj", kv_lora_rank=self.kv_lora_rank, num_attention_heads=self.num_attention_heads, qk_nope_head_dim=self.qk_nope_head_dim, v_head_dim=self.v_head_dim, ) self.rope_scaling = fd_config.model_config.rope_scaling if self.rope_scaling: mscale_all_dim = self.rope_scaling.get("mscale_all_dim", False) scaling_factor = self.rope_scaling["factor"] mscale = self.yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale rope_scaling_kwargs = { key: self.rope_scaling[key] for key in [ "beta_fast", "beta_slow", "mscale", "mscale_all_dim", ] if key in self.rope_scaling } self.rope_scaling_factor = self.rope_scaling["factor"] self.rope_scaling_original_max_position_embeddings = self.rope_scaling["original_max_position_embeddings"] self.rotary_emb = DeepseekScalingRotaryEmbedding( self.qk_rope_head_dim, max_position_embeddings=self.rope_scaling_original_max_position_embeddings, base=self.rope_theta, scaling_factor=self.rope_scaling_factor, **rope_scaling_kwargs, ) self.mla_attn = Attention( fd_config=fd_config, layer_id=layer_id, prefix=prefix, use_neox_rotary_style=False, ) self.prefix = prefix @staticmethod def yarn_get_mscale(scale=1, mscale=1): """ """ if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 def forward( self, forward_meta: ForwardMeta, hidden_states: paddle.Tensor, position_ids: paddle.Tensor, mask_encoder_batch: paddle.Tensor, ): """ """ # NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation. fmha_out = None # NOTE: (changwenbin) qkv_a_proj horizontal fusion qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states) query, compressed_kv, key_pe = qkv_a_out.split( [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1 ) query = self.q_a_layernorm(query) query = self.q_b_proj(query) query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim]) query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim]) compressed_kv = self.kv_a_layernorm(compressed_kv) query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe) if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time key_value = self.kv_b_proj(compressed_kv) key_value = key_value.reshape( [ -1, self.num_attention_heads_tp, self.qk_nope_head_dim + self.v_head_dim, ] ) key_nope, value = key_value.split([self.qk_nope_head_dim, self.v_head_dim], axis=-1) query[..., self.qk_nope_head_dim :] = query_pe key = paddle.empty_like(query) key[..., : self.qk_nope_head_dim] = key_nope key[..., self.qk_nope_head_dim :] = key_pe value = paddle.nn.functional.pad(value, [0, self.qk_head_dim - self.v_head_dim], value=0) fmha_out_prefill = self.mla_attn( q=query, k=key, v=value, qkv=None, compressed_kv=compressed_kv, k_pe=key_pe, forward_meta=forward_meta, ) fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim]) fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim] fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp * self.v_head_dim]) fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype) fmha_out = fmha_out_prefill if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2]) q_input = paddle.concat([q_nope_out, query_pe], axis=-1) q_input = q_input.reshape( [ -1, self.num_attention_heads_tp * (self.kv_lora_rank + self.qk_rope_head_dim), ] ) fmha_out_decode = self.mla_attn( q=q_input, k=None, v=None, qkv=None, compressed_kv=compressed_kv, k_pe=key_pe, forward_meta=forward_meta, ) fmha_out_decode = fmha_out_decode.reshape([-1, self.num_attention_heads_tp, self.kv_lora_rank]).transpose( [1, 0, 2] ) fmha_out_decode = ( self.kv_b_proj_bmm(fmha_out_decode, proj_type="v") .transpose([1, 0, 2]) .reshape([-1, self.num_attention_heads_tp * self.v_head_dim]) ) if fmha_out is None: fmha_out = fmha_out_decode else: fmha_out = fmha_out + fmha_out_decode output = self.o_proj(fmha_out) return output def load_state_dict(self, state_dict): """ """ self.q_a_layernorm.load_state_dict(state_dict) self.qkv_a_proj_with_mqa.load_state_dict(state_dict) self.kv_a_layernorm.load_state_dict(state_dict) self.q_b_proj.load_state_dict(state_dict) self.kv_b_proj_bmm.load_state_dict(state_dict) self.kv_b_proj.load_state_dict(state_dict) # NOTE(Ryan):Make sure kv_b_proj_bmm loaded before kv_b_proj, # The same weight key will be poped after kv_b_proj. self.o_proj.load_state_dict(state_dict) self.mla_attn.load_state_dict(state_dict) class DeepSeekV3DecoderLayer(nn.Layer): """ DeepSeekV3DecoderLayer """ def __init__( self, fd_config: FDConfig, prefix: str = "", ) -> None: super().__init__() layer_id = int(prefix.split(sep=".")[-1]) self.self_attn = DeepseekV3MLAAttention( fd_config=fd_config, layer_id=layer_id, prefix=f"{prefix}.self_attn", ) if ( fd_config.model_config.n_routed_experts is not None and layer_id >= fd_config.model_config.first_k_dense_replace ): self.mlp = DeepSeekV3MoE( fd_config=fd_config, layer_id=layer_id, prefix=f"{prefix}.mlp", ) else: self.mlp = DeepSeekV3MLP( fd_config=fd_config, intermediate_size=fd_config.model_config.intermediate_size, prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm( fd_config, hidden_size=fd_config.model_config.hidden_size, eps=fd_config.model_config.rms_norm_eps, prefix=f"{prefix}.input_layernorm", ) self.post_attention_layernorm = RMSNorm( fd_config, hidden_size=fd_config.model_config.hidden_size, eps=fd_config.model_config.rms_norm_eps, prefix=f"{prefix}.post_attention_layernorm", ) def load_state_dict(self, state_dict): """ """ self.self_attn.load_state_dict(state_dict) self.mlp.load_state_dict(state_dict) self.input_layernorm.load_state_dict(state_dict) self.post_attention_layernorm.load_state_dict(state_dict) def forward( self, forward_meta: ForwardMeta, hidden_states: paddle.Tensor, residual: paddle.Tensor, position_ids: paddle.Tensor, mask_encoder_batch: paddle.Tensor, ): """ """ if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn(forward_meta, hidden_states, position_ids, mask_encoder_batch) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_graph_optimization class DeepSeekV3Model(nn.Layer): """ DeepSeekV3Model """ def __init__( self, fd_config: FDConfig = None, ): """ Initializer for the DeepSeekV3Model class. """ super().__init__() self.num_layers = fd_config.model_config.num_hidden_layers fd_config.model_config.pretrained_config.prefix_name = "deepseek_v3" self.embed_tokens = VocabParallelEmbedding( fd_config, num_embeddings=fd_config.model_config.vocab_size, embedding_dim=fd_config.model_config.hidden_size, params_dtype=paddle.get_default_dtype(), prefix="deepseek_v3.embed_tokens", ) self.layers = nn.LayerList( [ DeepSeekV3DecoderLayer( fd_config, prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}", ) for i in range(self.num_layers) ] ) self.norm = RMSNorm( fd_config, hidden_size=fd_config.model_config.hidden_size, eps=fd_config.model_config.rms_norm_eps, prefix="deepseek_v3.norm", ) def load_state_dict(self, state_dict): """ Load model parameters from a given state dictionary. """ self.embed_tokens.load_state_dict(state_dict) self.norm.load_state_dict(state_dict) for i in range(self.num_layers): logger.info(f"Start load layer {i}") self.layers[i].load_state_dict(state_dict) def forward( self, ids_remove_padding: paddle.Tensor, forward_meta: ForwardMeta, position_ids: paddle.Tensor, mask_encoder_batch: paddle.Tensor, ): """ """ hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding) residual = None for i in range(self.num_layers): hidden_states, residual = self.layers[i]( forward_meta, hidden_states, residual, position_ids, mask_encoder_batch, ) hidden_states = hidden_states + residual out = self.norm(hidden_states) return out class DeepseekV3ForCausalLM(ModelForCasualLM): """ DeepseekV3ForCausalLM """ def __init__(self, fd_config: FDConfig): """ Args: fd_config (FDConfig): Configurations for the LLM model. """ super().__init__(fd_config) self.model = DeepSeekV3Model(fd_config) self.ori_vocab_size = fd_config.model_config.ori_vocab_size self.lm_head = ParallelLMHead( fd_config, embedding_dim=fd_config.model_config.hidden_size, num_embeddings=fd_config.model_config.vocab_size, prefix="lm_head", ) self.position_ids_buffer = paddle.empty([fd_config.parallel_config.max_num_batched_tokens], dtype=paddle.int32) self.mask_encoder_batch_buffer = paddle.empty( [fd_config.parallel_config.max_num_batched_tokens, 1], dtype=paddle.int32 ) @classmethod def name(cls): """ """ return "DeepseekV3ForCausalLM" @paddle.no_grad() def set_state_dict(self, state_dict): """ Load model parameters from a given state dictionary. """ self.model.load_state_dict(state_dict) self.lm_head.load_state_dict(state_dict) @paddle.no_grad() def load_weights(self, weights_iterator) -> None: """ Load model parameters from a given weights_iterator object. Args: weights_iterator (Iterator): An iterator yielding (name, weight) pairs. """ from fastdeploy.model_executor.utils import ( default_weight_loader, process_weights_after_loading, ) stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("up_gate_proj", "gate_proj", "gate"), ("up_gate_proj", "up_proj", "up"), ("embed_tokens.embeddings", "embed_tokens", None), ("lm_head.linear", "lm_head", None), ("experts.gate_correction_bias", "gate.e_score_correction_bias", None), ] # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = FusedMoE.make_expert_params_mapping( num_experts=self.fd_config.model_config.n_routed_experts, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", param_gate_up_proj_name="experts.up_gate_proj_", param_down_proj_name="experts.down_proj_", ) params_dict = dict(self.named_parameters()) process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers())) for loaded_weight_name, loaded_weight in weights_iterator: loaded_weight_name = loaded_weight_name.replace("deepseek_v3", "model") for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in loaded_weight_name: continue if "mlp.experts." in loaded_weight_name: continue model_param_name = loaded_weight_name.replace(weight_name, param_name) if model_param_name not in params_dict: continue param = params_dict[model_param_name] weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) weight_loader(param, loaded_weight, shard_id) break else: for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in loaded_weight_name: continue model_param_name = loaded_weight_name.replace(weight_name, param_name) if model_param_name not in params_dict: continue param = params_dict[model_param_name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id) break else: model_param_name = loaded_weight_name if model_param_name not in params_dict: continue param = params_dict[model_param_name] weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) weight_loader(param, loaded_weight) model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name) if "kv_b_proj" in model_sublayer_name: kv_model_sublayer_name = model_sublayer_name.replace("kv_b_proj", "kv_b_proj_bmm") process_weights_after_loading_fn(kv_model_sublayer_name) process_weights_after_loading_fn(model_sublayer_name, param) def compute_logits(self, hidden_states: paddle.Tensor): """ """ logits = self.lm_head(hidden_states) logits = logits.astype(paddle.float32) logits[:, self.ori_vocab_size :] = -float("inf") return logits def pre_process(self, forward_meta): """ """ seq_lens_encoder = forward_meta.seq_lens_encoder seq_lens_decoder = forward_meta.seq_lens_decoder seq_lens_this_time = forward_meta.seq_lens_this_time current_total_tokens = paddle.sum(seq_lens_this_time) position_ids = self.position_ids_buffer[:current_total_tokens] mask_encoder_batch = self.mask_encoder_batch_buffer[:current_total_tokens] get_position_ids_and_mask_encoder_batch( seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids, mask_encoder_batch, ) return position_ids, mask_encoder_batch def forward( self, ids_remove_padding: paddle.Tensor, forward_meta: ForwardMeta, ): """ """ position_ids, mask_encoder_batch = self.pre_process(forward_meta) hidden_states = self.model( ids_remove_padding=ids_remove_padding, forward_meta=forward_meta, position_ids=position_ids, mask_encoder_batch=mask_encoder_batch, ) return hidden_states def clear_grpah_opt_backend(self): """Clear graph optimization bakcend, the captured cuda graph will be cleaned""" self.model.clear_grpah_opt_backend(fd_config=self.fd_config) class DeepSeekV3PretrainedModel(PretrainedModel): """ DeepSeekV3PretrainedModel """ config_class = FDConfig def _init_weight(self, layer): """ _init_weight """ return None @classmethod def arch_name(self): return "DeepseekV3ForCausalLM" @classmethod def _get_tensor_parallel_mappings(cls, config, is_split=True): logger.info("DeepseekV3 inference model _get_tensor_parallel_mappings") from paddleformers.transformers.conversion_utils import split_or_merge_func fn = split_or_merge_func( is_split=is_split, tensor_parallel_degree=config.tensor_parallel_degree, tensor_parallel_rank=config.tensor_parallel_rank, num_attention_heads=config.num_attention_heads, ) def get_tensor_parallel_split_mappings(num_layers): final_actions = {} base_actions = { "lm_head.weight": partial(fn, is_column=True), "embed_tokens.weight": partial(fn, is_column=False), "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), } # Self Attention Layer which are need TP. base_actions["layers.0.self_attn.q_b_proj.weight"] = partial(fn, is_column=True) base_actions["layers.0.self_attn.kv_b_proj.weight"] = partial(fn, is_column=True) base_actions["layers.0.self_attn.q_b_proj.weight_scale_inv"] = partial(fn, is_column=True) base_actions["layers.0.self_attn.kv_b_proj.weight_scale_inv"] = partial(fn, is_column=True) # MLP Layer base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True) base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True) base_actions["layers.0.mlp.down_proj.weight"] = partial(fn, is_column=False) # Moe Layer for expert_idx in range(config.n_routed_experts): base_actions[f"layers.0.mlp.experts.{expert_idx}.up_proj.weight"] = partial(fn, is_column=True) base_actions[f"layers.0.mlp.experts.{expert_idx}.gate_proj.weight"] = partial(fn, is_column=True) base_actions[f"layers.0.mlp.experts.{expert_idx}.down_proj.weight"] = partial(fn, is_column=False) # Shared Expert Layer base_actions["layers.0.mlp.shared_experts.up_proj.weight"] = partial(fn, is_column=True) base_actions["layers.0.mlp.shared_experts.gate_proj.weight"] = partial(fn, is_column=True) base_actions["layers.0.mlp.shared_experts.down_proj.weight"] = partial(fn, is_column=False) # MTP parts base_actions["layers.61.embed_tokens.weight"] = partial(fn, is_column=False) base_actions["layers.61.eh_proj.weight"] = partial(fn, is_column=True) base_actions["layers.61.shared_head.head.weight"] = partial(fn, is_column=True) for key, action in base_actions.items(): if "layers.0." in key: for i in range(num_layers): final_actions[key.replace("layers.0.", f"layers.{i}.")] = action final_actions[key] = action return final_actions mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) return mappings