mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ..moe import FusedMoE
|
||||
@@ -79,29 +80,22 @@ class WINT2Config(QuantConfigBase):
|
||||
"""
|
||||
|
||||
dense_quant_type = config.get("dense_quant_config", "wint8")
|
||||
dense_quant_granularity = config.get("dense_quant_granularity",
|
||||
"per_channel")
|
||||
dense_quant_granularity = config.get("dense_quant_granularity", "per_channel")
|
||||
|
||||
moe_quant_config = config.get("moe_quant_config", {})
|
||||
moe_quant_type = moe_quant_config.get("quant_type", "w4w2")
|
||||
|
||||
moe_w4_quant_config = moe_quant_config.get("moe_w4_quant_config", {})
|
||||
moe_w4_quant_type = moe_w4_quant_config.get("quant_type",
|
||||
"wint4")
|
||||
moe_w4_quant_granularity = moe_w4_quant_config.get(
|
||||
"quant_granularity", "per_channel")
|
||||
moe_w4_quant_start_layer = moe_w4_quant_config.get(
|
||||
"quant_start_layer", 0)
|
||||
moe_w4_quant_type = moe_w4_quant_config.get("quant_type", "wint4")
|
||||
moe_w4_quant_granularity = moe_w4_quant_config.get("quant_granularity", "per_channel")
|
||||
moe_w4_quant_start_layer = moe_w4_quant_config.get("quant_start_layer", 0)
|
||||
moe_w4_quant_end_layer = moe_w4_quant_config.get("quant_end_layer", 6)
|
||||
|
||||
moe_w2_quant_config = moe_quant_config.get("moe_w2_quant_config", {})
|
||||
moe_w2_quant_type = moe_w2_quant_config.get("quant_type", "wint2")
|
||||
moe_w2_quant_granularity = moe_w2_quant_config.get(
|
||||
"quant_granularity", "pp_acc")
|
||||
moe_w2_quant_group_size = moe_w2_quant_config.get(
|
||||
"quant_group_size", 0)
|
||||
moe_w2_quant_start_layer = moe_w2_quant_config.get(
|
||||
"quant_start_layer", 0)
|
||||
moe_w2_quant_granularity = moe_w2_quant_config.get("quant_granularity", "pp_acc")
|
||||
moe_w2_quant_group_size = moe_w2_quant_config.get("quant_group_size", 0)
|
||||
moe_w2_quant_start_layer = moe_w2_quant_config.get("quant_start_layer", 0)
|
||||
moe_w2_quant_end_layer = moe_w2_quant_config.get("quant_end_layer", 0)
|
||||
|
||||
return cls(
|
||||
@@ -130,13 +124,12 @@ class WINT2Config(QuantConfigBase):
|
||||
"""
|
||||
if isinstance(layer, FusedMoE):
|
||||
if layer.layer_idx <= self.moe_w4_quant_end_layer:
|
||||
return get_quantization_config(
|
||||
self.moe_w4_quant_type).from_config(
|
||||
{}).get_quant_method(layer)
|
||||
return get_quantization_config(self.moe_w4_quant_type).from_config({}).get_quant_method(layer)
|
||||
else:
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_wint2_backend import \
|
||||
CutlassWint2FusedMoeMethod
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_wint2_backend import (
|
||||
CutlassWint2FusedMoeMethod,
|
||||
)
|
||||
|
||||
return CutlassWint2FusedMoeMethod(self)
|
||||
else:
|
||||
return get_quantization_config(self.dense_quant_type).from_config(
|
||||
{}).get_quant_method(layer)
|
||||
return get_quantization_config(self.dense_quant_type).from_config({}).get_quant_method(layer)
|
||||
|
Reference in New Issue
Block a user