【Fearture】support qwen2 some func (#2740)

* add rl qwen model support

* fix

* fix
This commit is contained in:
gaoziyuan
2025-07-08 12:03:04 +08:00
committed by GitHub
parent fefbd65cf8
commit 26d5d737dd
13 changed files with 438 additions and 174 deletions

View File

@@ -21,7 +21,11 @@ from dataclasses import dataclass, field
from typing import List, Optional
import paddle
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
try:
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
except:
flash_attention_v3_varlen = None
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention

View File

@@ -294,7 +294,7 @@ class ColumnParallelLinear(LinearBase):
)
if self.nranks > 0:
# col parallel
_set_var_distributed(self.linear_weight, split_axis=-1)
_set_var_distributed(self.linear_weight, split_axis=1)
self.linear_bias = None
if self.with_bias:
@@ -305,7 +305,7 @@ class ColumnParallelLinear(LinearBase):
)
if self.nranks > 0:
# col parallel
_set_var_distributed(self.linear_bias, split_axis=-1)
_set_var_distributed(self.linear_bias, split_axis=1)
# smooth quant
self.linear_shift = None

View File

@@ -89,6 +89,7 @@ class FusedMoE(nn.Layer):
self.routed_scaling_factor = routed_scaling_factor
moe_quant_config = fd_config.quant_config
self.moe_quant_type = None
if moe_quant_config:
self.quant_method = moe_quant_config.get_quant_method(self)
self.moe_quant_type = moe_quant_config.name()
@@ -142,7 +143,7 @@ class FusedMoE(nn.Layer):
if self.moe_quant_type == "fp8":
#(TODO:gaoziyuan)
pass
else:
elif self.moe_quant_type == "wint8":
self.weight_dtype = "int8"
self.init_weight_only_scale()