refactor rl get_name_mappings_to_training (#2847)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* refactor rl get_name_mappings_to_training

* fix tp>1

* change variable name(ffn1->up_gate_proj/ffn2->down_proj)

* change variable name(linear_weight->weight/linear_bias->bias)

* add rl names mapping for vl

* fix ernie 0.3B error

* fix develop code

* fix
This commit is contained in:
Yuanle Liu
2025-07-15 22:31:42 +08:00
committed by GitHub
parent e7bcbbab52
commit 61b3997b85
47 changed files with 1591 additions and 1629 deletions

View File

@@ -19,12 +19,10 @@ from paddle import nn
from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.utils import (create_hadamard_matrix_map,
get_tensor)
from fastdeploy.model_executor.layers.quantization.quant_base import \
QuantMethodBase
from fastdeploy.utils import ceil_div
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
"""
@@ -36,9 +34,9 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
Triton Group Gemm to compute Fused MoE.
"""
self.quant_method = quant_method
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
self.added_scale_attrs = [
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
"up_gate_proj_weight_scale", "down_proj_weight_scale"
]
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
@@ -49,26 +47,26 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
"""
Triton MoE create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(ffn1_weights) == layer.num_local_experts
assert len(ffn2_weights) == layer.num_local_experts
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
assert self.quant_method.name() == "wint8"
assert ffn1_weights[0].shape == [
assert up_gate_proj_weights[0].shape == [
layer.hidden_size, layer.moe_intermediate_size * 2
]
assert ffn2_weights[0].shape == [
assert down_proj_weights[0].shape == [
layer.moe_intermediate_size, layer.hidden_size
]
ffn1_tensor = paddle.stack(ffn1_weights, axis=0)
ffn2_tensor = paddle.stack(ffn2_weights, axis=0)
up_gate_proj_tensor = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_tensor = paddle.stack(down_proj_weights, axis=0)
if self.quant_method.name() == "wint8":
max_bound = 127
elif self.quant_method.name() == "wint4":
max_bound = 7
for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]):
for idx, weight_tensor in enumerate([up_gate_proj_tensor, down_proj_tensor]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
@@ -150,10 +148,10 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
fused_moe_kernel_paddle[grid](
x,
layer.moe_ffn1_weight,
layer.up_gate_proj_weight,
intermediate_cache1,
None,
layer.moe_ffn1_weight_scale,
layer.up_gate_proj_weight_scale,
None,
sorted_token_ids,
expert_ids,
@@ -164,17 +162,17 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
token_num * top_k,
stride_am=x.strides[0],
stride_ak=x.strides[1],
stride_be=layer.moe_ffn1_weight.strides[0],
stride_bk=layer.moe_ffn1_weight.strides[1],
stride_bn=layer.moe_ffn1_weight.strides[2],
stride_be=layer.up_gate_proj_weight.strides[0],
stride_bk=layer.up_gate_proj_weight.strides[1],
stride_bn=layer.up_gate_proj_weight.strides[2],
stride_cm=intermediate_cache1.strides[0],
stride_cn=intermediate_cache1.strides[1],
#
stride_asm=-1,
stride_ask=-1,
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
stride_bsk=-1,
stride_bsn=layer.moe_ffn1_weight_scale.strides[1],
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
group_n=-1,
group_k=-1,
# Meta-parameters
@@ -197,10 +195,10 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
fused_moe_kernel_paddle[grid](
intermediate_cache2,
layer.moe_ffn2_weight,
layer.down_proj_weight,
intermediate_cache3,
None,
layer.moe_ffn2_weight_scale,
layer.down_proj_weight_scale,
topk_weights,
sorted_token_ids,
expert_ids,
@@ -211,16 +209,16 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
token_num * top_k,
stride_am=intermediate_cache2.strides[0],
stride_ak=intermediate_cache2.strides[1],
stride_be=layer.moe_ffn2_weight.strides[0],
stride_bk=layer.moe_ffn2_weight.strides[1],
stride_bn=layer.moe_ffn2_weight.strides[2],
stride_be=layer.down_proj_weight.strides[0],
stride_bk=layer.down_proj_weight.strides[1],
stride_bn=layer.down_proj_weight.strides[2],
stride_cm=intermediate_cache3.strides[0],
stride_cn=intermediate_cache3.strides[1],
stride_asm=-1,
stride_ask=-1,
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
stride_bse=layer.down_proj_weight_scale.strides[0],
stride_bsk=-1,
stride_bsn=layer.moe_ffn2_weight_scale.strides[1],
stride_bsn=layer.down_proj_weight_scale.strides[1],
group_n=-1,
group_k=-1,
# Meta-parameters

View File

@@ -16,8 +16,8 @@
import paddle
from paddle.nn.quant import weight_dequantize
from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig, GPUWeightOnlyLinearMethod
from fastdeploy.model_executor.layers.quantization.weight_only import (
GPUWeightOnlyLinearMethod, WeightOnlyConfig)
class DCUWeightOnlyLinearMethod(GPUWeightOnlyLinearMethod):
@@ -35,12 +35,12 @@ class DCUWeightOnlyLinearMethod(GPUWeightOnlyLinearMethod):
def apply(self, layer, x):
dequant_out = weight_dequantize(
x=layer.linear_weight,
scale=layer.linear_weight_scale,
x=layer.weight,
scale=layer.weight_scale,
algo=self.quant_config.algo,
out_dtype=paddle.get_default_dtype()
)
linear_out = paddle.matmul(x, dequant_out)
if layer.linear_bias is not None:
linear_out = paddle.add(linear_out, layer.linear_bias)
if layer.bias is not None:
linear_out = paddle.add(linear_out, layer.bias)
return linear_out

View File

@@ -50,11 +50,11 @@ class GCUFusedMoeMethod(MoEMethodBase):
Paddle gcu create weight process.
"""
# bf16
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0)
stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0)
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
for idx, weight_tensor in enumerate(
[stacked_ffn1_weights, stacked_ffn2_weights]):
[stacked_up_gate_proj_weights, stacked_down_proj_weights]):
# shape [E, K, N] -> [E, N, K]
weight_tensor = paddle.transpose(weight_tensor, [0, 2, 1])
weight_name = self.added_weight_attrs[idx]
@@ -117,16 +117,16 @@ class GCUFusedMoeMethod(MoEMethodBase):
dtype=x.dtype,
)
ffn1_B_scale = layer.moe_ffn1_weight_scale if enable_quant else None
ffn1_B_zeros = layer.moe_ffn1_weight_zeros if enable_quant else None
up_gate_proj_B_scale = layer.up_gate_proj_weight_scale if enable_quant else None
up_gate_proj_B_zeros = layer.up_gate_proj_weight_zeros if enable_quant else None
invoke_fused_moe_kernel(
x, # input
layer.moe_ffn1_weight, # weight
layer.up_gate_proj_weight, # weight
intermediate_cache1, # output
None, # A_scale
ffn1_B_scale, # B_scale
ffn1_B_zeros, # B_zp
up_gate_proj_B_scale, # B_scale
up_gate_proj_B_zeros, # B_zp
topk_weights,
topk_indices,
sorted_token_ids,
@@ -154,16 +154,16 @@ class GCUFusedMoeMethod(MoEMethodBase):
dtype=x.dtype,
)
ffn2_B_scale = layer.moe_ffn2_weight_scale if enable_quant else None
ffn2_B_zeros = layer.moe_ffn2_weight_zeros if enable_quant else None
down_proj_B_scale = layer.down_proj_weight_scale if enable_quant else None
down_proj_B_zeros = layer.down_proj_weight_zeros if enable_quant else None
invoke_fused_moe_kernel(
intermediate_cache2, # input
layer.moe_ffn2_weight, # weight
layer.down_proj_weight, # weight
intermediate_cache3, # output
None, # A_scale
ffn2_B_scale, # B_scale
ffn2_B_zeros, # B_zp
down_proj_B_scale, # B_scale
down_proj_B_zeros, # B_zp
topk_weights,
topk_indices,
sorted_token_ids,
@@ -251,7 +251,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
"GCUWeightOnlyMoEMethod only support weight_only_int4, but got:{self.quant_config.algo}"
self.added_qzeros_attrs = [
"moe_ffn1_weight_zeros", "moe_ffn2_weight_zeros"
"up_gate_proj_weight_zeros", "down_proj_weight_zeros"
]
self.group_size = 64
@@ -265,41 +265,41 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
"""
Paddle gcu process prequanted weights.
"""
ffn1_expert_weight_key = layer.weight_key_map.get(
"ffn1_expert_weight_key", None)
ffn2_expert_weight_key = layer.weight_key_map.get(
"ffn2_expert_weight_key", None)
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
"ffn1_expert_weight_scale_key", None)
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
"ffn2_expert_weight_scale_key", None)
up_gate_proj_expert_weight_key = layer.weight_key_map.get(
"up_gate_proj_expert_weight_key", None)
down_proj_expert_weight_key = layer.weight_key_map.get(
"down_proj_expert_weight_key", None)
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get(
"up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get(
"down_proj_expert_weight_scale_key", None)
ffn1_weights, ffn2_weights = layer.load_experts_weight(
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
# self.check(layer, ffn1_weights, ffn2_weights)
ffn1_weight_scale = []
ffn2_weight_scale = []
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key)
# self.check(layer, up_gate_proj_weights, down_proj_weights)
up_gate_proj_weight_scale = []
down_proj_weight_scale = []
for i in range(layer.num_experts):
expert_idx = layer.expert_id_offset + i
ffn1_weight_scale.append(
up_gate_proj_weight_scale.append(
get_tensor(
state_dict.pop(
ffn1_expert_weight_scale_key.format(expert_idx))))
ffn2_weight_scale.append(
up_gate_proj_expert_weight_scale_key.format(expert_idx))))
down_proj_weight_scale.append(
get_tensor(
state_dict.pop(
ffn2_expert_weight_scale_key.format(expert_idx))))
down_proj_expert_weight_scale_key.format(expert_idx))))
ffn1_weight = paddle.stack(ffn1_weights, axis=0)
ffn2_weight = paddle.stack(ffn2_weights, axis=0)
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0)
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0)
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
name_tensor_map = {
"moe_ffn1_weight": ffn1_weight,
"moe_ffn2_weight": ffn2_weight,
"moe_ffn1_weight_scale": ffn1_weight_scale,
"moe_ffn2_weight_scale": ffn2_weight_scale
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale
}
for name, tensor in name_tensor_map.items():
create_and_set_parameter(layer, name, tensor)
@@ -310,8 +310,8 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
"""
Paddle cutlass create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, ffn1_weights, ffn2_weights)
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, up_gate_proj_weights, down_proj_weights)
def quant_worker(p_group_idx, shared_dict, weights, moe_quant_type, group_size):
@@ -329,7 +329,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
)
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
zeros_name = self.added_qzeros_attrs[idx]
@@ -365,8 +365,8 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
dict_ = dict(shared_dict)
for k, v in dict_.items():
weight_list[k] = v[0].to(ffn1_weights[0].place)
weight_scale_list[k] = v[1].to(ffn1_weights[0].place)
weight_list[k] = v[0].to(up_gate_proj_weights[0].place)
weight_scale_list[k] = v[1].to(up_gate_proj_weights[0].place)
else:
remain_weights_start_idx = 0

View File

@@ -38,14 +38,14 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
def create_weights(self, layer):
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
weight_scale_shape = [layer.weight_shape[1]]
layer.linear_weight_shape.reverse()
layer.weight_shape.reverse()
if self.quant_config.name() == "wint4":
layer.linear_weight_shape[0] //= 2
layer.weight_shape[0] //= 2
layer.weight_dtype = "int8"
layer.linear_weight_scale = layer.create_parameter(
shape=linear_weight_scale_shape,
layer.weight_scale = layer.create_parameter(
shape=weight_scale_shape,
dtype=layer._dtype,
is_bias=False,
)
@@ -61,8 +61,8 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
"""
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
layer.linear_weight.set_value(quant_weight)
layer.linear_weight_scale.set_value(
layer.weight.set_value(quant_weight)
layer.weight_scale.set_value(
weight_scale.astype(paddle.get_default_dtype()))
@@ -73,8 +73,8 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
self.group_size, # group_size
)
layer.linear_weight.set_value(quanted_weight_tensor)
layer.linear_weight_scale.set_value(
layer.weight.set_value(quanted_weight_tensor)
layer.weight_scale.set_value(
weight_scale_tensor.astype(paddle.get_default_dtype()))
@@ -82,8 +82,8 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
def apply(self, layer, x):
linear_out = linear_quant(
lhs=x,
rhs=layer.linear_weight,
scale=layer.linear_weight_scale,
rhs=layer.weight,
scale=layer.weight_scale,
bias=None,
group_size=self.group_size,
)

View File

@@ -37,13 +37,13 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
Create weights for linear layer on XPU
"""
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
layer.linear_weight_shape.reverse()
weight_scale_shape = [layer.weight_shape[1]]
layer.weight_shape.reverse()
if self.quant_config.name() == "weight_only_int4":
layer.linear_weight_shape[0] //= 2
layer.weight_shape[0] //= 2
layer.weight_dtype = "int8"
layer.linear_weight_scale = layer.create_parameter(
shape=linear_weight_scale_shape,
layer.weight_scale = layer.create_parameter(
shape=weight_scale_shape,
dtype="float32",
is_bias=False,
)
@@ -55,6 +55,6 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
"""
quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu(
weight, self.quant_config.algo, -1, -1)
layer.linear_weight.set_value(
layer.weight.set_value(
paddle.transpose(quanted_weight_tensor, [1, 0]))
layer.linear_weight_scale.set_value(weight_scale_tensor)
layer.weight_scale.set_value(weight_scale_tensor)

View File

@@ -68,13 +68,13 @@ class VocabParallelEmbedding(nn.Layer):
self.params_dtype: str = params_dtype
if self.use_ep:
self.word_embeddings = nn.Embedding(
self.embeddings = nn.Embedding(
num_embeddings,
embedding_dim,
)
else:
if not self.column_cut:
self.word_embeddings = fleet.meta_parallel.VocabParallelEmbedding(
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
num_embeddings,
embedding_dim,
mp_group=fleet.get_hybrid_communicate_group().
@@ -85,13 +85,13 @@ class VocabParallelEmbedding(nn.Layer):
)
else:
# column cut embedding
self.word_embeddings = nn.Embedding(
self.embeddings = nn.Embedding(
num_embeddings,
embedding_dim // self.world_size,
)
self.word_embeddings.weight.is_distributed = True
self.word_embeddings.weight.split_axis = 1
self.embeddings.weight.is_distributed = True
self.embeddings.weight.split_axis = 1
if not self.use_rope:
self.position_embeddings = nn.Embedding(
@@ -112,13 +112,12 @@ class VocabParallelEmbedding(nn.Layer):
Args:
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
a = state_dict[self.prefix + ".weight"]
if self.tie_word_embeddings:
self.word_embeddings.weight.set_value(
self.embeddings.weight.set_value(
get_tensor(state_dict[self.prefix + ".weight"]).astype(
paddle.get_default_dtype()))
else:
self.word_embeddings.weight.set_value(
self.embeddings.weight.set_value(
get_tensor(state_dict.pop(self.prefix + ".weight")).astype(
paddle.get_default_dtype()))
@@ -134,10 +133,10 @@ class VocabParallelEmbedding(nn.Layer):
Tensor: Embedded tensor representation of the input IDs.
"""
if self.use_ep:
input_embedings = self.word_embeddings(ids_remove_padding)
input_embedings = self.embeddings(ids_remove_padding)
else:
if self.column_cut:
input_embedings = self.word_embeddings(ids_remove_padding)
input_embedings = self.embeddings(ids_remove_padding)
inputs_embeds_temp = []
paddle.distributed.all_gather(
inputs_embeds_temp,
@@ -148,6 +147,6 @@ class VocabParallelEmbedding(nn.Layer):
)
input_embedings = paddle.concat(inputs_embeds_temp, -1)
else:
input_embedings = self.word_embeddings(ids_remove_padding)
input_embedings = self.embeddings(ids_remove_padding)
return input_embedings

View File

@@ -14,16 +14,13 @@
# limitations under the License.
"""
from paddleformers.utils.log import logger
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import (
ColumnParallelLinear,
VocabParallelEmbedding,
)
from paddle.distributed.fleet.meta_parallel import (ColumnParallelLinear,
VocabParallelEmbedding)
from paddleformers.utils.log import logger
from .utils import get_tensor
@@ -130,7 +127,7 @@ class HydraHead(nn.Layer):
]
)
self.word_embeddings = VocabParallelEmbedding(
self.embeddings = VocabParallelEmbedding(
vocab_size,
hidden_size,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
@@ -170,8 +167,8 @@ class HydraHead(nn.Layer):
get_tensor(state_dict.pop(f"1.{hydra_head_idx}.weight"))
)
self.word_embeddings.weight.set_value(
get_tensor(state_dict.pop("word_embeddings.weight"))
self.embeddings.weight.set_value(
get_tensor(state_dict.pop("embeddings.weight"))
)
def set_state_dict(self, state_dict):
@@ -183,7 +180,7 @@ class HydraHead(nn.Layer):
"""
is_custom = True
for key in state_dict.keys():
if key != "word_embeddings.weight" and (
if key != "embeddings.weight" and (
"hydra_mlp" in key or "hydra_head" in key
):
is_custom = False
@@ -207,7 +204,7 @@ class HydraHead(nn.Layer):
hidden_states: [batch_size, hidden_size] The hidden_states of the last accept_tokens
"""
hydra_inputs = [hidden_states]
input_embeds = self.word_embeddings(input_ids)
input_embeds = self.embeddings(input_ids)
for hydra_head_idx in range(self.hydra_num_heads):
hydra_inputs.append(input_embeds)
head_input = paddle.concat(hydra_inputs, axis=-1)
@@ -217,4 +214,4 @@ class HydraHead(nn.Layer):
_, topk_tokens = paddle.topk(probs, k=1, axis=-1)
next_tokens[:, 1 + hydra_head_idx : 2 + hydra_head_idx] = topk_tokens[:]
input_embeds = self.word_embeddings(next_tokens[:, 1 + hydra_head_idx])
input_embeds = self.embeddings(next_tokens[:, 1 + hydra_head_idx])

View File

@@ -79,7 +79,7 @@ class LinearBase(nn.Layer):
self._dtype = self._helper.get_default_dtype()
self.weight_dtype = self._dtype
self.linear_weight_shape = [
self.weight_shape = [
self.input_size,
self.output_size,
]
@@ -96,16 +96,16 @@ class LinearBase(nn.Layer):
"""
if self.skip_quant:
self.weight_dtype = self._dtype
self.linear_weight = self.create_parameter(
shape=self.linear_weight_shape,
self.weight = self.create_parameter(
shape=self.weight_shape,
dtype=self.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
self.linear_bias = None
self.bias = None
if self.with_bias:
self.linear_bias = self.create_parameter(
self.bias = self.create_parameter(
shape=[self.output_size],
dtype=self._dtype,
is_bias=True,
@@ -136,7 +136,7 @@ class LinearBase(nn.Layer):
if self.fd_config.quant_config:
self.quant_method.process_loaded_weights(self, weight_tensor)
else:
self.linear_weight.set_value(weight_tensor)
self.weight.set_value(weight_tensor)
def load_state_dict(self, state_dict: dict):
"""
@@ -157,7 +157,7 @@ class LinearBase(nn.Layer):
if self.with_bias:
bias_tensor = paddle.to_tensor(
get_tensor(state_dict.pop(self.bias_key)))
self.linear_bias.set_value(bias_tensor)
self.bias.set_value(bias_tensor)
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
"""
@@ -175,9 +175,9 @@ class LinearBase(nn.Layer):
if self.fd_config.quant_config:
linear_out = self.quant_method.apply(self, x)
else:
linear_out = paddle.matmul(x, self.linear_weight)
linear_out = paddle.matmul(x, self.weight)
if self.with_bias:
linear_out = paddle.add(linear_out, self.linear_bias)
linear_out = paddle.add(linear_out, self.bias)
return linear_out
@@ -219,7 +219,7 @@ class ReplicatedLinear(LinearBase):
skip_quant=skip_quant)
self.hidden_size = fd_config.model_config.hidden_size
self.linear_weight_shape = [
self.weight_shape = [
self.input_size,
self.output_size,
]
@@ -272,7 +272,7 @@ class ColumnParallelLinear(LinearBase):
output_size,
self.nranks) # Split the output_size using TP inference.
self.hidden_size = fd_config.model_config.hidden_size
self.linear_weight_shape = [
self.weight_shape = [
self.input_size,
self.output_size,
]
@@ -286,26 +286,26 @@ class ColumnParallelLinear(LinearBase):
"""
if self.skip_quant:
self.weight_dtype = self._dtype
self.linear_weight = self.create_parameter(
shape=self.linear_weight_shape,
self.weight = self.create_parameter(
shape=self.weight_shape,
dtype=self.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
if self.nranks > 0:
# col parallel
_set_var_distributed(self.linear_weight, split_axis=1)
_set_var_distributed(self.weight, split_axis=1)
self.linear_bias = None
self.bias = None
if self.with_bias:
self.linear_bias = self.create_parameter(
self.bias = self.create_parameter(
shape=[self.output_size],
dtype=self._dtype,
is_bias=True,
)
if self.nranks > 0:
# col parallel
_set_var_distributed(self.linear_bias, split_axis=1)
_set_var_distributed(self.bias, split_axis=1)
# smooth quant
self.linear_shift = None
@@ -333,7 +333,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_quant: bool = False,
):
"""
Initialize the fused ffn1 Linear layer with given parameters.
Initialize the fused up_gate_proj Linear layer with given parameters.
Args:
fd_config (FDConfig): Inference-related parameters.
@@ -443,7 +443,7 @@ class QKVParallelLinear(ColumnParallelLinear):
q_tensor = get_tensor(state_dict.pop(q_weight_key))
k_tensor = get_tensor(state_dict.pop(k_weight_key))
v_tensor = get_tensor(state_dict.pop(v_weight_key))
if self.kv_num_heads < self.nranks:
sharedkv_index = (self.fd_config.parallel_config.tensor_parallel_rank * self.kv_num_heads) // self.nranks
sharedkv_start = sharedkv_index * self.head_dim
@@ -462,7 +462,7 @@ class QKVParallelLinear(ColumnParallelLinear):
if self.fd_config.quant_config:
self.quant_method.process_loaded_weights(self, weight_tensor)
else:
self.linear_weight.set_value(weight_tensor)
self.weight.set_value(weight_tensor)
def load_state_dict(self, state_dict: dict):
"""
@@ -485,7 +485,7 @@ class QKVParallelLinear(ColumnParallelLinear):
if self.bias_key in state_dict.keys():
bias_tensor = paddle.to_tensor(
get_tensor(state_dict.pop(self.bias_key)))
self.linear_bias.set_value(bias_tensor)
self.bias.set_value(bias_tensor)
else:
q_bias_key = self.bias_key.replace("qkv_proj", "q_proj")
k_bias_key = self.bias_key.replace("qkv_proj", "k_proj")
@@ -494,7 +494,7 @@ class QKVParallelLinear(ColumnParallelLinear):
k_bias = get_tensor(state_dict.pop(k_bias_key))
v_bias = get_tensor(state_dict.pop(v_bias_key))
qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1)
self.linear_bias.set_value(qkv_bias)
self.bias.set_value(qkv_bias)
class RowParallelLinear(LinearBase):
@@ -554,7 +554,7 @@ class RowParallelLinear(LinearBase):
self.input_size = divide(input_size, self.nranks)
self.output_size = output_size
self.linear_weight_shape = [
self.weight_shape = [
self.input_size,
self.output_size,
]
@@ -574,16 +574,16 @@ class RowParallelLinear(LinearBase):
if self.skip_quant:
self.weight_dtype = self._dtype
self.linear_weight = self.create_parameter(
shape=self.linear_weight_shape,
self.weight = self.create_parameter(
shape=self.weight_shape,
dtype=self.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
self.linear_bias = None
self.bias = None
if self.with_bias:
self.linear_bias = self.create_parameter(
self.bias = self.create_parameter(
shape=[self.hidden_size],
dtype=self._dtype,
is_bias=True,
@@ -591,7 +591,7 @@ class RowParallelLinear(LinearBase):
if self.nranks > 0:
# row parallel
_set_var_distributed(self.linear_weight, split_axis=0)
_set_var_distributed(self.weight, split_axis=0)
# smooth quant
self.linear_shift = None
@@ -601,7 +601,7 @@ class RowParallelLinear(LinearBase):
if self.fd_config.quant_config:
out = self.quant_method.apply(self, x)
else:
out = paddle.matmul(x, self.linear_weight)
out = paddle.matmul(x, self.weight)
if self.reduce_results and self.nranks > 1:
tensor_model_parallel_all_reduce(out)

View File

@@ -52,11 +52,11 @@ class ParallelLMHead(nn.Layer):
with_bias (bool): whether to have bias. Default: False.
"""
super(ParallelLMHead, self).__init__()
self.linear_weight_key: str = prefix + ".weight"
self.weight_key: str = prefix + ".weight"
if with_bias:
self.linear_bias_key: Optional[str] = prefix + ".bias"
self.bias_key: Optional[str] = prefix + ".bias"
else:
self.linear_bias_key: Optional[str] = None
self.bias_key: Optional[str] = None
self.use_ep: bool = fd_config.parallel_config.use_ep
self.column_cut = True
@@ -74,26 +74,26 @@ class ParallelLMHead(nn.Layer):
else:
if self.column_cut:
need_gather = True
self.out_linear = ColumnParallelLinear(
self.linear = ColumnParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().
get_model_parallel_group(),
weight_attr=None,
has_bias=True
if self.linear_bias_key is not None else False,
if self.bias_key is not None else False,
gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小
)
else:
self.out_linear = RowParallelLinear(
self.linear = RowParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().
get_model_parallel_group(),
weight_attr=None,
has_bias=True
if self.linear_bias_key is not None else False,
if self.bias_key is not None else False,
input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小
)
@@ -109,25 +109,25 @@ class ParallelLMHead(nn.Layer):
if self.use_ep:
self.weight.set_value(
get_tensor(state_dict.pop(self.linear_weight_key)).astype(
get_tensor(state_dict.pop(self.weight_key)).astype(
paddle.get_default_dtype()))
else:
if self.tie_word_embeddings:
self.out_linear.weight.set_value(
get_tensor(state_dict.pop(self.linear_weight_key)).astype(
self.linear.weight.set_value(
get_tensor(state_dict.pop(self.weight_key)).astype(
paddle.get_default_dtype()).transpose([1, 0]))
else:
weight_tensor = get_tensor(
state_dict.pop(self.linear_weight_key)).astype(
state_dict.pop(self.weight_key)).astype(
paddle.get_default_dtype())
if self.out_linear.weight.shape != weight_tensor.shape:
if self.linear.weight.shape != weight_tensor.shape:
weight_tensor = weight_tensor.transpose([1, 0])
self.out_linear.weight.set_value(weight_tensor)
self.linear.weight.set_value(weight_tensor)
if self.linear_bias_key is not None:
bias = get_tensor(state_dict.pop(self.linear_bias_key)).astype(
if self.bias_key is not None:
bias = get_tensor(state_dict.pop(self.bias_key)).astype(
paddle.get_default_dtype())
self.out_linear.bias.set_value(bias)
self.linear.bias.set_value(bias)
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
"""
@@ -143,5 +143,5 @@ class ParallelLMHead(nn.Layer):
if self.use_ep:
logits = paddle.matmul(logits, self.weight)
else:
logits = self.out_linear(logits)
logits = self.linear(logits)
return logits

View File

@@ -34,9 +34,9 @@ class MoEMethodBase(QuantMethodBase):
self.moe_quant_type = "w16a16"
else:
self.quant_config = quant_config
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
self.added_scale_attrs = [
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
"up_gate_proj_weight_scale", "down_proj_weight_scale"
]
self.pack_num = 1
@@ -63,14 +63,14 @@ class MoEMethodBase(QuantMethodBase):
"""
pass
def check(self, layer: nn.Layer, ffn1_weights, ffn2_weights):
def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights):
"""
check layer is valid for this method
"""
assert ffn1_weights[0].shape == [
assert up_gate_proj_weights[0].shape == [
layer.hidden_size // self.pack_num, layer.moe_intermediate_size * 2
]
assert ffn2_weights[0].shape == [
assert down_proj_weights[0].shape == [
layer.moe_intermediate_size // self.pack_num, layer.hidden_size
]

View File

@@ -31,7 +31,8 @@ if current_platform.is_cuda() and not current_platform.is_dcu():
from fastdeploy.model_executor.ops.gpu import (moe_expert_dispatch,
moe_expert_reduce, noaux_tc)
elif current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import moe_expert_dispatch, moe_expert_reduce
from fastdeploy.model_executor.ops.iluvatar import (moe_expert_dispatch,
moe_expert_reduce)
# used for deepseek_v3
@@ -65,11 +66,11 @@ class CutlassMoEMethod(MoEMethodBase):
Paddle cutlass create weight process.
"""
# bf16
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0)
stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0)
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
for idx, weight_tensor in enumerate(
[stacked_ffn1_weights, stacked_ffn2_weights]):
[stacked_up_gate_proj_weights, stacked_down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
setattr(
layer, weight_name,
@@ -95,15 +96,15 @@ class CutlassMoEMethod(MoEMethodBase):
return fastdeploy.model_executor.ops.iluvatar.moe_expert_ffn(
permute_input,
token_nums_per_expert,
layer.moe_ffn1_weight,
layer.moe_ffn2_weight,
layer.up_gate_proj_weight,
layer.down_proj_weight,
None,
(layer.moe_ffn1_weight_scale if hasattr(
layer, "moe_ffn1_weight_scale") else None),
(layer.moe_ffn2_weight_scale if hasattr(
layer, "moe_ffn2_weight_scale") else None),
(layer.moe_ffn2_in_scale
if hasattr(layer, "moe_ffn2_in_scale") else None),
(layer.up_gate_proj_weight_scale if hasattr(
layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale if hasattr(
layer, "down_proj_weight_scale") else None),
(layer.down_proj_in_scale
if hasattr(layer, "down_proj_in_scale") else None),
expert_idx_per_token,
self.moe_quant_type,
used_in_ep_low_latency,
@@ -111,15 +112,15 @@ class CutlassMoEMethod(MoEMethodBase):
return fastdeploy.model_executor.ops.gpu.moe_expert_ffn(
permute_input,
token_nums_per_expert,
layer.moe_ffn1_weight,
layer.moe_ffn2_weight,
layer.up_gate_proj_weight,
layer.down_proj_weight,
None,
(layer.moe_ffn1_weight_scale
if hasattr(layer, "moe_ffn1_weight_scale") else None),
(layer.moe_ffn2_weight_scale
if hasattr(layer, "moe_ffn2_weight_scale") else None),
(layer.moe_ffn2_in_scale
if hasattr(layer, "moe_ffn2_in_scale") else None),
(layer.up_gate_proj_weight_scale
if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale
if hasattr(layer, "down_proj_weight_scale") else None),
(layer.down_proj_in_scale
if hasattr(layer, "down_proj_in_scale") else None),
expert_idx_per_token,
self.moe_quant_type,
used_in_ep_low_latency,
@@ -163,8 +164,8 @@ class CutlassMoEMethod(MoEMethodBase):
recv_x,
recv_topk_idx,
recv_topk_weights,
(self.moe_ffn1_in_scale
if hasattr(self, "moe_ffn1_in_scale") else None),
(self.up_gate_proj_in_scale
if hasattr(self, "up_gate_proj_in_scale") else None),
recv_num_tokens_per_expert_list,
token_all_num,
self.moe_quant_type,
@@ -186,7 +187,7 @@ class CutlassMoEMethod(MoEMethodBase):
dst_weights,
permute_indices_per_token,
dst_indices,
None, # moe_ffn2_bias,
None, # down_proj_bias,
False, # norm_topk_prob
1.0,
)[0]
@@ -256,7 +257,7 @@ class CutlassMoEMethod(MoEMethodBase):
x,
gate_out,
None, # Use layer.gate_correction_bias in get_moe_scores.
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale")
else None), # if set, permute_input will be int8_t
layer.top_k,
False,
@@ -274,7 +275,7 @@ class CutlassMoEMethod(MoEMethodBase):
x,
gate_out,
layer.gate_correction_bias,
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale")
else None), # if set, permute_input will be int8_t
layer.top_k,
False,
@@ -323,9 +324,9 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
"""
Paddle cutlass create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, ffn1_weights, ffn2_weights)
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, up_gate_proj_weights, down_proj_weights)
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
weight_list = []
for i in range(layer.num_local_experts):
@@ -366,26 +367,26 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
create_and_set_parameter(layer, name, processed_weight_scale)
# 1. Init scale containers and maps
moe_ffn1_weight_scales = []
moe_ffn2_weight_scales = []
moe_ffn1_in_scales = []
moe_ffn2_in_scales = []
up_gate_proj_weight_scales = []
down_proj_weight_scales = []
up_gate_proj_in_scales = []
down_proj_in_scales = []
scale_weight_map = {
"moe_ffn1_weight_scale": moe_ffn1_weight_scales,
"moe_ffn2_weight_scale": moe_ffn2_weight_scales,
"moe_ffn1_in_scale": moe_ffn1_in_scales,
"moe_ffn2_in_scale": moe_ffn2_in_scales,
"up_gate_proj_weight_scale": up_gate_proj_weight_scales,
"down_proj_weight_scale": down_proj_weight_scales,
"up_gate_proj_in_scale": up_gate_proj_in_scales,
"down_proj_in_scale": down_proj_in_scales,
}
scale_key_map = {
"moe_ffn1_weight_scale":
weight_key_map.get("ffn1_expert_weight_scale_key", None),
"moe_ffn2_weight_scale":
weight_key_map.get("ffn2_expert_weight_scale_key", None),
"moe_ffn1_in_scale":
weight_key_map.get("ffn1_expert_in_scale_key", None),
"moe_ffn2_in_scale":
weight_key_map.get("ffn2_expert_in_scale_key", None),
"up_gate_proj_weight_scale":
weight_key_map.get("up_gate_proj_expert_weight_scale_key", None),
"down_proj_weight_scale":
weight_key_map.get("down_proj_expert_weight_scale_key", None),
"up_gate_proj_in_scale":
weight_key_map.get("up_gate_proj_expert_in_scale_key", None),
"down_proj_in_scale":
weight_key_map.get("down_proj_expert_in_scale_key", None),
}
for name, value in scale_key_map.items():
if value is None:
@@ -404,13 +405,13 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
# 3. Process scale tensor and set to layer
in_scales = []
for in_scale_name in ["moe_ffn1_in_scale", "moe_ffn2_in_scale"]:
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
in_scales.append(
_process_in_scale(in_scale_name,
scale_weight_map[in_scale_name]))
for i, weight_scale_name in enumerate(
["moe_ffn1_weight_scale", "moe_ffn2_weight_scale"]):
["up_gate_proj_weight_scale", "down_proj_weight_scale"]):
_process_weight_scale(weight_scale_name,
scale_weight_map[weight_scale_name],
in_scales[i])
@@ -431,41 +432,41 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
"""
Paddle cutlass process prequanted weights.
"""
ffn1_expert_weight_key = layer.weight_key_map.get(
"ffn1_expert_weight_key", None)
ffn2_expert_weight_key = layer.weight_key_map.get(
"ffn2_expert_weight_key", None)
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
"ffn1_expert_weight_scale_key", None)
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
"ffn2_expert_weight_scale_key", None)
up_gate_proj_expert_weight_key = layer.weight_key_map.get(
"up_gate_proj_expert_weight_key", None)
down_proj_expert_weight_key = layer.weight_key_map.get(
"down_proj_expert_weight_key", None)
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get(
"up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get(
"down_proj_expert_weight_scale_key", None)
ffn1_weights, ffn2_weights = layer.load_experts_weight(
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
# self.check(layer, ffn1_weights, ffn2_weights)
ffn1_weight_scale = []
ffn2_weight_scale = []
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key)
# self.check(layer, up_gate_proj_weights, down_proj_weights)
up_gate_proj_weight_scale = []
down_proj_weight_scale = []
for i in range(layer.num_local_experts):
expert_idx = layer.expert_id_offset + i
ffn1_weight_scale.append(
up_gate_proj_weight_scale.append(
get_tensor(
state_dict.pop(
ffn1_expert_weight_scale_key.format(expert_idx))))
ffn2_weight_scale.append(
up_gate_proj_expert_weight_scale_key.format(expert_idx))))
down_proj_weight_scale.append(
get_tensor(
state_dict.pop(
ffn2_expert_weight_scale_key.format(expert_idx))))
down_proj_expert_weight_scale_key.format(expert_idx))))
ffn1_weight = paddle.stack(ffn1_weights, axis=0)
ffn2_weight = paddle.stack(ffn2_weights, axis=0)
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0)
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0)
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
name_tensor_map = {
"moe_ffn1_weight": ffn1_weight,
"moe_ffn2_weight": ffn2_weight,
"moe_ffn1_weight_scale": ffn1_weight_scale,
"moe_ffn2_weight_scale": ffn2_weight_scale
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale
}
for name, tensor in name_tensor_map.items():
create_and_set_parameter(layer, name, tensor)
@@ -474,10 +475,10 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
"""
Paddle cutlass create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, ffn1_weights, ffn2_weights)
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, up_gate_proj_weights, down_proj_weights)
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]

View File

@@ -39,11 +39,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
deepgemm create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, ffn1_weights, ffn2_weights)
self.check(layer, up_gate_proj_weights, down_proj_weights)
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
@@ -70,41 +70,41 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
"""
Paddle cutlass process prequanted weights.
"""
ffn1_expert_weight_key = layer.weight_key_map.get(
"ffn1_expert_weight_key", None)
ffn2_expert_weight_key = layer.weight_key_map.get(
"ffn2_expert_weight_key", None)
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
"ffn1_expert_weight_scale_key", None)
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
"ffn2_expert_weight_scale_key", None)
up_gate_proj_expert_weight_key = layer.weight_key_map.get(
"up_gate_proj_expert_weight_key", None)
down_proj_expert_weight_key = layer.weight_key_map.get(
"down_proj_expert_weight_key", None)
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get(
"up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get(
"down_proj_expert_weight_scale_key", None)
ffn1_weights, ffn2_weights = layer.load_experts_weight(
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
# self.check(layer, ffn1_weights, ffn2_weights)
ffn1_weight_scale = []
ffn2_weight_scale = []
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key)
# self.check(layer, up_gate_proj_weights, down_proj_weights)
up_gate_proj_weight_scale = []
down_proj_weight_scale = []
for i in range(layer.num_local_experts):
expert_idx = layer.expert_id_offset + i
ffn1_weight_scale.append(
up_gate_proj_weight_scale.append(
get_tensor(
state_dict.pop(
ffn1_expert_weight_scale_key.format(expert_idx))))
ffn2_weight_scale.append(
up_gate_proj_expert_weight_scale_key.format(expert_idx))))
down_proj_weight_scale.append(
get_tensor(
state_dict.pop(
ffn2_expert_weight_scale_key.format(expert_idx))))
down_proj_expert_weight_scale_key.format(expert_idx))))
ffn1_weight = paddle.stack(ffn1_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
ffn2_weight = paddle.stack(ffn2_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0).transpose([0, 2, 1]).contiguous()
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0).transpose([0, 2, 1]).contiguous()
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
down_proj_weight = paddle.stack(down_proj_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0).transpose([0, 2, 1]).contiguous()
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0).transpose([0, 2, 1]).contiguous()
name_tensor_map = {
"moe_ffn1_weight": ffn1_weight,
"moe_ffn2_weight": ffn2_weight,
"moe_ffn1_weight_scale": ffn1_weight_scale,
"moe_ffn2_weight_scale": ffn2_weight_scale
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale
}
for name, tensor in name_tensor_map.items():
create_and_set_parameter(layer, name, tensor)
@@ -143,10 +143,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
if token_all_num > 0:
logger.info(f"token_all_num {token_all_num}")
(recv_x, recv_x_scale) = recv_x
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
token_nums_this_rank_padded = sum(token_nums_this_rank[1].numpy().tolist())
(
permute_input,
permute_scale,
@@ -171,21 +171,21 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
permute_scale = permute_scale.transpose([1, 0]).contiguous()
permute_scale = permute_scale.transpose([1, 0])
# ffn1
# up_gate_proj
ffn_out = paddle.empty(
(permute_input.shape[0], layer.moe_ffn1_weight.shape[1]),
(permute_input.shape[0], layer.up_gate_proj_weight.shape[1]),
dtype=paddle.bfloat16,
)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(permute_input, permute_scale),
(layer.moe_ffn1_weight, layer.moe_ffn1_weight_scale),
(layer.up_gate_proj_weight, layer.up_gate_proj_weight_scale),
ffn_out,
m_indices,
)
# swiglu
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out, None)
# ffn2
# down_proj
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
ffn_out, self.quant_config.weight_block_size[0])
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose(
@@ -193,11 +193,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
ffn_out = paddle.empty(
(ffn_out.shape[0], layer.moe_ffn2_weight.shape[1]),
(ffn_out.shape[0], layer.down_proj_weight.shape[1]),
dtype=paddle.bfloat16)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(ffn_in_x, ffn_in_x_scale_tensor),
(layer.moe_ffn2_weight, layer.moe_ffn2_weight_scale),
(layer.down_proj_weight, layer.down_proj_weight_scale),
ffn_out,
m_indices,
)
@@ -207,7 +207,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
dst_weights,
permute_indices_per_token,
dst_indices,
None, # moe_ffn2_bias
None, # down_proj_bias
False, # norm_topk_prob
1.0,
)[0]
@@ -237,7 +237,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
# 3. Compute ffn
assert isinstance(permute_input, tuple)
ffn1_out = paddle.empty(
up_gate_proj_out = paddle.empty(
[
layer.num_local_experts,
layer.ep_size *
@@ -261,16 +261,16 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
permute_input,
(
layer.moe_ffn1_weight,
layer.moe_ffn1_weight_scale,
layer.up_gate_proj_weight,
layer.up_gate_proj_weight_scale,
),
ffn1_out,
up_gate_proj_out,
token_nums_per_expert,
expected_m,
)
act_out = fastdeploy.model_executor.ops.gpu.group_swiglu_with_masked(
ffn1_out, token_nums_per_expert)
up_gate_proj_out, token_nums_per_expert)
act_out_fp8, scale = fastdeploy.model_executor.ops.gpu.masked_per_token_quant(
act_out, token_nums_per_expert,
@@ -279,8 +279,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
(act_out_fp8, scale),
(
layer.moe_ffn2_weight,
layer.moe_ffn2_weight_scale,
layer.down_proj_weight,
layer.down_proj_weight_scale,
),
ffn_out,
token_nums_per_expert,
@@ -339,21 +339,21 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
permute_scale = permute_scale.transpose([1, 0]).contiguous()
permute_scale = permute_scale.transpose([1, 0])
# ffn1
# up_gate_proj
ffn_out = paddle.empty(
(permute_input.shape[0], layer.moe_ffn1_weight.shape[1]),
(permute_input.shape[0], layer.up_gate_proj_weight.shape[1]),
dtype=paddle.bfloat16,
)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(permute_input, permute_scale),
(layer.moe_ffn1_weight, layer.moe_ffn1_weight_scale),
(layer.up_gate_proj_weight, layer.up_gate_proj_weight_scale),
ffn_out,
m_indices,
)
# swiglu
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out)
# ffn2
# down_proj
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
ffn_out, self.quant_config.weight_block_size[0])
@@ -362,11 +362,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
ffn_out = paddle.empty(
(ffn_out.shape[0], layer.moe_ffn2_weight.shape[1]),
(ffn_out.shape[0], layer.down_proj_weight.shape[1]),
dtype=paddle.bfloat16)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(ffn_in_x, ffn_in_x_scale_tensor),
(layer.moe_ffn2_weight, layer.moe_ffn2_weight_scale),
(layer.down_proj_weight, layer.down_proj_weight_scale),
ffn_out,
m_indices,
)

View File

@@ -103,9 +103,9 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
Marlin Group Gemm to compute Fused MoE.
"""
self.quant_method = quant_method
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
self.added_scale_attrs = [
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
"up_gate_proj_weight_scale", "down_proj_weight_scale"
]
self.added_zeros_attrs = ["zeros0", "zeros1"]
@@ -113,22 +113,22 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
"""
Marlin MoE create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(ffn1_weights) == layer.num_local_experts
assert len(ffn2_weights) == layer.num_local_experts
assert ffn1_weights[0].shape == [
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
assert up_gate_proj_weights[0].shape == [
layer.hidden_size, layer.moe_intermediate_size * 2
]
assert ffn2_weights[0].shape == [
assert down_proj_weights[0].shape == [
layer.moe_intermediate_size, layer.hidden_size
]
ffn1_tensor = paddle.stack(ffn1_weights, axis=0)
ffn2_tensor = paddle.stack(ffn2_weights, axis=0)
up_gate_proj_tensor = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_tensor = paddle.stack(down_proj_weights, axis=0)
max_bound = 7
for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]):
for idx, weight_tensor in enumerate([up_gate_proj_tensor, down_proj_tensor]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
@@ -221,8 +221,8 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
ffn_out = MoeWna16MarlinGemmApi(
x,
c_or_none=None,
b_q_weight=layer.moe_ffn1_weight,
b_scales=layer.moe_ffn1_weight_scale,
b_q_weight=layer.up_gate_proj_weight,
b_scales=layer.up_gate_proj_weight_scale,
global_scale_or_none=None,
b_zeros_or_none=None,
g_idx_or_none=None,
@@ -250,8 +250,8 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
ffn_out = MoeWna16MarlinGemmApi(
swiglu_out,
c_or_none=None,
b_q_weight=layer.moe_ffn2_weight,
b_scales=layer.moe_ffn2_weight_scale,
b_q_weight=layer.down_proj_weight,
b_scales=layer.down_proj_weight_scale,
global_scale_or_none=None,
b_zeros_or_none=None,
g_idx_or_none=None,

View File

@@ -30,7 +30,7 @@ try:
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func
from .triton_moe_kernels import fused_moe_kernel_paddle
except:
except ImportError:
pass
@@ -44,9 +44,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
Triton Group Gemm to compute Fused MoE.
"""
self.quant_config = quant_config
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
self.added_scale_attrs = [
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
"up_gate_proj_weight_scale", "down_proj_weight_scale"
]
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
@@ -57,30 +57,30 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
"""
Triton MoE create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(ffn1_weights) == layer.num_local_experts
assert len(ffn2_weights) == layer.num_local_experts
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
algo = layer.quant_method.quant_config.name()
assert algo == "wint8"
assert ffn1_weights[0].shape == [
assert up_gate_proj_weights[0].shape == [
layer.hidden_size, layer.moe_intermediate_size * 2
]
assert ffn2_weights[0].shape == [
assert down_proj_weights[0].shape == [
layer.moe_intermediate_size, layer.hidden_size
]
ffn1_tensor = paddle.stack(ffn1_weights, axis=0)
ffn2_tensor = paddle.stack(ffn2_weights, axis=0)
up_gate_proj_tensor = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_tensor = paddle.stack(down_proj_weights, axis=0)
if algo == "wint8":
max_bound = 127
elif algo == "wint4":
max_bound = 7
for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]):
for idx, weight_tensor in enumerate([up_gate_proj_tensor, down_proj_tensor]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
@@ -130,7 +130,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
True, # apply_norm_weight,
False,
)
ffn1_out = paddle.empty(
up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
)
@@ -150,10 +150,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
fused_moe_kernel_paddle[grid](
x,
layer.moe_ffn1_weight,
ffn1_out,
layer.up_gate_proj_weight,
up_gate_proj_out,
None,
layer.moe_ffn1_weight_scale,
layer.up_gate_proj_weight_scale,
None,
sorted_token_ids,
expert_ids,
@@ -164,17 +164,17 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
K=hidden_size,
stride_am=x.strides[0],
stride_ak=x.strides[1],
stride_be=layer.moe_ffn1_weight.strides[0],
stride_bk=layer.moe_ffn1_weight.strides[1],
stride_bn=layer.moe_ffn1_weight.strides[2],
stride_cm=ffn1_out.strides[0],
stride_cn=ffn1_out.strides[1],
stride_be=layer.up_gate_proj_weight.strides[0],
stride_bk=layer.up_gate_proj_weight.strides[1],
stride_bn=layer.up_gate_proj_weight.strides[2],
stride_cm=up_gate_proj_out.strides[0],
stride_cn=up_gate_proj_out.strides[1],
#
stride_asm=-1,
stride_ask=-1,
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
stride_bsk=-1,
stride_bsn=layer.moe_ffn1_weight_scale.strides[1],
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
group_n=-1,
group_k=-1,
# Meta-parameters
@@ -190,10 +190,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
)
ffn2_input = paddle.incubate.nn.functional.swiglu(
ffn1_out)
down_proj_input = paddle.incubate.nn.functional.swiglu(
up_gate_proj_out)
ffn2_out = paddle.empty(
down_proj_out = paddle.empty(
(token_num * top_k, hidden_size),
dtype=x.dtype,
)
@@ -202,11 +202,11 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
fused_moe_kernel_paddle[grid](
ffn2_input,
layer.moe_ffn2_weight,
ffn2_out,
down_proj_input,
layer.down_proj_weight,
down_proj_out,
None,
layer.moe_ffn2_weight_scale,
layer.down_proj_weight_scale,
topk_weights,
sorted_token_ids,
expert_ids,
@@ -215,18 +215,18 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
token_num * top_k,
N=hidden_size,
K=moe_intermediate_size,
stride_am=ffn2_input.strides[0],
stride_ak=ffn2_input.strides[1],
stride_be=layer.moe_ffn2_weight.strides[0],
stride_bk=layer.moe_ffn2_weight.strides[1],
stride_bn=layer.moe_ffn2_weight.strides[2],
stride_cm=ffn2_out.strides[0],
stride_cn=ffn2_out.strides[1],
stride_am=down_proj_input.strides[0],
stride_ak=down_proj_input.strides[1],
stride_be=layer.down_proj_weight.strides[0],
stride_bk=layer.down_proj_weight.strides[1],
stride_bn=layer.down_proj_weight.strides[2],
stride_cm=down_proj_out.strides[0],
stride_cn=down_proj_out.strides[1],
stride_asm=-1,
stride_ask=-1,
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
stride_bse=layer.down_proj_weight_scale.strides[0],
stride_bsk=-1,
stride_bsn=layer.moe_ffn2_weight_scale.strides[1],
stride_bsn=layer.down_proj_weight_scale.strides[1],
group_n=-1,
group_k=-1,
# Meta-parameters
@@ -242,8 +242,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
)
ffn2_out.reshape_([token_num, top_k, hidden_size])
out = ffn2_out.sum(axis=1)
down_proj_out.reshape_([token_num, top_k, hidden_size])
out = down_proj_out.sum(axis=1)
return out
@@ -261,20 +261,20 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
"""process_prequanted_weights"""
ffn1_tensor, ffn2_tensor = layer.extract_moe_ffn_weights(state_dict)
assert ffn1_tensor[0].shape == [
up_gate_proj_tensor, down_proj_tensor = layer.extract_moe_ffn_weights(state_dict)
assert up_gate_proj_tensor[0].shape == [
layer.hidden_size, layer.moe_intermediate_size * 2
]
assert ffn2_tensor[0].shape == [
assert down_proj_tensor[0].shape == [
layer.moe_intermediate_size, layer.hidden_size
]
ffn1_tensor = paddle.stack(ffn1_tensor, axis=0).view(paddle.float8_e4m3fn)
ffn2_tensor = paddle.stack(ffn2_tensor, axis=0).view(paddle.float8_e4m3fn)
up_gate_proj_tensor = paddle.stack(up_gate_proj_tensor, axis=0).view(paddle.float8_e4m3fn)
down_proj_tensor = paddle.stack(down_proj_tensor, axis=0).view(paddle.float8_e4m3fn)
added_wfp8afp8_attrs = [
"moe_ffn1_weight", "moe_ffn2_weight", "moe_ffn1_weight_scale",
"moe_ffn2_weight_scale", "moe_ffn1_in_scale", "moe_ffn2_in_scale"
"up_gate_proj_weight", "down_proj_weight", "up_gate_proj_weight_scale",
"down_proj_weight_scale", "up_gate_proj_in_scale", "down_proj_in_scale"
]
def _extract_scale_tensor(key_template):
@@ -285,18 +285,18 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
return paddle.concat(result).cast("float32")
weight_key_map = layer.weight_key_map
moe_ffn1_weight_scale = _extract_scale_tensor(
weight_key_map["ffn1_expert_weight_scale_key"])
moe_ffn2_weight_scale = _extract_scale_tensor(
weight_key_map["ffn2_expert_weight_scale_key"])
moe_ffn1_in_scale = _extract_scale_tensor(
weight_key_map["ffn1_expert_in_scale_key"])
moe_ffn2_in_scale = _extract_scale_tensor(
weight_key_map["ffn2_expert_in_scale_key"])
up_gate_proj_weight_scale = _extract_scale_tensor(
weight_key_map["up_gate_proj_expert_weight_scale_key"])
down_proj_weight_scale = _extract_scale_tensor(
weight_key_map["down_proj_expert_weight_scale_key"])
up_gate_proj_in_scale = _extract_scale_tensor(
weight_key_map["up_gate_proj_expert_in_scale_key"])
down_proj_in_scale = _extract_scale_tensor(
weight_key_map["down_proj_expert_in_scale_key"])
for idx, weight_tensor in enumerate([
ffn1_tensor, ffn2_tensor, moe_ffn1_weight_scale,
moe_ffn2_weight_scale, moe_ffn1_in_scale, moe_ffn2_in_scale
up_gate_proj_tensor, down_proj_tensor, up_gate_proj_weight_scale,
down_proj_weight_scale, up_gate_proj_in_scale, down_proj_in_scale
]):
name = added_wfp8afp8_attrs[idx]
setattr(
@@ -341,12 +341,12 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
False,
)
ffn1_out = paddle.empty(
up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
)
config_ffn1 = {
config_up_gate_proj = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
@@ -354,15 +354,15 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
}
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
topk_ids, num_local_experts, config_ffn1["BLOCK_SIZE_M"])
topk_ids, num_local_experts, config_up_gate_proj["BLOCK_SIZE_M"])
max_possible_num_post_padded = sorted_token_ids.shape[0]
grid = (
ceil_div(max_possible_num_post_padded, config_ffn1["BLOCK_SIZE_M"]) *
ceil_div(moe_intermediate_size * 2, config_ffn1["BLOCK_SIZE_N"]), )
ceil_div(max_possible_num_post_padded, config_up_gate_proj["BLOCK_SIZE_M"]) *
ceil_div(moe_intermediate_size * 2, config_up_gate_proj["BLOCK_SIZE_N"]), )
permute_x = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
x,
scale=layer.moe_ffn1_in_scale,
scale=layer.up_gate_proj_in_scale,
topk_ids=topk_ids,
top_k=top_k,
intermediate_size=hidden_size,
@@ -370,10 +370,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
fused_moe_kernel_paddle[grid](
permute_x,
layer.moe_ffn1_weight,
ffn1_out,
layer.moe_ffn1_in_scale,
layer.moe_ffn1_weight_scale,
layer.up_gate_proj_weight,
up_gate_proj_out,
layer.up_gate_proj_in_scale,
layer.up_gate_proj_weight_scale,
None,
sorted_token_ids,
expert_ids,
@@ -384,11 +384,11 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
K=hidden_size,
stride_am=x.strides[0],
stride_ak=x.strides[1],
stride_be=layer.moe_ffn1_weight.strides[0],
stride_bk=layer.moe_ffn1_weight.strides[1],
stride_bn=layer.moe_ffn1_weight.strides[2],
stride_cm=ffn1_out.strides[0],
stride_cn=ffn1_out.strides[1],
stride_be=layer.up_gate_proj_weight.strides[0],
stride_bk=layer.up_gate_proj_weight.strides[1],
stride_bn=layer.up_gate_proj_weight.strides[2],
stride_cm=up_gate_proj_out.strides[0],
stride_cn=up_gate_proj_out.strides[1],
#
stride_asm=-1, # only used in blockwise fp8
stride_ask=-1, # only used in blockwise fp8
@@ -398,51 +398,51 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
group_n=-1,
group_k=-1,
# Meta-parameters
BLOCK_SIZE_M=config_ffn1["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config_ffn1["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config_ffn1["BLOCK_SIZE_K"],
GROUP_SIZE_M=config_ffn1["GROUP_SIZE_M"],
BLOCK_SIZE_M=config_up_gate_proj["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config_up_gate_proj["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config_up_gate_proj["BLOCK_SIZE_K"],
GROUP_SIZE_M=config_up_gate_proj["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=False,
top_k=1,
compute_type_enum=1,
use_fp8_w8a8=True,
use_int8_w8a16=False,
even_Ks=hidden_size % config_ffn1["BLOCK_SIZE_K"] == 0,
even_Ks=hidden_size % config_up_gate_proj["BLOCK_SIZE_K"] == 0,
)
ffn2_input = paddle.incubate.nn.functional.swiglu(
ffn1_out)
down_proj_input = paddle.incubate.nn.functional.swiglu(
up_gate_proj_out)
ffn2_input = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
ffn2_input,
scale=layer.moe_ffn2_in_scale,
down_proj_input = fastdeploy.model_executor.ops.gpu.moe_fused_hadamard_quant_fp8(
down_proj_input,
scale=layer.down_proj_in_scale,
topk_ids=topk_ids,
top_k=top_k,
intermediate_size=moe_intermediate_size,
tiled=True)
config_ffn2 = {
config_down_proj = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
ffn2_out = paddle.empty(
down_proj_out = paddle.empty(
(token_num * top_k, hidden_size),
dtype=x.dtype,
)
grid = (
ceil_div(max_possible_num_post_padded, config_ffn2["BLOCK_SIZE_M"]) *
ceil_div(hidden_size, config_ffn2["BLOCK_SIZE_N"]), )
ceil_div(max_possible_num_post_padded, config_down_proj["BLOCK_SIZE_M"]) *
ceil_div(hidden_size, config_down_proj["BLOCK_SIZE_N"]), )
fused_moe_kernel_paddle[grid](
ffn2_input,
layer.moe_ffn2_weight,
ffn2_out,
layer.moe_ffn2_in_scale,
layer.moe_ffn2_weight_scale,
down_proj_input,
layer.down_proj_weight,
down_proj_out,
layer.down_proj_in_scale,
layer.down_proj_weight_scale,
topk_weights,
sorted_token_ids,
expert_ids,
@@ -451,13 +451,13 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
token_num * top_k,
N=hidden_size,
K=moe_intermediate_size,
stride_am=ffn2_input.strides[0],
stride_ak=ffn2_input.strides[1],
stride_be=layer.moe_ffn2_weight.strides[0],
stride_bk=layer.moe_ffn2_weight.strides[1],
stride_bn=layer.moe_ffn2_weight.strides[2],
stride_cm=ffn2_out.strides[0],
stride_cn=ffn2_out.strides[1],
stride_am=down_proj_input.strides[0],
stride_ak=down_proj_input.strides[1],
stride_be=layer.down_proj_weight.strides[0],
stride_bk=layer.down_proj_weight.strides[1],
stride_bn=layer.down_proj_weight.strides[2],
stride_cm=down_proj_out.strides[0],
stride_cn=down_proj_out.strides[1],
stride_asm=-1,
stride_ask=-1,
stride_bse=-1,
@@ -466,20 +466,20 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
group_n=-1,
group_k=-1,
# Meta-parameters
BLOCK_SIZE_M=config_ffn2["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config_ffn2["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config_ffn2["BLOCK_SIZE_K"],
GROUP_SIZE_M=config_ffn2["GROUP_SIZE_M"],
BLOCK_SIZE_M=config_down_proj["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config_down_proj["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config_down_proj["BLOCK_SIZE_K"],
GROUP_SIZE_M=config_down_proj["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=True,
top_k=1,
compute_type_enum=1,
use_fp8_w8a8=True,
use_int8_w8a16=False,
even_Ks=moe_intermediate_size % config_ffn2["BLOCK_SIZE_K"] == 0,
even_Ks=moe_intermediate_size % config_down_proj["BLOCK_SIZE_K"] == 0,
)
ffn2_out.reshape_([token_num, top_k, hidden_size])
out = ffn2_out.sum(axis=1)
down_proj_out.reshape_([token_num, top_k, hidden_size])
out = down_proj_out.sum(axis=1)
if layer.tp_size > 1:
tensor_model_parallel_all_reduce(out)
@@ -496,9 +496,9 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
Triton Group Gemm to compute Fused MoE.
"""
self.quant_config = quant_config
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
self.added_scale_attrs = [
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
"up_gate_proj_weight_scale", "down_proj_weight_scale"
]
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
@@ -510,11 +510,11 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
"""
Triton MoE create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, ffn1_weights, ffn2_weights)
self.check(layer, up_gate_proj_weights, down_proj_weights)
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
@@ -537,14 +537,14 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
[0, 2, 1]).contiguous()
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
def check(self, layer: nn.Layer, ffn1_weights, ffn2_weights):
def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights):
"""
check layer is valid for this method
"""
assert ffn1_weights[0].shape == [
assert up_gate_proj_weights[0].shape == [
layer.hidden_size, layer.moe_intermediate_size * 2
]
assert ffn2_weights[0].shape == [
assert down_proj_weights[0].shape == [
layer.moe_intermediate_size, layer.hidden_size
]
@@ -563,8 +563,8 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
num_local_experts = layer.num_local_experts
moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size
E, N1, _ = layer.moe_ffn1_weight.shape
N2 = layer.moe_ffn2_weight.shape[1]
E, N1, _ = layer.up_gate_proj_weight.shape
N2 = layer.down_proj_weight.shape[1]
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
@@ -605,10 +605,10 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
fused_moe_kernel_paddle[grid](
x_q,
layer.moe_ffn1_weight.view(paddle.float8_e4m3fn),
layer.up_gate_proj_weight.view(paddle.float8_e4m3fn),
intermediate_cache1,
x_scale,
layer.moe_ffn1_weight_scale,
layer.up_gate_proj_weight_scale,
None,
sorted_token_ids,
expert_ids,
@@ -619,17 +619,17 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
K=hidden_size,
stride_am=x_q.strides[0],
stride_ak=x_q.strides[1],
stride_be=layer.moe_ffn1_weight.strides[0],
stride_bk=layer.moe_ffn1_weight.strides[2],
stride_bn=layer.moe_ffn1_weight.strides[1],
stride_be=layer.up_gate_proj_weight.strides[0],
stride_bk=layer.up_gate_proj_weight.strides[2],
stride_bn=layer.up_gate_proj_weight.strides[1],
stride_cm=intermediate_cache1.strides[0],
stride_cn=intermediate_cache1.strides[1],
#
stride_asm=x_scale.strides[0], # only used in blockwise fp8
stride_ask=x_scale.strides[1], # only used in blockwise fp8
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
stride_bsk=layer.moe_ffn1_weight_scale.strides[2],
stride_bsn=layer.moe_ffn1_weight_scale.strides[1],
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
stride_bsk=layer.up_gate_proj_weight_scale.strides[2],
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
group_n=self.quant_config.weight_block_size[1],
group_k=self.quant_config.weight_block_size[0],
# Meta-parameters
@@ -656,10 +656,10 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
fused_moe_kernel_paddle[grid](
x_q,
layer.moe_ffn2_weight.view(paddle.float8_e4m3fn),
layer.down_proj_weight.view(paddle.float8_e4m3fn),
intermediate_cache3,
x_scale,
layer.moe_ffn2_weight_scale,
layer.down_proj_weight_scale,
topk_weights,
sorted_token_ids,
expert_ids,
@@ -670,16 +670,16 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
K=moe_intermediate_size,
stride_am=x_q.strides[0],
stride_ak=x_q.strides[1],
stride_be=layer.moe_ffn2_weight.strides[0],
stride_bk=layer.moe_ffn2_weight.strides[2],
stride_bn=layer.moe_ffn2_weight.strides[1],
stride_be=layer.down_proj_weight.strides[0],
stride_bk=layer.down_proj_weight.strides[2],
stride_bn=layer.down_proj_weight.strides[1],
stride_cm=intermediate_cache3.strides[0],
stride_cn=intermediate_cache3.strides[1],
stride_asm=x_scale.strides[0], # only used in blockwise fp8
stride_ask=x_scale.strides[1], # only used in blockwise fp8
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
stride_bsk=layer.moe_ffn2_weight_scale.strides[2],
stride_bsn=layer.moe_ffn2_weight_scale.strides[1],
stride_bse=layer.down_proj_weight_scale.strides[0],
stride_bsk=layer.down_proj_weight_scale.strides[2],
stride_bsn=layer.down_proj_weight_scale.strides[1],
group_n=self.quant_config.weight_block_size[1],
group_k=self.quant_config.weight_block_size[0],
# Meta-parameters

View File

@@ -41,16 +41,16 @@ class Wint2MoeMethod(QuantMethodBase):
"""
pass
def check(self, layer: nn.Layer, ffn1_weights, ffn2_weights):
def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights):
"""
check layer is valid for this method
"""
assert len(
ffn1_weights
) == layer.num_local_experts, "ffn1_weights length should be equal to num_local_experts."
up_gate_proj_weights
) == layer.num_local_experts, "up_gate_proj_weights length should be equal to num_local_experts."
assert len(
ffn2_weights
) == layer.num_local_experts, "ffn2_weights length should be equal to num_local_experts."
down_proj_weights
) == layer.num_local_experts, "down_proj_weights length should be equal to num_local_experts."
def create_weights(self, layer: nn.Layer, state_dict):
"""
@@ -78,96 +78,96 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
"""
Paddle cutlass process prequanted weights.
"""
ffn1_expert_weight_key = layer.weight_key_map.get(
"ffn1_expert_weight_key", None)
ffn2_expert_weight_key = layer.weight_key_map.get(
"ffn2_expert_weight_key", None)
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
"ffn1_expert_weight_scale_key", None)
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
"ffn2_expert_weight_scale_key", None)
ffn1_expert_super_scales_key = layer.weight_key_map.get(
"ffn1_expert_super_scales_key", None)
ffn2_expert_super_scales_key = layer.weight_key_map.get(
"ffn2_expert_super_scales_key", None)
ffn1_expert_code_scale_key = layer.weight_key_map.get(
"ffn1_expert_code_scale_key", None)
ffn2_expert_code_scale_key = layer.weight_key_map.get(
"ffn2_expert_code_scale_key", None)
ffn1_expert_code_zp_key = layer.weight_key_map.get(
"ffn1_expert_code_zp_key", None)
ffn2_expert_code_zp_key = layer.weight_key_map.get(
"ffn2_expert_code_zp_key", None)
up_gate_proj_expert_weight_key = layer.weight_key_map.get(
"up_gate_proj_expert_weight_key", None)
down_proj_expert_weight_key = layer.weight_key_map.get(
"down_proj_expert_weight_key", None)
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get(
"up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get(
"down_proj_expert_weight_scale_key", None)
up_gate_proj_expert_super_scales_key = layer.weight_key_map.get(
"up_gate_proj_expert_super_scales_key", None)
down_proj_expert_super_scales_key = layer.weight_key_map.get(
"down_proj_expert_super_scales_key", None)
up_gate_proj_expert_code_scale_key = layer.weight_key_map.get(
"up_gate_proj_expert_code_scale_key", None)
down_proj_expert_code_scale_key = layer.weight_key_map.get(
"down_proj_expert_code_scale_key", None)
up_gate_proj_expert_code_zp_key = layer.weight_key_map.get(
"up_gate_proj_expert_code_zp_key", None)
down_proj_expert_code_zp_key = layer.weight_key_map.get(
"down_proj_expert_code_zp_key", None)
ffn1_weights, ffn2_weights = layer.load_experts_weight(
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
# self.check(layer, ffn1_weights, ffn2_weights)
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key)
# self.check(layer, up_gate_proj_weights, down_proj_weights)
ffn1_weight_scale = []
ffn2_weight_scale = []
ffn1_super_scales = []
ffn2_super_scales = []
ffn1_code_scale = []
ffn2_code_scale = []
ffn1_code_zp = []
ffn2_code_zp = []
up_gate_proj_weight_scale = []
down_proj_weight_scale = []
up_gate_proj_super_scales = []
down_proj_super_scales = []
up_gate_proj_code_scale = []
down_proj_code_scale = []
up_gate_proj_code_zp = []
down_proj_code_zp = []
for i in range(layer.num_experts):
expert_idx = layer.expert_id_offset + i
ffn1_weight_scale.append(
up_gate_proj_weight_scale.append(
get_tensor(
state_dict.pop(
ffn1_expert_weight_scale_key.format(expert_idx))))
ffn2_weight_scale.append(
up_gate_proj_expert_weight_scale_key.format(expert_idx))))
down_proj_weight_scale.append(
get_tensor(
state_dict.pop(
ffn2_expert_weight_scale_key.format(expert_idx))))
ffn1_super_scales.append(
down_proj_expert_weight_scale_key.format(expert_idx))))
up_gate_proj_super_scales.append(
get_tensor(
state_dict.pop(
ffn1_expert_super_scales_key.format(expert_idx))))
ffn2_super_scales.append(
up_gate_proj_expert_super_scales_key.format(expert_idx))))
down_proj_super_scales.append(
get_tensor(
state_dict.pop(
ffn2_expert_super_scales_key.format(expert_idx))))
ffn1_code_scale.append(
down_proj_expert_super_scales_key.format(expert_idx))))
up_gate_proj_code_scale.append(
get_tensor(
state_dict.pop(
ffn1_expert_code_scale_key.format(expert_idx))))
ffn2_code_scale.append(
up_gate_proj_expert_code_scale_key.format(expert_idx))))
down_proj_code_scale.append(
get_tensor(
state_dict.pop(
ffn2_expert_code_scale_key.format(expert_idx))))
ffn1_code_zp.append(
down_proj_expert_code_scale_key.format(expert_idx))))
up_gate_proj_code_zp.append(
get_tensor(
state_dict.pop(
ffn1_expert_code_zp_key.format(expert_idx))))
ffn2_code_zp.append(
up_gate_proj_expert_code_zp_key.format(expert_idx))))
down_proj_code_zp.append(
get_tensor(
state_dict.pop(
ffn2_expert_code_zp_key.format(expert_idx))))
down_proj_expert_code_zp_key.format(expert_idx))))
ffn1_weight = paddle.stack(ffn1_weights, axis=0)
ffn2_weight = paddle.stack(ffn2_weights, axis=0)
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0)
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0)
ffn1_super_scales = paddle.stack(ffn1_super_scales, axis=0)
ffn2_super_scales = paddle.stack(ffn2_super_scales, axis=0)
ffn1_code_scale = paddle.stack(ffn1_code_scale, axis=0)
ffn2_code_scale = paddle.stack(ffn2_code_scale, axis=0)
ffn1_code_zp = paddle.stack(ffn1_code_zp, axis=0)
ffn2_code_zp = paddle.stack(ffn2_code_zp, axis=0)
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
up_gate_proj_super_scales = paddle.stack(up_gate_proj_super_scales, axis=0)
down_proj_super_scales = paddle.stack(down_proj_super_scales, axis=0)
up_gate_proj_code_scale = paddle.stack(up_gate_proj_code_scale, axis=0)
down_proj_code_scale = paddle.stack(down_proj_code_scale, axis=0)
up_gate_proj_code_zp = paddle.stack(up_gate_proj_code_zp, axis=0)
down_proj_code_zp = paddle.stack(down_proj_code_zp, axis=0)
name_tensor_map = {
"moe_ffn1_weight": ffn1_weight,
"moe_ffn2_weight": ffn2_weight,
"moe_ffn1_weight_scale": ffn1_weight_scale,
"moe_ffn2_weight_scale": ffn2_weight_scale,
"moe_ffn1_super_scales": ffn1_super_scales,
"moe_ffn2_super_scales": ffn2_super_scales,
"moe_ffn1_code_scale": ffn1_code_scale,
"moe_ffn2_code_scale": ffn2_code_scale,
"moe_ffn1_code_zp": ffn1_code_zp,
"moe_ffn2_code_zp": ffn2_code_zp
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
"down_proj_weight_scale": down_proj_weight_scale,
"up_gate_proj_super_scales": up_gate_proj_super_scales,
"down_proj_super_scales": down_proj_super_scales,
"up_gate_proj_code_scale": up_gate_proj_code_scale,
"down_proj_code_scale": down_proj_code_scale,
"up_gate_proj_code_zp": up_gate_proj_code_zp,
"down_proj_code_zp": down_proj_code_zp
}
for name, tensor in name_tensor_map.items():
create_and_set_parameter(layer, name, tensor)
@@ -200,7 +200,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
x,
gate_out,
layer.gate_correction_bias,
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale")
else None), # if set, permute_input will be int8_t
layer.top_k,
False,
@@ -210,17 +210,17 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
ffn_out = fastdeploy.model_executor.ops.gpu.moe_expert_ffn_wint2(
permute_input,
token_nums_per_expert,
layer.moe_ffn1_weight,
layer.moe_ffn2_weight,
layer.up_gate_proj_weight,
layer.down_proj_weight,
None,
layer.moe_ffn1_super_scales,
layer.moe_ffn2_super_scales,
layer.moe_ffn1_weight_scale,
layer.moe_ffn1_code_scale,
layer.moe_ffn1_code_zp,
layer.moe_ffn2_weight_scale,
layer.moe_ffn2_code_scale,
layer.moe_ffn2_code_zp,
layer.up_gate_proj_super_scales,
layer.down_proj_super_scales,
layer.up_gate_proj_weight_scale,
layer.up_gate_proj_code_scale,
layer.up_gate_proj_code_zp,
layer.down_proj_weight_scale,
layer.down_proj_code_scale,
layer.down_proj_code_zp,
False,
)
@@ -271,7 +271,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
)
num_tokens, K = x.shape
E, _, N = layer.moe_ffn1_weight.shape
E, _, N = layer.up_gate_proj_weight.shape
M = num_tokens
top_k = topk_ids.shape[1]
@@ -308,12 +308,12 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
moe_wint2_ffn_kernel[grid](
x,
layer.moe_ffn1_weight,
layer.up_gate_proj_weight,
intermediate_cache1,
layer.moe_ffn1_weight_scale,
layer.moe_ffn1_super_scales,
layer.moe_ffn1_code_scale,
layer.moe_ffn1_code_zp,
layer.up_gate_proj_weight_scale,
layer.up_gate_proj_super_scales,
layer.up_gate_proj_code_scale,
layer.up_gate_proj_code_zp,
topk_weights,
sorted_token_ids,
expert_ids,
@@ -321,7 +321,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
num_valid_tokens,
max_possible_num_post_padded,
# Matrix dimensions
N=layer.moe_ffn1_weight.shape[-1],
N=layer.up_gate_proj_weight.shape[-1],
K=x.shape[-1],
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
@@ -329,15 +329,15 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
# (A has M rows).
stride_am=x.strides[0],
stride_ak=x.strides[1],
stride_be=layer.moe_ffn1_weight.strides[0],
stride_bk=layer.moe_ffn1_weight.strides[1],
stride_be=layer.up_gate_proj_weight.strides[0],
stride_bk=layer.up_gate_proj_weight.strides[1],
stride_bn=1,
stride_cm=intermediate_cache1.strides[-2],
stride_cn=1,
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
stride_bsk=layer.moe_ffn1_weight_scale.strides[1],
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
stride_bsk=layer.up_gate_proj_weight_scale.strides[1],
stride_bsn=1,
stride_bce=layer.moe_ffn1_code_scale.strides[0],
stride_bce=layer.up_gate_proj_code_scale.strides[0],
stride_bck=1,
stride_bcn=1,
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
@@ -361,17 +361,17 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
}
grid = (ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) *
ceil_div(layer.moe_ffn2_weight.shape[-1], config["BLOCK_SIZE_N"]), )
ceil_div(layer.down_proj_weight.shape[-1], config["BLOCK_SIZE_N"]), )
moe_wint2_ffn_kernel[grid](
intermediate_cache2,
layer.moe_ffn2_weight,
layer.down_proj_weight,
intermediate_cache3,
layer.moe_ffn2_weight_scale,
layer.moe_ffn2_super_scales,
layer.moe_ffn2_code_scale,
layer.moe_ffn2_code_zp,
layer.down_proj_weight_scale,
layer.down_proj_super_scales,
layer.down_proj_code_scale,
layer.down_proj_code_zp,
topk_weights,
sorted_token_ids,
expert_ids,
@@ -379,7 +379,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
num_valid_tokens,
max_possible_num_post_padded,
# Matrix dimensions
N=layer.moe_ffn2_weight.shape[-1],
N=layer.down_proj_weight.shape[-1],
K=intermediate_cache2.shape[-1],
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
@@ -387,15 +387,15 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
# (A has M rows).
stride_am=intermediate_cache2.strides[0],
stride_ak=1,
stride_be=layer.moe_ffn2_weight.strides[0],
stride_bk=layer.moe_ffn2_weight.strides[1],
stride_be=layer.down_proj_weight.strides[0],
stride_bk=layer.down_proj_weight.strides[1],
stride_bn=1,
stride_cm=intermediate_cache3.strides[-2],
stride_cn=1,
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
stride_bsk=layer.moe_ffn2_weight_scale.strides[1],
stride_bse=layer.down_proj_weight_scale.strides[0],
stride_bsk=layer.down_proj_weight_scale.strides[1],
stride_bsn=1,
stride_bce=layer.moe_ffn2_code_scale.strides[0],
stride_bce=layer.down_proj_code_scale.strides[0],
stride_bck=1,
stride_bcn=1,
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],

View File

@@ -38,14 +38,14 @@ class XPUMoEMethod(MoEMethodBase):
Paddle cutlass create weight process.
"""
# bf16
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
for weights in [ffn1_weights, ffn2_weights]:
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
for weights in [up_gate_proj_weights, down_proj_weights]:
for idx, weight in enumerate(weights):
weights[idx] = weight.transpose([1, 0])
stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0)
stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0)
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
for idx, weight_tensor in enumerate(
[stacked_ffn1_weights, stacked_ffn2_weights]):
[stacked_up_gate_proj_weights, stacked_down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
setattr(
layer, weight_name,
@@ -71,13 +71,13 @@ class XPUMoEMethod(MoEMethodBase):
x,
layer.gate_weight.transpose([1, 0]),
layer.gate_correction_bias,
layer.moe_ffn1_weight,
layer.moe_ffn2_weight,
None, # ffn1 bias
None, # ffn2 bias
None, # ffn1 scale
None, # ffn2 scale
None, # ffn1_in_scale
layer.up_gate_proj_weight,
layer.down_proj_weight,
None, # up_gate_proj bias
None, # down_proj bias
None, # up_gate_proj scale
None, # down_proj scale
None, # up_gate_proj_in_scale
"", # moe_quant_type
layer.top_k,
False, # moe group, used in deepseek
@@ -129,20 +129,20 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
"""
Paddle cutlass create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(ffn1_weights) == layer.num_local_experts
assert len(ffn2_weights) == layer.num_local_experts
assert ffn1_weights[0].shape == [
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
assert up_gate_proj_weights[0].shape == [
layer.hidden_size, layer.moe_intermediate_size * 2
]
assert ffn2_weights[0].shape == [
assert down_proj_weights[0].shape == [
layer.moe_intermediate_size, layer.hidden_size
]
added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
added_scale_attrs = ["moe_ffn1_weight_scale", "moe_ffn2_weight_scale"]
added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
added_scale_attrs = ["up_gate_proj_weight_scale", "down_proj_weight_scale"]
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = added_weight_attrs[idx]
scale_name = added_scale_attrs[idx]
@@ -189,16 +189,16 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
x,
layer.gate_weight.transpose([1, 0]),
layer.gate_correction_bias,
layer.moe_ffn1_weight,
layer.moe_ffn2_weight,
None, # ffn1 bias
None, # ffn2 bias
(layer.moe_ffn1_weight_scale
if hasattr(layer, "moe_ffn1_weight_scale") else None),
(layer.moe_ffn2_weight_scale
if hasattr(layer, "moe_ffn2_weight_scale") else None),
(layer.moe_ffn2_in_scale
if hasattr(layer, "moe_ffn2_in_scale") else None),
layer.up_gate_proj_weight,
layer.down_proj_weight,
None, # up_gate_proj bias
None, # down_proj bias
(layer.up_gate_proj_weight_scale
if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale
if hasattr(layer, "down_proj_weight_scale") else None),
(layer.down_proj_in_scale
if hasattr(layer, "down_proj_in_scale") else None),
self.moe_quant_type,
layer.top_k,
False, # moe group, used in deepseek

View File

@@ -145,13 +145,13 @@ class FusedMoE(nn.Layer):
shape=gate_correction_bias_shape,
dtype="float32",
)
ffn1_output_dim = self.moe_intermediate_size * 2
up_gate_proj_output_dim = self.moe_intermediate_size * 2
if self.moe_quant_type in ["fp8", "wint8"]:
ffn1_weight_shape = [self.num_local_experts, ffn1_output_dim, self.hidden_size]
ffn2_weight_shape = [self.num_local_experts, self.hidden_size, self.moe_intermediate_size]
up_gate_proj_weight_shape = [self.num_local_experts, up_gate_proj_output_dim, self.hidden_size]
down_proj_weight_shape = [self.num_local_experts, self.hidden_size, self.moe_intermediate_size]
else:
ffn1_weight_shape = [self.num_local_experts, self.hidden_size, ffn1_output_dim]
ffn2_weight_shape = [self.num_local_experts, self.moe_intermediate_size, self.hidden_size]
up_gate_proj_weight_shape = [self.num_local_experts, self.hidden_size, up_gate_proj_output_dim]
down_proj_weight_shape = [self.num_local_experts, self.moe_intermediate_size, self.hidden_size]
# Create parameters
if self.moe_quant_type == "fp8":
@@ -161,15 +161,15 @@ class FusedMoE(nn.Layer):
self.weight_dtype = "int8"
self.init_weight_only_scale()
# FFN1 parameters
self.moe_ffn1_weight = self.create_parameter(
shape=ffn1_weight_shape,
# up_gate_proj parameters
self.up_gate_proj_weight = self.create_parameter(
shape=up_gate_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
# FFN2 parameters
self.moe_ffn2_weight = self.create_parameter(
shape=ffn2_weight_shape,
# down_proj parameters
self.down_proj_weight = self.create_parameter(
shape=down_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
@@ -178,44 +178,44 @@ class FusedMoE(nn.Layer):
"""
Initialize the weight scale.
"""
self.moe_ffn1_weight_scale = self.create_parameter(
self.up_gate_proj_weight_scale = self.create_parameter(
shape=[self.num_local_experts, self.moe_intermediate_size * 2],
dtype=self._dtype,
)
self.moe_ffn2_weight_scale = self.create_parameter(
self.down_proj_weight_scale = self.create_parameter(
shape=[self.num_local_experts, self.hidden_size],
dtype=self._dtype,
)
def load_experts_weight(self, state_dict: dict,
ffn1_expert_weight_key: str,
ffn2_expert_weight_key: str):
up_gate_proj_expert_weight_key: str,
down_proj_expert_weight_key: str):
"""
Load experts weight from state_dict.
Args:
state_dict (dict): The state_dict of model.
ffn1_expert_weight_key (str): The key of ffn1 expert weight.
ffn2_expert_weight_key (str): The key of ffn2 expert weight.
up_gate_proj_expert_weight_key (str): The key of up_gate_proj expert weight.
down_proj_expert_weight_key (str): The key of down_proj expert weight.
"""
ffn1_weights = []
ffn2_weights = []
is_ffn_merged = ffn1_expert_weight_key.format(
up_gate_proj_weights = []
down_proj_weights = []
is_ffn_merged = up_gate_proj_expert_weight_key.format(
self.expert_id_offset) in state_dict
if is_ffn_merged:
for i in range(self.num_local_experts):
expert_idx = self.expert_id_offset + i
ffn1_weights.append(
up_gate_proj_weights.append(
get_tensor(
state_dict.pop(
ffn1_expert_weight_key.format(expert_idx))))
ffn2_weights.append(
up_gate_proj_expert_weight_key.format(expert_idx))))
down_proj_weights.append(
get_tensor(
state_dict.pop(
ffn2_expert_weight_key.format(expert_idx))))
down_proj_expert_weight_key.format(expert_idx))))
else:
gate_expert_weight_key = ffn1_expert_weight_key.replace(
gate_expert_weight_key = up_gate_proj_expert_weight_key.replace(
"up_gate_proj", "gate_proj")
up_expert_weight_key = ffn1_expert_weight_key.replace(
up_expert_weight_key = up_gate_proj_expert_weight_key.replace(
"up_gate_proj", "up_proj")
for j in range(self.num_local_experts):
expert_idx = self.expert_id_offset + j
@@ -223,12 +223,12 @@ class FusedMoE(nn.Layer):
state_dict.pop(gate_expert_weight_key.format(expert_idx)))
up = get_tensor(
state_dict.pop(up_expert_weight_key.format(expert_idx)))
ffn1_weights.append(paddle.concat([gate, up], axis=-1))
ffn2_weights.append(
up_gate_proj_weights.append(paddle.concat([gate, up], axis=-1))
down_proj_weights.append(
get_tensor(
state_dict.pop(
ffn2_expert_weight_key.format(expert_idx))))
return ffn1_weights, ffn2_weights
down_proj_expert_weight_key.format(expert_idx))))
return up_gate_proj_weights, down_proj_weights
def extract_moe_ffn_weights(self, state_dict: dict):
"""
@@ -239,30 +239,30 @@ class FusedMoE(nn.Layer):
Returns:
tuple: A tuple containing two lists:
- ffn1_weights: List of tensors for first FFN layer weights
- ffn2_weights: List of tensors for second FFN layer weights
- up_gate_proj_weights: List of tensors for first FFN layer weights
- down_proj_weights: List of tensors for second FFN layer weights
Raises:
AssertionError: If required weight keys are missing or number of weights
doesn't match number of local experts.
"""
ffn1_expert_weight_key = self.weight_key_map.get(
"ffn1_expert_weight_key", None)
ffn2_expert_weight_key = self.weight_key_map.get(
"ffn2_expert_weight_key", None)
assert ffn1_expert_weight_key is not None, "ffn1_expert_weight_key should not be none."
assert ffn2_expert_weight_key is not None, "ffn2_expert_weight_key should not be none."
up_gate_proj_expert_weight_key = self.weight_key_map.get(
"up_gate_proj_expert_weight_key", None)
down_proj_expert_weight_key = self.weight_key_map.get(
"down_proj_expert_weight_key", None)
assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none."
assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none."
ffn1_weights, ffn2_weights = self.load_experts_weight(
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
up_gate_proj_weights, down_proj_weights = self.load_experts_weight(
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key)
assert len(
ffn1_weights
) == self.num_local_experts, "ffn1_weights length should be equal to num_local_experts."
up_gate_proj_weights
) == self.num_local_experts, "up_gate_proj_weights length should be equal to num_local_experts."
assert len(
ffn2_weights
) == self.num_local_experts, "ffn2_weights length should be equal to num_local_experts."
down_proj_weights
) == self.num_local_experts, "down_proj_weights length should be equal to num_local_experts."
return ffn1_weights, ffn2_weights
return up_gate_proj_weights, down_proj_weights
def extract_gate_correction_bias(self, gate_correction_bias_key,
state_dict):

View File

@@ -46,11 +46,11 @@ class ParallelEHProjection(nn.Layer):
prefix (str): full name of the layer in the state dict
"""
super(ParallelEHProjection, self).__init__()
self.linear_weight_key = prefix + ".weight"
self.weight_key = prefix + ".weight"
if with_bias:
self.linear_bias_key = prefix + ".bias"
self.bias_key = prefix + ".bias"
else:
self.linear_bias_key = None
self.bias_key = None
self.use_ep = fd_config.parallel_config.use_ep
self.column_cut = True
@@ -66,26 +66,26 @@ class ParallelEHProjection(nn.Layer):
else:
if self.column_cut:
need_gather = True
self.out_linear = ColumnParallelLinear(
self.linear = ColumnParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().
get_model_parallel_group(),
weight_attr=None,
has_bias=True
if self.linear_bias_key is not None else False,
if self.bias_key is not None else False,
gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小
)
else:
self.out_linear = RowParallelLinear(
self.linear = RowParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().
get_model_parallel_group(),
weight_attr=None,
has_bias=True
if self.linear_bias_key is not None else False,
if self.bias_key is not None else False,
input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小
)
@@ -100,20 +100,20 @@ class ParallelEHProjection(nn.Layer):
if self.use_ep:
self.weight.set_value(
get_tensor(state_dict.pop(self.linear_weight_key)).astype(
get_tensor(state_dict.pop(self.weight_key)).astype(
paddle.get_default_dtype()))
else:
weight_tensor = get_tensor(
state_dict.pop(self.linear_weight_key)).astype(
state_dict.pop(self.weight_key)).astype(
paddle.get_default_dtype())
if self.out_linear.weight.shape != weight_tensor.shape:
if self.linear.weight.shape != weight_tensor.shape:
weight_tensor = weight_tensor.transpose([1, 0])
self.out_linear.weight.set_value(weight_tensor)
self.linear.weight.set_value(weight_tensor)
if self.linear_bias_key is not None:
bias = get_tensor(state_dict.pop(self.linear_bias_key)).astype(
if self.bias_key is not None:
bias = get_tensor(state_dict.pop(self.bias_key)).astype(
paddle.get_default_dtype())
self.out_linear.bias.set_value(bias)
self.linear.bias.set_value(bias)
def forward(self, input):
"""
@@ -129,5 +129,5 @@ class ParallelEHProjection(nn.Layer):
if self.use_ep:
logits = paddle.matmul(logits, self.weight)
else:
logits = self.out_linear(logits)
logits = self.linear(logits)
return logits

View File

@@ -43,7 +43,7 @@ class RMSNorm(nn.Layer):
hidden_size: int,
eps: float = 1e-5,
prefix: str = "",
linear_bias: paddle.Tensor = None,
bias: paddle.Tensor = None,
quant_scale: float = None,
begin_norm_axis: int = 1,
) -> None:
@@ -57,7 +57,7 @@ class RMSNorm(nn.Layer):
hidden_size (int) : size of hidden state.
eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5.
prefix(str,optional):The name of current layer. Defaults to "".
linear_bias (paddle.Tensor,optional): Initial bias value for the linear layer (if used). Defaults to None.
bias (paddle.Tensor,optional): Initial bias value for the linear layer (if used). Defaults to None.
quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization.
begin_norm_axis (int, optional): The axis along which to perform normalization. Defaults to 1.
@@ -78,7 +78,7 @@ class RMSNorm(nn.Layer):
self.norm_func: Callable = fused_add_rms_norm
else:
self.norm_func: Callable = fused_rms_norm
self.linear_bias: Optional[paddle.Tensor] = linear_bias
self.bias: Optional[paddle.Tensor] = bias
self.quant_scale: Optional[float] = quant_scale
self._dtype: str = self._helper.get_default_dtype()
self._norm_weight_dtype: str = self._dtype
@@ -94,9 +94,9 @@ class RMSNorm(nn.Layer):
Initialize the weights and biases.
"""
self.ln_weight = None
self.weight = None
if self.with_weight:
self.ln_weight = self.create_parameter(
self.weight = self.create_parameter(
shape=[self.hidden_size],
default_initializer=nn.initializer.Constant(value=1.0),
dtype=self._norm_weight_dtype,
@@ -115,7 +115,7 @@ class RMSNorm(nn.Layer):
weight_tensor = paddle.cast(
get_tensor(state_dict.pop(self.weight_key)),
self._norm_weight_dtype)
self.ln_weight.set_value(weight_tensor)
self.weight.set_value(weight_tensor)
def forward(
self,
@@ -139,18 +139,18 @@ class RMSNorm(nn.Layer):
"""
if current_platform.is_gcu():
if residual_input is None:
return rms_norm(x, self.ln_weight, self.eps)
return rms_norm(x, self.weight, self.eps)
norm_out = self.norm_func(
x, residual_input, self.ln_weight, self.eps
x, residual_input, self.weight, self.eps
)
else:
norm_out = self.norm_func(
x,
norm_weight=self.ln_weight,
norm_weight=self.weight,
norm_bias=None,
epsilon=self.eps,
begin_norm_axis=self.begin_norm_axis,
bias=self.linear_bias,
bias=self.bias,
residual=residual_input,
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
quant_round_type=self.quant_round_type,
@@ -174,7 +174,7 @@ class LayerNorm(nn.Layer):
hidden_size: int,
eps: float = 1e-5,
prefix="",
linear_bias: paddle.Tensor = None,
bias: paddle.Tensor = None,
quant_scale: float = None,
with_bias: bool = False,
):
@@ -189,7 +189,7 @@ class LayerNorm(nn.Layer):
eps:(float, optional): Small value added to the variance to avoid division by zero. Defaults to 1e-5.
prefix (str): Unique name of the layer, used for naming internal attributes,
you can give it any name you like.
linear_bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None.
bias (float, optional): Initial bias value for the linear layer (if used). Defaults to None.
quant_scale(float,optional):Quantization scale, used in quantization scenarios. Defaults to -1, indicating no quantization.
with_bias (bool):Whether to include bias or not. Defaults to False.
Raises:
@@ -212,7 +212,7 @@ class LayerNorm(nn.Layer):
self.norm_func: Callable = paddle.nn.functional.layer_norm
else:
self.norm_func: Callable = fused_layer_norm
self.linear_bias: Optional[paddle.Tensor] = linear_bias
self.bias: Optional[paddle.Tensor] = bias
self._dtype: str = self._helper.get_default_dtype()
self._norm_weight_dtype: str = "float32"
@@ -227,16 +227,16 @@ class LayerNorm(nn.Layer):
Initialize the weights and biases.
"""
self.ln_weight = None
self.weight = None
if self.with_weight:
self.ln_weight = self.create_parameter(
self.weight = self.create_parameter(
shape=[self.hidden_size],
default_initializer=nn.initializer.Constant(value=1.0),
dtype=self._norm_weight_dtype,
)
self.ln_bias = None
self.bias = None
if self.with_bias:
self.ln_bias = self.create_parameter(
self.bias = self.create_parameter(
shape=[self.hidden_size],
is_bias=True,
dtype=self._norm_weight_dtype,
@@ -255,14 +255,14 @@ class LayerNorm(nn.Layer):
weight_tensor = paddle.cast(
get_tensor(state_dict.pop(self.weight_key)),
self._norm_weight_dtype)
self.ln_weight.set_value(weight_tensor)
self.weight.set_value(weight_tensor)
# bias
if self.with_bias:
bias_tensor = paddle.cast(
get_tensor(state_dict.pop(self.bias_key)),
self._norm_weight_dtype)
self.ln_bias.set_value(bias_tensor)
self.bias.set_value(bias_tensor)
def forward(
self,
@@ -285,10 +285,10 @@ class LayerNorm(nn.Layer):
operations (like linear transformation) on the `residual_input`.
"""
if current_platform.is_iluvatar():
if self.ln_weight is None and self.ln_bias is None:
if self.weight is None and self.bias is None:
out = x
if self.linear_bias is not None:
out += self.linear_bias
if self.bias is not None:
out += self.bias
if residual_input is not None:
out += residual_input
return out, out
@@ -303,8 +303,8 @@ class LayerNorm(nn.Layer):
out = self.norm_func(
x=y,
normalized_shape=y.shape[1:],
weight=self.ln_weight,
bias=self.linear_bias,
weight=self.weight,
bias=self.bias,
epsilon=self.eps,
)
return out, y
@@ -312,19 +312,19 @@ class LayerNorm(nn.Layer):
out = self.norm_func(
x=x,
normalized_shape=x.shape[1:],
weight=self.ln_weight,
bias=self.linear_bias,
weight=self.weight,
bias=self.bias,
epsilon=self.eps,
)
return out
else:
norm_out = self.norm_func(
x,
norm_weight=self.ln_weight,
norm_bias=self.ln_bias,
norm_weight=self.weight,
norm_bias=self.bias,
epsilon=self.eps,
begin_norm_axis=1,
bias=self.linear_bias,
bias=self.bias,
residual=residual_input,
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
quant_round_type=self.quant_round_type,

View File

@@ -78,8 +78,8 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
self.quant_config = quant_config
def create_weights(self, layer):
layer.linear_weight_shape.reverse()
layer.linear_weight_scale = layer.create_parameter(
layer.weight_shape.reverse()
layer.weight_scale = layer.create_parameter(
shape=[
(layer.output_size + self.quant_config.weight_block_size[0] -
1) // self.quant_config.weight_block_size[0],
@@ -95,8 +95,8 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
weight_tensor = weights.transpose([1, 0])
quanted_weight_tensor, weight_block_scale_tensor = (
per_block_cast_to_fp8(weight_tensor))
layer.linear_weight.copy_(quanted_weight_tensor, False)
layer.linear_weight_scale.set_value(weight_block_scale_tensor)
layer.weight.copy_(quanted_weight_tensor, False)
layer.weight_scale.set_value(weight_block_scale_tensor)
def process_prequanted_weights(self, layer, state_dict):
"""
@@ -106,10 +106,10 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
quant_weight = quant_weight.transpose([1, 0]).contiguous()
layer.linear_weight.copy_(quant_weight.view("float8_e4m3fn"), False)
layer.weight.copy_(quant_weight.view("float8_e4m3fn"), False)
weight_scale = weight_scale.transpose([1, 0])
layer.linear_weight_scale.set_value(weight_scale)
layer.weight_scale.set_value(weight_scale)
def apply(self, layer, x):
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant_padding(
@@ -119,9 +119,9 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm
deep_gemm.gemm_fp8_fp8_bf16_nt(
(x, x_scale_tensor),
(layer.linear_weight, layer.linear_weight_scale),
(layer.weight, layer.weight_scale),
linear_out,
)
if layer.with_bias:
linear_out = paddle.add(linear_out, layer.linear_bias)
linear_out = paddle.add(linear_out, layer.bias)
return linear_out

View File

@@ -96,7 +96,7 @@ class TensorWiseFP8LinearMethod(QuantMethodBase):
act_scale = get_tensor(state_dict.pop(layer.act_scale_key))
quant_weight = quant_weight.transpose([1, 0]).contiguous()
layer.linear_weight.copy_(quant_weight.view("float8_e4m3fn"), False)
layer.weight.copy_(quant_weight.view("float8_e4m3fn"), False)
self.act_scale = act_scale.item()
self.total_scale = (act_scale * weight_scale).item()
@@ -118,7 +118,7 @@ class TensorWiseFP8LinearMethod(QuantMethodBase):
linear_out = cutlass_fp8_fp8_half_gemm_fused(
fp8_x,
layer.linear_weight,
layer.weight,
transpose_x=False,
transpose_y=True,
bias=None,

View File

@@ -63,8 +63,8 @@ class W4AFP8LinearMethod(QuantMethodBase):
self.quant_config = quant_config
def create_weights(self, layer):
layer.linear_weight_shape.reverse()
layer.linear_weight_shape[0] //= 2
layer.weight_shape.reverse()
layer.weight_shape[0] //= 2
layer.weight_dtype = "int8"
pass
@@ -77,16 +77,16 @@ class W4AFP8LinearMethod(QuantMethodBase):
scale_dtype="float16",
))
weight_scale_tensor = paddle.view(weight_scale_tensor, layer._dtype)
layer.linear_weight.set_value(quanted_weight_tensor)
layer.linear_weight_scale.set_value(weight_scale_tensor)
layer.weight.set_value(quanted_weight_tensor)
layer.weight_scale.set_value(weight_scale_tensor)
def apply(self, layer, x):
linear_out = fastdeploy.model_executor.ops.gpu.scaled_gemm_f8_i4_f16(
x,
layer.linear_weight,
layer.linear_weight_scale,
layer.weight,
layer.weight_scale,
zero_points=None,
bias=layer.linear_bias if layer.add_bias else None,
bias=layer.bias if layer.add_bias else None,
out_scale=self.quant_config.weight_scale_dict.get(layer.prefix +
".weight_scale")
/ (self.quant_config.act_scale_dict.get(layer.prefix +

View File

@@ -69,7 +69,7 @@ class W8A8LinearMethod(QuantMethodBase):
self.smooth_quant_method = SmoothQuantLinearMethod(quant_config)
def create_weights(self, layer):
layer.linear_weight_shape.reverse()
layer.weight_shape.reverse()
layer.weight_dtype = "int8"
if self.quant_config.use_smooth_quant:
self.smooth_quant_method.create_weights(layer)
@@ -101,21 +101,21 @@ class W8A8LinearMethod(QuantMethodBase):
if self.skip_quant:
logger.debug(f"{layer.prefix} skip quant")
weight_tensor = weights.cast(layer._dtype)
layer.linear_weight.set_value(weight_tensor)
layer.weight.set_value(weight_tensor)
else:
weight_tensor = weights.transpose([1, 0])
weight_tensor = paddle.cast(weight_tensor, "int8")
layer.linear_weight.set_value(weight_tensor)
layer.weight.set_value(weight_tensor)
def apply(self, layer, x):
if self.skip_quant:
linear_out = paddle.matmul(x, layer.linear_weight, False, True)
linear_out = paddle.matmul(x, layer.weight, False, True)
return linear_out
if self.quant_config.use_gemm_dequant:
linear_out = fastdeploy.model_executor.ops.gpu.gemm_dequant(
x, layer.linear_weight, layer.linear_out_scale, layer._dtype)
x, layer.weight, layer.linear_out_scale, layer._dtype)
else:
linear_out = paddle.matmul(x, layer.linear_weight, False, True)
linear_out = paddle.matmul(x, layer.weight, False, True)
linear_out = fastdeploy.model_executor.ops.gpu.dequant_int8(
linear_out, layer.linear_out_scale, layer._dtype)
return linear_out

View File

@@ -77,12 +77,12 @@ class WeightOnlyConfig(QuantConfigBase):
return GCUWeightOnlyLinearMethod(self)
elif current_platform.is_dcu():
if isinstance(layer, FusedMoE):
from fastdeploy.model_executor.layers.backends import (
DCUTritonWeightOnlyMoEMethod)
from fastdeploy.model_executor.layers.backends import \
DCUTritonWeightOnlyMoEMethod
return DCUTritonWeightOnlyMoEMethod(self)
else:
from fastdeploy.model_executor.layers.backends import (
DCUWeightOnlyLinearMethod)
from fastdeploy.model_executor.layers.backends import \
DCUWeightOnlyLinearMethod
return DCUWeightOnlyLinearMethod(self)
else:
if isinstance(layer, FusedMoE):
@@ -152,14 +152,14 @@ class WeightOnlyLinearMethod(QuantMethodBase):
def create_weights(self, layer):
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
weight_scale_shape = [layer.weight_shape[1]]
layer.linear_weight_shape.reverse()
layer.weight_shape.reverse()
if self.quant_config.name() == "wint4":
layer.linear_weight_shape[0] //= 2
layer.weight_shape[0] //= 2
layer.weight_dtype = "int8"
layer.linear_weight_scale = layer.create_parameter(
shape=linear_weight_scale_shape,
layer.weight_scale = layer.create_parameter(
shape=weight_scale_shape,
dtype=layer._dtype,
is_bias=False,
)
@@ -171,9 +171,9 @@ class WeightOnlyLinearMethod(QuantMethodBase):
def apply(self, layer, x):
linear_out = weight_only_linear(
x,
weight=layer.linear_weight,
bias=layer.linear_bias if layer.add_bias else None,
weight_scale=layer.linear_weight_scale,
weight=layer.weight,
bias=layer.bias if layer.add_bias else None,
weight_scale=layer.weight_scale,
weight_dtype="int8"
if self.quant_config.name() == "wint8" else "int4",
arch=self.quant_config.weight_only_linear_arch,
@@ -204,8 +204,8 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
"""
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
layer.linear_weight.set_value(quant_weight)
layer.linear_weight_scale.set_value(
layer.weight.set_value(quant_weight)
layer.weight_scale.set_value(
weight_scale.astype(paddle.get_default_dtype()))
def process_loaded_weights(self, layer, weight) -> None:
@@ -216,6 +216,6 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
arch=self.quant_config.weight_only_linear_arch,
)
layer.linear_weight.set_value(quanted_weight_tensor)
layer.linear_weight_scale.set_value(
layer.weight.set_value(quanted_weight_tensor)
layer.weight_scale.set_value(
weight_scale_tensor.astype(paddle.get_default_dtype()))

View File

@@ -70,11 +70,11 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
def create_weights(self, layer):
"""
"""
layer.linear_weight_shape.reverse()
layer.weight_shape.reverse()
layer.weight_dtype = "float8_e4m3fn"
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
self.skip_quant = False
layer.linear_weight_scale = layer.create_parameter(
layer.weight_scale = layer.create_parameter(
shape=[1],
dtype="float32",
is_bias=False,
@@ -86,7 +86,7 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
"""
if self.skip_quant:
weight_tensor = weights.cast(layer._dtype)
layer.linear_weight.set_value(weight_tensor)
layer.weight.set_value(weight_tensor)
return
if weights.dtype != paddle.float8_e4m3fn:
self.use_per_token_if_dynamic = True
@@ -95,22 +95,22 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
weight_tensor,
use_per_token_if_dynamic=False,
)
layer.linear_weight.copy_(qweight, False)
layer.linear_weight_scale.set_value(weight_scale)
layer.weight.copy_(qweight, False)
layer.weight_scale.set_value(weight_scale)
def apply(self, layer, x):
"""
"""
if self.skip_quant:
linear_out = paddle.matmul(x, layer.linear_weight, False, True)
linear_out = paddle.matmul(x, layer.weight, False, True)
return linear_out
if self.use_per_token_if_dynamic:
out_type = x.dtype
a_q, a_scales = scaled_fp8_quant(
x, use_per_token_if_dynamic=self.use_per_token_if_dynamic)
linear_out = cutlass_scaled_mm(a_q, layer.linear_weight, a_scales,
layer.linear_weight_scale, out_type,
layer.linear_bias)
linear_out = cutlass_scaled_mm(a_q, layer.weight, a_scales,
layer.weight_scale, out_type,
layer.bias)
else:
raise NotImplementedError
return linear_out