[V1 Loader] support weight_only (#3413)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled

* support wint4/wint8

* delete smoe case

* update ci

* print log
This commit is contained in:
bukejiyu
2025-08-23 13:13:41 +08:00
committed by GitHub
parent 93e1b63200
commit 77514e3e1e
24 changed files with 1055 additions and 524 deletions

View File

@@ -17,6 +17,7 @@
from __future__ import annotations
import math
import re
from functools import partial
import paddle
@@ -122,6 +123,25 @@ class DeepSeekV3MoE(nn.Layer):
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
}
self.gate = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.gate",
input_size=fd_config.model_config.hidden_size,
output_size=fd_config.model_config.n_routed_experts,
with_bias=False,
skip_quant=True,
weight_dtype="float32",
)
if fd_config.model_config.topk_method == "noaux_tc":
self.gate.e_score_correction_bias = self.create_parameter(
shape=[1, fd_config.model_config.n_routed_experts],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
else:
self.gate.e_score_correction_bias = None
self.experts = FusedMoE(
fd_config=fd_config,
reduce_results=False,
@@ -133,19 +153,10 @@ class DeepSeekV3MoE(nn.Layer):
n_group=fd_config.model_config.n_group,
routed_scaling_factor=fd_config.model_config.routed_scaling_factor,
layer_idx=layer_id,
gate_correction_bias=self.gate.e_score_correction_bias,
weight_key_map=weight_key_map,
)
self.gate = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.gate",
input_size=fd_config.model_config.hidden_size,
output_size=fd_config.model_config.n_routed_experts,
with_bias=False,
skip_quant=True,
weight_dtype="float32",
)
self.num_shared_experts = fd_config.model_config.n_shared_experts
shared_experts_intermediate_size = self.num_shared_experts * fd_config.model_config.moe_intermediate_size
@@ -258,6 +269,7 @@ class DeepseekV3MLAAttention(nn.Layer):
self.kv_b_proj_bmm = KVBatchLinear(
fd_config=fd_config,
kv_b_proj=self.kv_b_proj,
prefix=f"{prefix}.kv_b_proj",
kv_lora_rank=self.kv_lora_rank,
num_attention_heads=self.num_attention_heads,
@@ -617,7 +629,10 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
Args:
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""
from fastdeploy.model_executor.models.utils import default_weight_loader
from fastdeploy.model_executor.utils import (
default_weight_loader,
process_weights_after_loading,
)
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
@@ -637,7 +652,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
param_down_proj_name="experts.down_proj_",
)
params_dict = dict(self.named_parameters())
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
for loaded_weight_name, loaded_weight in weights_iterator:
loaded_weight_name = loaded_weight_name.replace("deepseek_v3", "model")
@@ -668,19 +683,18 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id)
break
else:
if loaded_weight_name not in params_dict:
model_param_name = loaded_weight_name
if model_param_name not in params_dict:
continue
param = params_dict[loaded_weight_name]
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight)
if "kv_b_proj.weight" in loaded_weight_name:
# handle kv_b_proj_bmm
model_param_name = loaded_weight_name.replace(
"kv_b_proj.weight", "kv_b_proj_bmm.k_b_proj_weight"
)
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", None)
weight_loader(param, loaded_weight, shard_id)
model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
if "kv_b_proj" in model_sublayer_name:
kv_model_sublayer_name = model_sublayer_name.replace("kv_b_proj", "kv_b_proj_bmm")
process_weights_after_loading_fn(kv_model_sublayer_name)
process_weights_after_loading_fn(model_sublayer_name, param)
def compute_logits(self, hidden_states: paddle.Tensor):
""" """

View File

@@ -17,6 +17,7 @@
from __future__ import annotations
import inspect
import re
from functools import partial
from typing import Dict, Union
@@ -149,15 +150,6 @@ class Ernie4_5_MoE(nn.Layer):
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
}
self.experts = FusedMoE(
fd_config=fd_config,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
num_experts=fd_config.model_config.moe_num_experts,
top_k=fd_config.model_config.moe_k,
layer_idx=layer_id,
weight_key_map=weight_key_map,
)
self.gate = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.gate",
@@ -168,6 +160,25 @@ class Ernie4_5_MoE(nn.Layer):
weight_dtype="float32",
)
self.experts = FusedMoE(
fd_config=fd_config,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
num_experts=fd_config.model_config.moe_num_experts,
top_k=fd_config.model_config.moe_k,
layer_idx=layer_id,
gate_correction_bias=None,
weight_key_map=weight_key_map,
)
if fd_config.model_config.moe_use_aux_free:
self.experts.gate_correction_bias = self.create_parameter(
shape=[1, fd_config.model_config.moe_num_experts],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
else:
self.experts.gate_correction_bias = None
self.num_shared_experts = fd_config.model_config.moe_num_shared_experts
if self.num_shared_experts > 0:
shared_experts_hidden_dim = self.num_shared_experts * fd_config.model_config.moe_intermediate_size
@@ -180,6 +191,13 @@ class Ernie4_5_MoE(nn.Layer):
def load_state_dict(self, state_dict):
self.gate.load_state_dict(state_dict)
self.experts.load_state_dict(state_dict)
if self.experts.gate_correction_bias is not None:
gate_correction_bias_tensor = state_dict.pop(self.experts.gate_correction_bias_key)
if self.experts.gate_correction_bias.shape != gate_correction_bias_tensor.shape:
gate_correction_bias_tensor = gate_correction_bias_tensor.reshape(
self.experts.gate_correction_bias.shape
)
self.experts.gate_correction_bias.set_value(gate_correction_bias_tensor)
if self.num_shared_experts > 0:
self.shared_experts.load_state_dict(state_dict)
@@ -441,12 +459,16 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""
from fastdeploy.model_executor.models.utils import default_weight_loader
from fastdeploy.model_executor.utils import (
default_weight_loader,
process_weights_after_loading,
)
general_params_mapping = [
# (param_name, weight_name, expert_id, shard_id)
("embed_tokens.embeddings", "embed_tokens", None, None),
("lm_head.linear", "lm_head", None, None),
("experts.gate_correction_bias", "moe_statics.e_score_correction_bias", None, None),
]
expert_params_mapping = []
@@ -458,13 +480,10 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
param_gate_up_proj_name="experts.up_gate_proj_",
param_down_proj_name="experts.down_proj_",
)
expert_params_mapping.append(
("experts.gate_correction_bias", "moe_statics.e_score_correction_bias", None, "gate_bias")
)
logger.info(f"expert params mapping:{expert_params_mapping}")
all_param_mapping = general_params_mapping + expert_params_mapping
params_dict = dict(self.named_parameters())
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
expert_id = None
shard_id = None
@@ -478,9 +497,10 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
shard_id = shard_id
break
else:
if loaded_weight_name not in params_dict.keys():
model_param_name = loaded_weight_name
if model_param_name not in params_dict.keys():
continue
param = params_dict[loaded_weight_name]
param = params_dict[model_param_name]
# Get weight loader from parameter and set weight
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
@@ -490,6 +510,8 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
else:
weight_loader(param, loaded_weight)
model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))

View File

@@ -34,7 +34,7 @@ from paddle.nn.functional.flash_attention import (
from paddleformers.transformers.model_utils import PretrainedModel
from fastdeploy.model_executor.layers.utils import divide, get_tensor
from fastdeploy.model_executor.models.utils import set_weight_attrs
from fastdeploy.model_executor.utils import set_weight_attrs
from .activation import ACT2FN
from .configuration import DFNRopeVisionTransformerConfig

View File

@@ -17,6 +17,7 @@
from __future__ import annotations
import inspect
import re
from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Union
@@ -38,7 +39,6 @@ from fastdeploy.model_executor.layers.linear import ReplicatedLinear
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.models.ernie4_5_moe import (
Ernie4_5_Attention,
Ernie4_5_MLP,
@@ -75,7 +75,15 @@ class VLMoEMeta:
class Ernie4_5_VLMoeBlock(nn.Layer):
def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str, moe_tag: str, expert_id_offset: int) -> None:
def __init__(
self,
fd_config: FDConfig,
layer_id: int,
prefix: str,
moe_tag: str,
expert_id_offset: int,
gate_correction_bias=None,
) -> None:
super().__init__()
moe_quant_type = ""
if hasattr(fd_config, "quant_config") and fd_config.quant_config is not None:
@@ -120,6 +128,7 @@ class Ernie4_5_VLMoeBlock(nn.Layer):
layer_idx=layer_id,
moe_tag=moe_tag,
weight_key_map=weight_key_map,
gate_correction_bias=gate_correction_bias,
)
self.gate = ReplicatedLinear(
@@ -133,29 +142,10 @@ class Ernie4_5_VLMoeBlock(nn.Layer):
weight_key="weight" if moe_tag == "Text" else "weight_1",
)
if moe_tag == "Text":
self.experts.extract_gate_correction_bias = self.extract_gate_correction_bias_text
elif moe_tag == "Image":
self.experts.extract_gate_correction_bias = self.extract_gate_correction_bias_image
def forward(self, hidden_states: paddle.Tensor):
out = self.experts(hidden_states, self.gate)
return out
def extract_gate_correction_bias_text(self, gate_correction_bias_key, state_dict):
"""
extract_gate_correction_bias function.
"""
gate_correction_bias_tensor = get_tensor(state_dict[gate_correction_bias_key]).astype("float32")
return gate_correction_bias_tensor[0].unsqueeze(0)
def extract_gate_correction_bias_image(self, gate_correction_bias_key, state_dict):
"""
extract_gate_correction_bias function.
"""
gate_correction_bias_tensor = get_tensor(state_dict[gate_correction_bias_key]).astype("float32")
return gate_correction_bias_tensor[1].unsqueeze(0)
def load_state_dict(self, state_dict):
self.experts.load_state_dict(state_dict)
self.gate.load_state_dict(state_dict)
@@ -186,10 +176,25 @@ class Ernie4_5_VLMoE(nn.Layer):
image_moe_layer_end_index = moe_layer_end_index[1]
assert text_moe_layer_start_index <= text_moe_layer_end_index
if fd_config.model_config.moe_use_aux_free:
self.gate_correction_bias = self.create_parameter(
shape=[2, fd_config.model_config.moe_num_experts[0]],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
if not self.gate_correction_bias._is_initialized():
self.gate_correction_bias.initialize()
else:
self.gate_correction_bias = None
if layer_id >= text_moe_layer_start_index and layer_id <= text_moe_layer_end_index:
self.text_fused_moe = Ernie4_5_VLMoeBlock(
fd_config=fd_config, layer_id=layer_id, prefix=f"{prefix}", moe_tag="Text", expert_id_offset=0
fd_config=fd_config,
layer_id=layer_id,
prefix=f"{prefix}",
moe_tag="Text",
expert_id_offset=0,
gate_correction_bias=self.gate_correction_bias[0] if fd_config.model_config.moe_use_aux_free else None,
)
else:
self.text_fused_moe = Ernie4_5_VLMLP(
@@ -207,6 +212,7 @@ class Ernie4_5_VLMoE(nn.Layer):
prefix=f"{prefix}",
moe_tag="Image",
expert_id_offset=fd_config.model_config.moe_num_experts[0],
gate_correction_bias=self.gate_correction_bias[1] if fd_config.model_config.moe_use_aux_free else None,
)
else:
self.image_fused_moe = Ernie4_5_VLMLP(
@@ -226,10 +232,13 @@ class Ernie4_5_VLMoE(nn.Layer):
)
def load_state_dict(self, state_dict):
if self.gate_correction_bias is not None:
gate_correction_bias_tensor = state_dict.pop(self.text_fused_moe.experts.gate_correction_bias_key)
if self.gate_correction_bias.shape != gate_correction_bias_tensor.shape:
gate_correction_bias_tensor = gate_correction_bias_tensor.reshape(self.gate_correction_bias.shape)
self.gate_correction_bias.set_value(gate_correction_bias_tensor)
self.text_fused_moe.load_state_dict(state_dict)
self.image_fused_moe.load_state_dict(state_dict)
if self.text_fused_moe.experts.moe_use_gate_correction_bias:
state_dict.pop(self.text_fused_moe.experts.gate_correction_bias_key)
if self.num_shared_experts > 0:
self.shared_experts.load_state_dict(state_dict)
@@ -563,19 +572,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
def name(self):
return "Ernie4_5_VLMoeForConditionalGeneration"
def gate_correction_bias_loader(self, params_dict, loaded_weight_name, loaded_weight):
text_param_name = loaded_weight_name.replace(
"moe_statics.e_score_correction_bias", "text_fused_moe.experts.gate_correction_bias"
)
image_param_name = loaded_weight_name.replace(
"moe_statics.e_score_correction_bias", "image_fused_moe.experts.gate_correction_bias"
)
text_param = params_dict[text_param_name]
image_param = params_dict[image_param_name]
loaded_weight = get_tensor(loaded_weight)
text_param.copy_(loaded_weight[0].unsqueeze(0), False)
image_param.copy_(loaded_weight[1].unsqueeze(0), False)
@paddle.no_grad()
def load_weights(self, weights_iterator) -> None:
"""
@@ -585,7 +581,10 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""
from fastdeploy.model_executor.models.utils import default_weight_loader
from fastdeploy.model_executor.utils import (
default_weight_loader,
process_weights_after_loading,
)
general_params_mapping = [
# (param_name, weight_name, expert_id, shard_id)
@@ -594,6 +593,8 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
("mlp.image_fused_moe.gate.weight", "mlp.gate.weight_1", None, "gate"),
("mlp.text_fused_moe.gate.weight", "mlp.gate.weight", None, "gate"),
("resampler_model", "ernie.resampler_model", None, None),
("vision_model", "ernie.vision_model", None, None),
("gate_correction_bias", "moe_statics.e_score_correction_bias", None, None),
]
text_expert_params_mapping = []
@@ -617,6 +618,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
all_param_mapping = general_params_mapping + text_expert_params_mapping + image_expert_params_mapping
params_dict = dict(self.named_parameters())
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
expert_id = None
shard_id = None
for loaded_weight_name, loaded_weight in weights_iterator:
@@ -629,10 +631,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
shard_id = shard_id
break
else:
# text and image gate_correction_bias is fused in ckpt and need load independently
if "moe_statics.e_score_correction_bias" in loaded_weight_name:
self.gate_correction_bias_loader(params_dict, loaded_weight_name, loaded_weight)
continue
if loaded_weight_name not in params_dict.keys():
continue
model_param_name = loaded_weight_name
@@ -646,7 +644,8 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
weight_loader(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
else:
weight_loader(param, loaded_weight)
model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))

View File

@@ -30,7 +30,7 @@ from fastdeploy.model_executor.models.ernie4_5_vl.dist_utils import (
reduce_scatter_group,
scatter_axis,
)
from fastdeploy.model_executor.models.utils import set_weight_attrs
from fastdeploy.model_executor.utils import set_weight_attrs
class ScatterOp(PyLayer):

View File

@@ -16,6 +16,7 @@
from __future__ import annotations
import re
from functools import partial
import paddle
@@ -254,7 +255,10 @@ class Qwen3ForCausalLM(ModelForCasualLM):
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""
from fastdeploy.model_executor.models.utils import default_weight_loader
from fastdeploy.model_executor.utils import (
default_weight_loader,
process_weights_after_loading,
)
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
@@ -266,8 +270,8 @@ class Qwen3ForCausalLM(ModelForCasualLM):
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
]
params_dict = dict(self.named_parameters())
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
for loaded_weight_name, loaded_weight in weights_iterator:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in loaded_weight_name:
@@ -280,11 +284,14 @@ class Qwen3ForCausalLM(ModelForCasualLM):
weight_loader(param, loaded_weight, shard_id)
break
else:
if loaded_weight_name not in params_dict:
model_param_name = loaded_weight_name
if model_param_name not in params_dict:
continue
param = params_dict[loaded_weight_name]
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight)
model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name)
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0]))

View File

@@ -16,6 +16,7 @@
from __future__ import annotations
import re
from functools import partial
import paddle
@@ -334,7 +335,10 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""
from fastdeploy.model_executor.models.utils import default_weight_loader
from fastdeploy.model_executor.utils import (
default_weight_loader,
process_weights_after_loading,
)
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
@@ -348,6 +352,7 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
]
expert_params_mapping = self.get_expert_mapping()
params_dict = dict(self.named_parameters())
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
for loaded_weight_name, loaded_weight in weights_iterator:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in loaded_weight_name:
@@ -374,12 +379,16 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id)
break
else:
if loaded_weight_name not in params_dict:
model_param_name = loaded_weight_name
if model_param_name not in params_dict:
continue
param = params_dict[loaded_weight_name]
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight)
model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
process_weights_after_loading_fn(model_sublayer_name, param)
@paddle.no_grad()
def set_state_dict(self, state_dict):
"""

View File

@@ -24,7 +24,7 @@ import random
import re
import struct
from functools import partial
from typing import Any, NamedTuple, Optional, Union
from typing import NamedTuple, Optional
import numpy as np
import paddle
@@ -40,73 +40,10 @@ from paddleformers.utils.env import (
from paddleformers.utils.log import logger
from tqdm import tqdm
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.utils import get_tensor
MAX_BSZ = 512
MAX_DRAFT_TOKENS = 6
def set_weight_attrs(param, param_attr_map: Optional[dict[str, Any]]):
if param_attr_map is None:
return
for key, value in param_attr_map.items():
setattr(param, key, value)
def slice_fn(weight_or_paramter, output_dim, start, end, step=1):
if hasattr(weight_or_paramter, "get_shape"):
shape = weight_or_paramter.get_shape()
else:
shape = weight_or_paramter.shape
if len(shape) == 1:
weight_or_paramter = weight_or_paramter[start:end]
elif output_dim:
weight_or_paramter = weight_or_paramter[..., start:end]
else:
weight_or_paramter = weight_or_paramter[start:end, ...]
return weight_or_paramter
def default_weight_loader(fd_config: FDConfig) -> None:
"""Default weight loader"""
def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
"""fn"""
try:
output_dim = getattr(param, "output_dim", None)
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
dim = -1 if output_dim else 0
size = loaded_weight.get_shape()[dim]
block_size = size // fd_config.parallel_config.tensor_parallel_size
shard_offset = fd_config.parallel_config.tensor_parallel_rank * block_size
shard_size = (fd_config.parallel_config.tensor_parallel_rank + 1) * block_size
if output_dim:
loaded_weight = loaded_weight[..., shard_offset:shard_size]
else:
loaded_weight = loaded_weight[shard_offset:shard_size, ...]
loaded_weight = get_tensor(loaded_weight)
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
if param.dtype != loaded_weight.dtype:
loaded_weight = loaded_weight.cast(param.dtype)
if param.shape != loaded_weight.shape:
try:
param = param.reshape(loaded_weight.shape)
except ValueError as e:
raise ValueError(
f" Attempted to load weight ({loaded_weight.shape}) into parameter ({param.shape}). {e}"
)
param.copy_(loaded_weight, False)
except Exception:
raise
return fn
class LayerIdPlaceholder(str, enum.Enum):
"""LayerIdPlaceholder"""