mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 01:22:59 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -15,10 +15,13 @@
|
||||
"""
|
||||
|
||||
# cipher_token=WjI1fQOvhN # do not edit this line
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.incubate.nn.functional import fused_bias_act
|
||||
|
||||
from fastdeploy.config import LLMConfig
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
@@ -29,28 +32,27 @@ class SiluAndMul(nn.Layer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config: LLMConfig,
|
||||
bias=None,
|
||||
act_method="gelu",
|
||||
dequant_scales=None,
|
||||
shift=None,
|
||||
smooth=None,
|
||||
quant_scale=-1,
|
||||
fd_config: FDConfig,
|
||||
bias: paddle.Tensor = None,
|
||||
act_method: str = "gelu",
|
||||
dequant_scales: Optional[paddle.Tensor] = None,
|
||||
shift: Optional[paddle.Tensor] = None,
|
||||
smooth: Optional[paddle.Tensor] = None,
|
||||
quant_scale: float = -1,
|
||||
):
|
||||
"""
|
||||
Initialize the activation layer with optional parameters for quantization, bias,
|
||||
activation method, and more.
|
||||
|
||||
Args:
|
||||
llm_config (Any): Arguments related to inference, including quantization
|
||||
fd_config (Any): Arguments related to inference, including quantization
|
||||
settings.
|
||||
bias (Optional[Tensor]): Optional bias term to be added to the output.
|
||||
act_method (str, optional): Activation method to be applied.
|
||||
Defaults to "gelu".
|
||||
dequant_scales (Optional[List[float]]): Dequantization scales, used in
|
||||
act_method (str): Activation method to be applied. Defaults to "gelu".
|
||||
dequant_scales (Optional[Tensor]): Dequantization scales, used in
|
||||
quantization scenarios.
|
||||
shift (Optional[float]): Shift factor, used in quantization scenarios.
|
||||
smooth (Optional[float]): Smoothing factor, used for specific activation
|
||||
shift (Optional[Tensor]): Shift factor, used in quantization scenarios.
|
||||
smooth (Optional[Tensor]): Smoothing factor, used for specific activation
|
||||
functions.
|
||||
quant_scale (float, optional): Quantization scale, used in quantization
|
||||
scenarios. Defaults to -1, indicating no quantization.
|
||||
@@ -61,12 +63,13 @@ class SiluAndMul(nn.Layer):
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.is_cuda() or current_platform.is_xpu():
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.bias = bias
|
||||
act_method = act_method.lower()
|
||||
if act_method == "silu":
|
||||
act_method = "swiglu"
|
||||
|
||||
@@ -75,9 +78,9 @@ class SiluAndMul(nn.Layer):
|
||||
self.shift = shift
|
||||
self.smooth = smooth
|
||||
self.quant_scale = quant_scale
|
||||
self.quant_round_type = llm_config.quant_config.quant_round_type
|
||||
self.quant_max_bound = llm_config.quant_config.quant_max_bound
|
||||
self.quant_min_bound = llm_config.quant_config.quant_min_bound
|
||||
self.quant_round_type = fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
|
||||
self.quant_max_bound = fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
|
||||
self.quant_min_bound = fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
|
||||
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
if self._dtype == "bfloat16":
|
||||
@@ -91,12 +94,12 @@ class SiluAndMul(nn.Layer):
|
||||
bfloat16 as default dtype, but received {self._dtype}")
|
||||
|
||||
# fp8 is not support smooth quantization
|
||||
if "float8" in llm_config.model_config.act_dtype:
|
||||
if fd_config.quant_config and "fp8" in fd_config.quant_config.name():
|
||||
self.dequant_scales = None
|
||||
self.shift = None
|
||||
self.smooth = None
|
||||
|
||||
def forward_cuda(self, x):
|
||||
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
Forward propagation of the custom activation layer.
|
||||
|
||||
|
@@ -13,15 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .attention import Attention
|
||||
from .append_attn_backend import AppendAttentionBackend
|
||||
from .attention_selecter import get_attention_backend
|
||||
from .base_attention_backend import AttentionBackend
|
||||
from .native_paddle_backend import PaddleNativeAttnBackend
|
||||
from .attention_selecter import get_attention_backend
|
||||
from .append_attn_backend import AppendAttentionBackend
|
||||
from .xpu_attn_backend import XPUAttentionBackend
|
||||
|
||||
__all__ = [
|
||||
"Attention",
|
||||
"AttentionBackend",
|
||||
"PaddleNativeAttnBackend",
|
||||
"get_attention_backend",
|
||||
"AppendAttentionBackend",
|
||||
"Attention", "AttentionBackend", "PaddleNativeAttnBackend",
|
||||
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend"
|
||||
]
|
||||
|
@@ -16,25 +16,28 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
append_attention, get_block_shape_and_split_kv_block)
|
||||
append_attention, get_block_shape_and_split_kv_block,
|
||||
init_signal_layerwise, open_shm_and_get_meta_signal)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from paddle._typing.dtype_like import _DTypeLiteral
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import \
|
||||
AttentionBackend
|
||||
from fastdeploy.worker.model_runner import ForwardMeta
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppendAttentionMetadata:
|
||||
class AppendAttentionMetadata(AttentionMetadata):
|
||||
"""
|
||||
AppendAttentionMetadata
|
||||
"""
|
||||
@@ -60,40 +63,65 @@ class AppendAttentionMetadata:
|
||||
decoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
|
||||
# pd_disaggregation
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
|
||||
|
||||
|
||||
class AppendAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
AppendAttentionBackend backend implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_runner: "ModelRunner",
|
||||
):
|
||||
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
|
||||
head_dim: int) -> None:
|
||||
"""
|
||||
AppendAttentionBackend __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: AppendAttentionMetadata = None
|
||||
self.block_size = model_runner.args.block_size
|
||||
self.max_seq_len = model_runner.args.max_model_len
|
||||
self.rope_theta = (10000.0 if model_runner.model_cfg.rope_theta is None
|
||||
else model_runner.model_cfg.rope_theta)
|
||||
self.rope_3d = getattr(model_runner.model_cfg, "rope_3d", False)
|
||||
self.causal = getattr(model_runner.model_cfg, "causal", True)
|
||||
self.speculate_method = model_runner.args.speculate_method
|
||||
self.speculate_max_draft_token_num = model_runner.args.speculate_max_draft_tokens
|
||||
self.num_heads = model_runner.model_cfg.num_attention_heads // model_runner.nranks
|
||||
self.kv_num_heads = int(
|
||||
model_runner.model_cfg.num_key_value_heads) // model_runner.nranks
|
||||
self.block_size: int = fd_config.parallel_config.block_size
|
||||
self.max_seq_len: int = fd_config.parallel_config.max_model_len
|
||||
self.rope_theta: float = (10000.0
|
||||
if fd_config.model_config.rope_theta is None
|
||||
else fd_config.model_config.rope_theta)
|
||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
|
||||
self.causal: bool = getattr(fd_config.model_config, "causal", True)
|
||||
self.speculative_method: str = fd_config.speculative_config.method
|
||||
self.use_speculate: bool = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
self.kv_num_heads: int = kv_num_heads
|
||||
self.num_heads: int = num_heads
|
||||
self.head_dim: int = fd_config.model_config.head_dim
|
||||
self.num_layers: int = fd_config.model_config.num_layers
|
||||
self.max_partition_size: int = int(
|
||||
os.getenv("FLAGS_max_partition_size", 32768))
|
||||
|
||||
# pd_disaggregation
|
||||
self.use_pd_disaggregation: int = int(
|
||||
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
|
||||
|
||||
if fd_config.parallel_config.expert_parallel_rank is None:
|
||||
fd_config.parallel_config.expert_parallel_rank = 0
|
||||
device_id = self.rank + fd_config.parallel_config.tensor_parallel_degree * \
|
||||
fd_config.parallel_config.expert_parallel_rank
|
||||
if self.device_id is None:
|
||||
self.device_id = device_id
|
||||
else:
|
||||
self.device_id = self.device_id.split(",")[device_id]
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = AppendAttentionMetadata()
|
||||
metadata.encoder_block_shape_q = 64
|
||||
metadata.decoder_block_shape_q = 16
|
||||
metadata.max_partition_size = 32768
|
||||
metadata.encoder_max_partition_size = 32768
|
||||
metadata.max_partition_size = self.max_partition_size
|
||||
metadata.encoder_max_partition_size = self.max_seq_len
|
||||
metadata._dtype = paddle.get_default_dtype()
|
||||
if metadata._dtype == "bfloat16":
|
||||
metadata._fuse_kernel_compute_dtype = "bf16"
|
||||
@@ -128,38 +156,51 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
self.block_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
)
|
||||
self.attention_metadata = metadata
|
||||
|
||||
def get_attntion_meta(self):
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
|
||||
self.rank, int(self.device_id), self.keep_pd_step_flag)
|
||||
self.attention_metadata: AttentionMetadata = metadata
|
||||
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
||||
forward_meta.decoder_tile_ids_per_batch.copy_(
|
||||
metadata.decoder_tile_ids_per_batch, False)
|
||||
|
||||
def get_attntion_meta(self) -> AttentionMetadata:
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
block_size: int,
|
||||
kv_num_head: int,
|
||||
head_dim: int,
|
||||
):
|
||||
) -> Tuple[int, int, int, int]:
|
||||
"""
|
||||
get_kv_cache_shape
|
||||
Caculate kv cache shape
|
||||
"""
|
||||
return (max_num_blocks, kv_num_head, block_size, head_dim)
|
||||
return (max_num_blocks, self.kv_num_heads, self.block_size,
|
||||
self.head_dim)
|
||||
|
||||
def forward_mixed(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
qkv,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
forward_mixed
|
||||
"""
|
||||
metadata = self.attention_metadata
|
||||
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
layer.layer_id + self.start_layer_index)
|
||||
|
||||
res = append_attention(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
@@ -176,8 +217,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.decoder_batch_ids,
|
||||
metadata.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_batch_ids, # from buffer
|
||||
forward_meta.decoder_tile_ids_per_batch, # from buffer
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.set_max_lengths,
|
||||
metadata.max_len_kv,
|
||||
@@ -193,7 +234,7 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
getattr(layer, "cache_v_zp", None),
|
||||
layer.linear_shift,
|
||||
layer.linear_smooth,
|
||||
None, # kv_signal_data,
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
metadata._fuse_kernel_compute_dtype,
|
||||
getattr(layer, "cache_quant_type_str", "none"),
|
||||
layer.use_neox_rotary_style,
|
||||
@@ -208,7 +249,6 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.encoder_max_partition_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
self.causal,
|
||||
self.speculate_method is not None,
|
||||
self.speculative_method is not None,
|
||||
)[0]
|
||||
|
||||
return res
|
||||
|
@@ -14,12 +14,17 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.worker.model_runner import ForwardMeta
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import \
|
||||
QuantMethodBase
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
||||
|
||||
class Attention(nn.Layer):
|
||||
@@ -29,26 +34,24 @@ class Attention(nn.Layer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config,
|
||||
fd_config: FDConfig,
|
||||
layer_id: int,
|
||||
logit_cap: float = 0.0,
|
||||
v_head_dim: int = -1,
|
||||
rope_type: str = "",
|
||||
qkv_bias: Optional[paddle.Tensor] = None,
|
||||
qkv_scale: Optional[paddle.Tensor] = None,
|
||||
prefix: str = "",
|
||||
out_scale: float = -1.,
|
||||
linear_shift=None,
|
||||
linear_smooth=None,
|
||||
use_neox_rotary_style=False,
|
||||
out_scale: float = -1.0,
|
||||
linear_shift: paddle.Tensor = None,
|
||||
linear_smooth: paddle.Tensor = None,
|
||||
use_neox_rotary_style: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes `LMLayer` with the given parameters.
|
||||
|
||||
Args:
|
||||
llm_config (dict): The config of LM model.
|
||||
fd_config (dict): The config of LM model.
|
||||
layer_id (int): The id of current layer.
|
||||
logit_cap (float, optional): The cap for logits. Defaults to 0.0.
|
||||
v_head_dim (int, optional): The head dim of value. Defaults to -1.
|
||||
rope_type (str, optional): The type of RoPE. Defaults to "".
|
||||
qkv_bias (Optional[paddle.Tensor], optional): The bias of QKV. Defaults to None.
|
||||
@@ -61,34 +64,46 @@ class Attention(nn.Layer):
|
||||
ValueError: If the `v_head_dim` is less than 0.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_heads = llm_config.model_config.num_attention_heads // llm_config.parallel_config.mp_size
|
||||
self.head_dim = llm_config.model_config.hidden_size // llm_config.model_config.num_attention_heads
|
||||
self.kv_num_heads = llm_config.model_config.num_key_value_heads // llm_config.parallel_config.mp_size
|
||||
self.layer_id = layer_id
|
||||
self.logit_cap = logit_cap
|
||||
self.v_head_dim = v_head_dim if v_head_dim > 0 else self.head_dim
|
||||
self.rope_type = rope_type
|
||||
self.qk_head_dim = self.head_dim
|
||||
self.num_heads: int = fd_config.model_config.num_attention_heads // fd_config.parallel_config.tensor_parallel_degree
|
||||
self.head_dim: int = fd_config.model_config.head_dim
|
||||
self.kv_num_heads: int = \
|
||||
fd_config.model_config.num_key_value_heads // fd_config.parallel_config.tensor_parallel_degree
|
||||
self.layer_id: int = layer_id
|
||||
self.v_head_dim: int = v_head_dim if v_head_dim > 0 else self.head_dim
|
||||
self.rope_type: str = rope_type
|
||||
self.qk_head_dim: int = self.head_dim
|
||||
self.prefix: str = prefix
|
||||
# not use
|
||||
self.tp_q_head_num = self.num_heads
|
||||
self.tp_k_head_num = self.num_heads
|
||||
self.tp_v_head_num = self.num_heads
|
||||
# not use
|
||||
self.scaling = 1.0 / (self.head_dim**0.5)
|
||||
self.linear_shift = linear_shift
|
||||
self.linear_smooth = linear_smooth
|
||||
self.qkv_bias = qkv_bias
|
||||
self.qkv_scale = qkv_scale
|
||||
self.linear_shift: paddle.Tensor | None = linear_shift
|
||||
self.linear_smooth: paddle.Tensor | None = linear_smooth
|
||||
self.qkv_bias: paddle.Tensor | None = qkv_bias
|
||||
self.qkv_scale: paddle.Tensor | None = qkv_scale
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
self.out_scale = out_scale
|
||||
self.use_neox_rotary_style = use_neox_rotary_style
|
||||
if llm_config.kvcache_config is not None:
|
||||
self.kvcache_quant_method = llm_config.kvcache_config.kvcache_quant_config.get_quant_method(
|
||||
|
||||
self.out_scale: float = out_scale
|
||||
self.use_neox_rotary_style: bool = use_neox_rotary_style
|
||||
|
||||
if fd_config.quant_config and hasattr(fd_config.quant_config,
|
||||
"kv_cache_quant_type"):
|
||||
self.kvcache_quant_method: QuantMethodBase = fd_config.quant_config.get_quant_method(
|
||||
self)
|
||||
self.kvcache_quant_method.create_weights(self)
|
||||
if llm_config.quant_config is not None:
|
||||
self.quant_max_bound = llm_config.quant_config.quant_max_bound
|
||||
self.quant_min_bound = llm_config.quant_config.quant_min_bound
|
||||
else:
|
||||
self.kvcache_quant_method = None
|
||||
|
||||
if self.kvcache_quant_method is None:
|
||||
logger.info(f"Attention is running in cache kv {self._dtype} mode")
|
||||
else:
|
||||
logger.info(
|
||||
f"Attention is running in cache kv {self.kvcache_quant_method.cache_quant_config.quant_type} mode"
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str,
|
||||
paddle.Tensor | np.ndarray]):
|
||||
'''
|
||||
Attention only have quant related scales not other parameters.
|
||||
'''
|
||||
if self.kvcache_quant_method is not None:
|
||||
self.kvcache_quant_method.create_weights(self, state_dict)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -97,7 +112,7 @@ class Attention(nn.Layer):
|
||||
v: paddle.Tensor = None,
|
||||
qkv: paddle.Tensor = None,
|
||||
forward_meta: ForwardMeta = None,
|
||||
):
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
The forward function of attention layer.
|
||||
args:
|
||||
|
@@ -14,26 +14,20 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
"""
|
||||
attention backend selecter
|
||||
"""
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import AttentionBackend
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import resolve_obj_from_strname
|
||||
from functools import cache
|
||||
from fastdeploy.platforms import _Backend
|
||||
|
||||
from fastdeploy.platforms import _Backend, current_platform
|
||||
from fastdeploy.utils import resolve_obj_from_strname
|
||||
|
||||
|
||||
def backend_name_to_enum(backend_name: str):
|
||||
def backend_name_to_enum(backend_name: str) -> _Backend:
|
||||
"""backend_name_to_enum """
|
||||
assert backend_name is not None
|
||||
return _Backend.__members__.get(backend_name)
|
||||
|
||||
|
||||
@cache
|
||||
def _get_attn_backend(
|
||||
selected_backend
|
||||
):
|
||||
def _get_attn_backend(selected_backend: str) -> object:
|
||||
"""_get_attn_backend """
|
||||
if isinstance(selected_backend, str):
|
||||
selected_backend = backend_name_to_enum(selected_backend)
|
||||
@@ -46,10 +40,6 @@ def _get_attn_backend(
|
||||
return resolve_obj_from_strname(attention_cls)
|
||||
|
||||
|
||||
def get_attention_backend(
|
||||
selected_backend
|
||||
):
|
||||
def get_attention_backend(selected_backend):
|
||||
"""Selects which attention backend ."""
|
||||
return _get_attn_backend(
|
||||
selected_backend
|
||||
)
|
||||
return _get_attn_backend(selected_backend)
|
||||
|
@@ -1,395 +0,0 @@
|
||||
"""
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
import fastdeploy
|
||||
|
||||
|
||||
class Attention(nn.Layer):
|
||||
"""
|
||||
Attention Layer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inference_args,
|
||||
prefix,
|
||||
out_scale=-1,
|
||||
use_neox_rotary_style=False,
|
||||
rope_theta=10000.0,
|
||||
rope_3d=False,
|
||||
qkv_scale=None,
|
||||
qkv_bias=None,
|
||||
linear_shift=None,
|
||||
linear_smooth=None,
|
||||
):
|
||||
"""
|
||||
Initialize the attention layer with various parameters.
|
||||
|
||||
Args:
|
||||
inference_args (dict or object): Contains arguments for inference, including
|
||||
number of key-value heads, weight data type, activation data type, etc.
|
||||
prefix (str): The name of the attention layer for identification purposes.
|
||||
out_scale (float, optional): Output scale factor. Defaults to -1.
|
||||
use_neox_rotary_style (bool, optional): Whether to use the NeoX rotary position
|
||||
encoding style. Defaults to False.
|
||||
rope_theta (float, optional): Theta value for the rope position encoding. Defaults to 10000.0.
|
||||
qkv_scale (float or None, optional): Quantization scale for QKV weights.
|
||||
Used only for certain quantization configurations. Defaults to None.
|
||||
qkv_bias (Tensor or None, optional): Bias for QKV linear layer. Defaults to None.
|
||||
linear_shift (float or None, optional): Linear shift factor used in
|
||||
quantization. Used only for certain quantization configurations.
|
||||
Defaults to None.
|
||||
linear_smooth (float or None, optional): Linear smooth factor used in
|
||||
quantization. Used only for certain quantization configurations.
|
||||
Defaults to None.
|
||||
"""
|
||||
super().__init__()
|
||||
self.inference_args = inference_args
|
||||
self.nranks = inference_args.mp_size
|
||||
self.kv_num_heads = inference_args.num_key_value_heads // self.nranks
|
||||
self.head_dim = self.inference_args.head_dim
|
||||
self.prefix = prefix
|
||||
self.cache_k_scale_name = prefix + ".cachek_matmul.activation_quanter"
|
||||
self.cache_v_scale_name = prefix + ".cachev_matmul.activation_quanter"
|
||||
self.out_scale = out_scale
|
||||
|
||||
self.cache_k_zp_name = self.cache_k_scale_name + ".zero_point"
|
||||
self.cache_v_zp_name = self.cache_v_scale_name + ".zero_point"
|
||||
|
||||
self.use_neox_rotary_style = use_neox_rotary_style
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_3d = rope_3d
|
||||
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
if self._dtype == "bfloat16":
|
||||
self._fuse_kernel_compute_dtype = "bf16"
|
||||
elif self._dtype == "float16":
|
||||
self._fuse_kernel_compute_dtype = "fp16"
|
||||
elif self._dtype == "float32":
|
||||
self._fuse_kernel_compute_dtype = "fp32"
|
||||
else:
|
||||
raise ValueError(f"Just support float32, float16 and \
|
||||
bfloat16 as default dtype, but received {self._dtype}")
|
||||
|
||||
self.cache_scale_dtype = (
|
||||
self._dtype if self.inference_args.use_append_attn else "float32")
|
||||
|
||||
self.qkv_bias = qkv_bias
|
||||
if inference_args.weight_dtype == "int8" and inference_args.act_dtype == "int8":
|
||||
self.qkv_scale = qkv_scale
|
||||
self.linear_shift = linear_shift
|
||||
self.linear_smooth = linear_smooth
|
||||
if (inference_args.cachekv_dtype == "int8"
|
||||
or inference_args.cachekv_dtype == "int4"
|
||||
or inference_args.cachekv_dtype == "float8_e4m3fn"):
|
||||
self.set_cachekv_scale()
|
||||
# qkv_bias fused with attention only when W8A8
|
||||
if not (inference_args.weight_dtype == "int8"
|
||||
and inference_args.act_dtype == "int8"):
|
||||
self.qkv_bias = None
|
||||
|
||||
def set_cachekv_scale(self):
|
||||
"""
|
||||
Set cache key (K) and value (V) scaling factors.
|
||||
|
||||
This method initializes and sets the scaling factors for cache key (K) and value (V)
|
||||
tensors, which are used in attention mechanisms to adjust the scale of the cache
|
||||
representations. Additionally, it calculates and sets the inverse of these scaling
|
||||
factors for the output cache K and V tensors.
|
||||
|
||||
Args:
|
||||
None - This method does not take any explicit arguments as it relies on the
|
||||
instance variables of the class, such as `self.kv_num_heads`,
|
||||
`self.cache_k_scale_name`, `self.cache_v_scale_name`, and
|
||||
`self.inference_args.cachekv_scale_dict` for its functionality.
|
||||
|
||||
Returns:
|
||||
None - This method modifies the instance variables directly and does not return
|
||||
any values.
|
||||
"""
|
||||
self.cache_k_scale = self.create_parameter(
|
||||
shape=([self.kv_num_heads *
|
||||
self.head_dim] if self.inference_args.is_channel_wise else
|
||||
[self.kv_num_heads]),
|
||||
dtype=self.cache_scale_dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
self.cache_v_scale = self.create_parameter(
|
||||
shape=([self.kv_num_heads *
|
||||
self.head_dim] if self.inference_args.is_channel_wise else
|
||||
[self.kv_num_heads]),
|
||||
dtype=self.cache_scale_dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
self.cache_k_out_scale = self.create_parameter(
|
||||
shape=([self.kv_num_heads *
|
||||
self.head_dim] if self.inference_args.is_channel_wise else
|
||||
[self.kv_num_heads]),
|
||||
attr=None,
|
||||
dtype=self.cache_scale_dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
self.cache_v_out_scale = self.create_parameter(
|
||||
shape=([self.kv_num_heads *
|
||||
self.head_dim] if self.inference_args.is_channel_wise else
|
||||
[self.kv_num_heads]),
|
||||
attr=None,
|
||||
dtype=self.cache_scale_dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
if self.cache_k_scale_name in self.inference_args.cachekv_scale_dict:
|
||||
cache_k_scale = paddle.cast(
|
||||
paddle.to_tensor(self.inference_args.cachekv_scale_dict[
|
||||
self.cache_k_scale_name]),
|
||||
self.cache_scale_dtype,
|
||||
)
|
||||
cache_k_out_scale = 1.0 / cache_k_scale
|
||||
else:
|
||||
if os.getenv("EP_DECODER_PERF_TEST", "False") == "True":
|
||||
cache_k_scale = paddle.zeros(self.cache_k_scale.shape,
|
||||
self.cache_k_scale.dtype)
|
||||
cache_k_out_scale = paddle.zeros(self.cache_k_out_scale.shape,
|
||||
self.cache_k_out_scale.dtype)
|
||||
else:
|
||||
raise KeyError(
|
||||
f"{self.cache_k_scale_name} not found in scale dict")
|
||||
|
||||
if self.cache_v_scale_name in self.inference_args.cachekv_scale_dict:
|
||||
cache_v_scale = paddle.cast(
|
||||
paddle.to_tensor(self.inference_args.cachekv_scale_dict[
|
||||
self.cache_v_scale_name]),
|
||||
self.cache_scale_dtype,
|
||||
)
|
||||
cache_v_out_scale = 1.0 / cache_v_scale
|
||||
else:
|
||||
if os.getenv("EP_DECODER_PERF_TEST", "False") == "True":
|
||||
cache_v_scale = paddle.zeros(self.cache_v_scale.shape,
|
||||
self.cache_v_scale.dtype)
|
||||
cache_v_out_scale = paddle.zeros(self.cache_v_out_scale.shape,
|
||||
self.cache_v_out_scale.dtype)
|
||||
else:
|
||||
raise KeyError(
|
||||
f"{self.cache_v_scale_name} not found in scale dict")
|
||||
|
||||
self.cache_k_scale.set_value(cache_k_scale)
|
||||
self.cache_v_scale.set_value(cache_v_scale)
|
||||
self.cache_k_out_scale.set_value(cache_k_out_scale)
|
||||
self.cache_v_out_scale.set_value(cache_v_out_scale)
|
||||
|
||||
if self.inference_args.has_zero_point:
|
||||
self.cache_k_zp = self.create_parameter(
|
||||
shape=([self.kv_num_heads *
|
||||
self.head_dim] if self.inference_args.is_channel_wise
|
||||
else [self.kv_num_heads]),
|
||||
dtype=self.cache_scale_dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
self.cache_v_zp = self.create_parameter(
|
||||
shape=([self.kv_num_heads *
|
||||
self.head_dim] if self.inference_args.is_channel_wise
|
||||
else [self.kv_num_heads]),
|
||||
dtype=self.cache_scale_dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
if self.cache_k_zp_name in self.inference_args.cachekv_scale_dict:
|
||||
cache_k_zp = paddle.cast(
|
||||
paddle.to_tensor(self.inference_args.cachekv_scale_dict[
|
||||
self.cache_k_zp_name]),
|
||||
self.cache_scale_dtype,
|
||||
)
|
||||
else:
|
||||
cache_k_zp = paddle.zeros(
|
||||
([self.kv_num_heads *
|
||||
self.head_dim] if self.inference_args.is_channel_wise
|
||||
else [self.kv_num_heads]),
|
||||
dtype=self.cache_scale_dtype,
|
||||
)
|
||||
if self.cache_v_zp_name in self.inference_args.cachekv_scale_dict:
|
||||
cache_v_zp = paddle.cast(
|
||||
paddle.to_tensor(self.inference_args.cachekv_scale_dict[
|
||||
self.cache_v_zp_name]),
|
||||
self.cache_scale_dtype,
|
||||
)
|
||||
else:
|
||||
cache_v_zp = paddle.zeros(
|
||||
([self.kv_num_heads *
|
||||
self.head_dim] if self.inference_args.is_channel_wise
|
||||
else [self.kv_num_heads]),
|
||||
dtype=self.cache_scale_dtype,
|
||||
)
|
||||
self.cache_k_zp.set_value(cache_k_zp)
|
||||
self.cache_v_zp.set_value(cache_v_zp)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
qkv,
|
||||
input_ids,
|
||||
rotary_embs,
|
||||
rotary_emb_dims,
|
||||
key_cache,
|
||||
value_cache,
|
||||
pre_key_cache,
|
||||
pre_value_cache,
|
||||
pre_caches_length,
|
||||
attn_mask,
|
||||
kv_signal_data,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Compute the attention for a single time step.
|
||||
|
||||
Args:
|
||||
qkv (Tensor): The output of the linear transformation of query, key and value.
|
||||
Shape: [batch_size, num_heads, seq_len, embed_dim // num_heads].
|
||||
padding_offset (Tensor): The offset to be added to the sequence length when computing
|
||||
the attention mask. Shape: [batch_size, 1].
|
||||
input_ids (Tensor, optional): The input ids of the batch. Used for computing the
|
||||
attention mask. Default: None. Shape: [batch_size, max_sequence_length].
|
||||
rotary_embs (Tensor, optional): The rotary position embeddings. Default: None.
|
||||
Shape: [num_heads, rotary_emb_dims].
|
||||
rotary_emb_dims (int, optional): The dimension of the rotary position embeddings.
|
||||
Default: None.
|
||||
caches (List[Tensor], optional): The cache tensors used in the computation of the
|
||||
attention. Default: None.
|
||||
pre_caches (List[Tensor], optional): The pre-computed cache tensors used in the
|
||||
computation of the attention. Default: None.
|
||||
pre_caches_length (int, optional): The length of the pre-computed cache tensors.
|
||||
Default: None.
|
||||
attn_mask (Tensor, optional): The attention mask. Default: None.
|
||||
Shape: [batch_size, max_sequence_length].
|
||||
**kwargs (dict, optional): Additional keyword arguments passed along.
|
||||
|
||||
Returns:
|
||||
Tensor: The output of the linear transformation after applying the attention.
|
||||
Shape: [batch_size, embed_dim // num_heads].
|
||||
|
||||
Raises:
|
||||
None.
|
||||
"""
|
||||
k_quant_scale = kwargs.get("k_quant_scale", None)
|
||||
v_quant_scale = kwargs.get("v_quant_scale", None)
|
||||
k_dequant_scale = kwargs.get("k_dequant_scale", None)
|
||||
v_dequant_scale = kwargs.get("v_dequant_scale", None)
|
||||
|
||||
if not self.inference_args.use_dynamic_cachekv_quant:
|
||||
k_quant_scale = getattr(self, "cache_k_scale", None)
|
||||
v_quant_scale = getattr(self, "cache_v_scale", None)
|
||||
k_dequant_scale = getattr(self, "cache_k_out_scale", None)
|
||||
v_dequant_scale = getattr(self, "cache_v_out_scale", None)
|
||||
cache_quant_type_str = self.inference_args.cache_quant_type
|
||||
else:
|
||||
cache_quant_type_str = "none"
|
||||
|
||||
if self.inference_args.use_append_attn:
|
||||
out = fastdeploy.model_executor.ops.gpu.append_attention(
|
||||
qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kwargs.get("seq_lens_encoder", None),
|
||||
kwargs.get("seq_lens_decoder", None),
|
||||
kwargs.get("seq_lens_this_time", None),
|
||||
kwargs.get("padding_offsets", None),
|
||||
kwargs.get("cum_offsets", None),
|
||||
kwargs.get("block_tables", None),
|
||||
kwargs.get("encoder_batch_ids", None),
|
||||
kwargs.get("encoder_tile_ids_per_batch", None),
|
||||
kwargs.get("encoder_num_blocks", None),
|
||||
kwargs.get("kv_batch_ids", None),
|
||||
kwargs.get("kv_tile_ids_per_batch", None),
|
||||
kwargs.get("kv_num_blocks", None),
|
||||
kwargs.get("decoder_batch_ids", None),
|
||||
kwargs.get("decoder_tile_ids_per_batch", None),
|
||||
kwargs.get("decoder_num_blocks", None),
|
||||
kwargs.get("set_max_lengths", None),
|
||||
kwargs.get("max_len_kv", None),
|
||||
rotary_embs,
|
||||
attn_mask,
|
||||
getattr(self, "qkv_bias", None),
|
||||
getattr(self, "qkv_scale", None),
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
k_dequant_scale,
|
||||
v_dequant_scale,
|
||||
getattr(self, "cache_k_zp", None), # cache_k_zp
|
||||
getattr(self, "cache_v_zp", None), # cache_v_zp
|
||||
getattr(self, "linear_shift", None), # out_shifts
|
||||
getattr(self, "linear_smooth", None), # out_smooths
|
||||
kv_signal_data,
|
||||
self._fuse_kernel_compute_dtype,
|
||||
cache_quant_type_str, # cache_quant_type
|
||||
self.use_neox_rotary_style,
|
||||
self.rope_3d,
|
||||
kwargs.get("max_input_length", -1),
|
||||
self.inference_args.quant_max_bound,
|
||||
self.inference_args.quant_min_bound,
|
||||
self.out_scale, # out_linear_in_scale
|
||||
kwargs.get("encoder_block_shape_q", 64),
|
||||
kwargs.get("decoder_block_shape_q", 16),
|
||||
kwargs.get("max_partition_size", 32768),
|
||||
kwargs.get("encoder_max_partition_size", 32768),
|
||||
self.inference_args.speculate_max_draft_token_num +
|
||||
1, # speculate_max_draft_token_num
|
||||
True, # causal
|
||||
self.inference_args.speculate_method
|
||||
is not None, # speculate_decoder
|
||||
)[0]
|
||||
else:
|
||||
out = paddle.incubate.nn.functional.block_multihead_attention(
|
||||
qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kwargs.get("seq_lens_encoder", None),
|
||||
kwargs.get("seq_lens_decoder", None),
|
||||
kwargs.get("seq_lens_this_time", None),
|
||||
kwargs.get("padding_offsets", None),
|
||||
kwargs.get("cum_offsets", None),
|
||||
kwargs.get("cu_seqlens_q", None),
|
||||
kwargs.get("cu_seqlens_k", None),
|
||||
kwargs.get("block_tables", None),
|
||||
pre_key_cache,
|
||||
pre_value_cache,
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
k_dequant_scale,
|
||||
v_dequant_scale,
|
||||
getattr(self, "qkv_scale", None),
|
||||
getattr(self, "qkv_bias", None),
|
||||
getattr(self, "linear_shift", None),
|
||||
getattr(self, "linear_smooth", None),
|
||||
kwargs.get("max_enc_len_this_time", None),
|
||||
kwargs.get("max_dec_len_this_time", None),
|
||||
rotary_embs,
|
||||
attn_mask,
|
||||
None, # tgt_mask
|
||||
kwargs.get("max_input_length", -1),
|
||||
kwargs.get("block_size", 64),
|
||||
self.use_neox_rotary_style,
|
||||
self.inference_args.use_dynamic_cachekv_quant,
|
||||
quant_round_type=self.inference_args.quant_round_type,
|
||||
quant_max_bound=self.inference_args.quant_max_bound,
|
||||
quant_min_bound=self.inference_args.quant_min_bound,
|
||||
out_scale=self.out_scale,
|
||||
compute_dtype=self._fuse_kernel_compute_dtype,
|
||||
rope_theta=self.rope_theta,
|
||||
)[0]
|
||||
|
||||
return out
|
@@ -20,10 +20,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.worker.model_runner import ForwardMeta
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionMetadata(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
@@ -42,7 +48,7 @@ class AttentionBackend(ABC):
|
||||
qkv: paddle.Tensor,
|
||||
layer: paddle.nn.Layer,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Run a forward.
|
||||
args:
|
||||
@@ -88,7 +94,7 @@ class AttentionBackend(ABC):
|
||||
qkv: paddle.Tensor,
|
||||
layer: paddle.nn.Layer,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
) -> paddle.Tensor:
|
||||
"""Run a forward for mix."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -100,7 +106,7 @@ class AttentionBackend(ABC):
|
||||
qkv: paddle.Tensor,
|
||||
layer: paddle.nn.Layer,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
) -> paddle.Tensor:
|
||||
"""Run a forward for decode."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -112,6 +118,6 @@ class AttentionBackend(ABC):
|
||||
qkv: paddle.Tensor,
|
||||
layer: paddle.nn.Layer,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
) -> paddle.Tensor:
|
||||
"""Run a forward for extend."""
|
||||
raise NotImplementedError()
|
||||
|
@@ -1,4 +1,3 @@
|
||||
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
@@ -16,15 +15,14 @@
|
||||
|
||||
"""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import paddle
|
||||
from paddle.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import AttentionBackend
|
||||
from fastdeploy.worker.model_runner import ForwardMeta, ForwardMode
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import \
|
||||
AttentionBackend
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
||||
|
||||
class PaddleNativeAttnBackend(AttentionBackend):
|
||||
@@ -33,10 +31,8 @@ class PaddleNativeAttnBackend(AttentionBackend):
|
||||
Which is used only for testing purpose.
|
||||
"""
|
||||
|
||||
def __init__(self, device):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.forward_metadata = None
|
||||
self.device = device
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Init the metadata for a forward pass."""
|
||||
@@ -53,8 +49,8 @@ class PaddleNativeAttnBackend(AttentionBackend):
|
||||
seq_lens: paddle.Tensor,
|
||||
extend_prefix_lens: paddle.Tensor,
|
||||
extend_seq_lens: paddle.Tensor,
|
||||
causal=False,
|
||||
):
|
||||
causal: bool = False,
|
||||
) -> paddle.Tensor:
|
||||
"""Run the extend forward by using paddle native sdpa op.
|
||||
|
||||
Args:
|
||||
@@ -111,18 +107,14 @@ class PaddleNativeAttnBackend(AttentionBackend):
|
||||
per_req_value = v_cache[per_req_tokens].transpose(
|
||||
[query.dim() - 2, 0])
|
||||
|
||||
per_req_out_redudant = (
|
||||
scaled_dot_product_attention(
|
||||
per_req_query_redudant.unsqueeze(0),
|
||||
per_req_key.unsqueeze(0),
|
||||
per_req_value.unsqueeze(0),
|
||||
is_causal=causal,
|
||||
)
|
||||
.squeeze(0)
|
||||
.transpose([query.dim() - 2, 0])
|
||||
)
|
||||
output[start_q:end_q, :,
|
||||
:] = per_req_out_redudant[prefill_seq_len_q:, :, :]
|
||||
per_req_out_redudant = (scaled_dot_product_attention(
|
||||
per_req_query_redudant.unsqueeze(0),
|
||||
per_req_key.unsqueeze(0),
|
||||
per_req_value.unsqueeze(0),
|
||||
is_causal=causal,
|
||||
).squeeze(0).transpose([query.dim() - 2, 0]))
|
||||
output[start_q:end_q, :, :] = per_req_out_redudant[
|
||||
prefill_seq_len_q:, :, :]
|
||||
start_q, start_kv = end_q, end_kv
|
||||
return output
|
||||
|
||||
@@ -132,7 +124,7 @@ class PaddleNativeAttnBackend(AttentionBackend):
|
||||
key: paddle.Tensor,
|
||||
value: paddle.Tensor,
|
||||
is_causal: bool = False,
|
||||
):
|
||||
) -> paddle.Tensor:
|
||||
"""Paddle implementation of scaled dot-product attention."""
|
||||
# query, key, value shape: [batch_size, num_heads, seq_len, head_size]
|
||||
d_k = query.shape[-1]
|
||||
@@ -159,8 +151,8 @@ class PaddleNativeAttnBackend(AttentionBackend):
|
||||
req_to_token: paddle.Tensor,
|
||||
req_pool_indices: paddle.Tensor,
|
||||
seq_lens: paddle.Tensor,
|
||||
causal=False,
|
||||
):
|
||||
causal: bool = False,
|
||||
) -> paddle.Tensor:
|
||||
"""Run the decode forward by using paddle native sdpa op.
|
||||
|
||||
Args:
|
||||
@@ -203,16 +195,12 @@ class PaddleNativeAttnBackend(AttentionBackend):
|
||||
per_req_value = v_cache[per_req_tokens].transpose(
|
||||
[query.dim() - 2, 0])
|
||||
|
||||
per_req_out = (
|
||||
self._scaled_dot_product_attention(
|
||||
per_req_query.unsqueeze(0),
|
||||
per_req_key.unsqueeze(0),
|
||||
per_req_value.unsqueeze(0),
|
||||
is_causal=causal,
|
||||
)
|
||||
.squeeze(0)
|
||||
.transpose([query.dim() - 2, 0])
|
||||
)
|
||||
per_req_out = (self._scaled_dot_product_attention(
|
||||
per_req_query.unsqueeze(0),
|
||||
per_req_key.unsqueeze(0),
|
||||
per_req_value.unsqueeze(0),
|
||||
is_causal=causal,
|
||||
).squeeze(0).transpose([query.dim() - 2, 0]))
|
||||
output[start_q:end_q, :, :] = per_req_out
|
||||
start_q, start_kv = end_q, end_kv
|
||||
|
||||
@@ -220,31 +208,28 @@ class PaddleNativeAttnBackend(AttentionBackend):
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
layer: paddle.nn.Layer,
|
||||
forward_meta: ForwardMeta,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
save_kv_cache: bool = True,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Run the prefill and extend(prompt cache) attention forward by using paddle native sdpa op.
|
||||
"""
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty(
|
||||
(q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
(q.shape[0], layer.self.num_heads * layer.v_head_dim))
|
||||
else:
|
||||
o = paddle.empty_like(q)
|
||||
|
||||
if save_kv_cache:
|
||||
forward_meta.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_meta.out_cache_loc, k, v
|
||||
)
|
||||
layer, forward_meta.out_cache_loc, k, v)
|
||||
|
||||
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
||||
|
||||
q_ = q.view([-1, layer.tp_q_head_num, layer.qk_head_dim])
|
||||
o_ = o.view([-1, layer.tp_q_head_num, layer.v_head_dim])
|
||||
q_ = q.view([-1, layer.self.num_heads, layer.qk_head_dim])
|
||||
o_ = o.view([-1, layer.self.num_heads, layer.v_head_dim])
|
||||
|
||||
causal = True
|
||||
|
||||
@@ -264,31 +249,29 @@ class PaddleNativeAttnBackend(AttentionBackend):
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
layer: paddle.nn.Layer,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Run the decoding attention forward by using paddle native sdpa op.
|
||||
"""
|
||||
q = q.reshape([-1, layer.tp_q_head_num * layer.qk_head_dim])
|
||||
q = q.reshape([-1, layer.self.num_heads * layer.qk_head_dim])
|
||||
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty(
|
||||
(q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
(q.shape[0], layer.self.num_heads * layer.v_head_dim))
|
||||
else:
|
||||
o = paddle.empty_like(q)
|
||||
|
||||
forward_meta.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_meta.out_cache_loc, k, v
|
||||
)
|
||||
forward_meta.token_to_kv_pool.set_kv_buffer(layer,
|
||||
forward_meta.out_cache_loc,
|
||||
k, v)
|
||||
|
||||
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
||||
|
||||
q_ = q.view([-1, layer.tp_q_head_num, layer.qk_head_dim])
|
||||
o_ = o.view([-1, layer.tp_q_head_num, layer.v_head_dim])
|
||||
q_ = q.view([-1, layer.self.num_heads, layer.qk_head_dim])
|
||||
o_ = o.view([-1, layer.self.num_heads, layer.v_head_dim])
|
||||
|
||||
self._run_sdpa_forward_decode(
|
||||
q_,
|
||||
|
@@ -14,10 +14,13 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block
|
||||
from .append_attention import append_attention
|
||||
from .get_block_shape_and_split_kv_block import \
|
||||
get_block_shape_and_split_kv_block
|
||||
from .init_signal_layerwise import init_signal_layerwise
|
||||
from .open_shm_and_get_meta_signal import open_shm_and_get_meta_signal
|
||||
|
||||
__all__ = [
|
||||
"get_block_shape_and_split_kv_block",
|
||||
"append_attention"
|
||||
]
|
||||
"get_block_shape_and_split_kv_block", "append_attention",
|
||||
"open_shm_and_get_meta_signal", "init_signal_layerwise"
|
||||
]
|
||||
|
@@ -14,10 +14,16 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
append_attention as append_attention_gpu
|
||||
|
||||
|
||||
def append_attention(
|
||||
qkv: paddle.Tensor,
|
||||
@@ -68,14 +74,12 @@ def append_attention(
|
||||
speculate_max_draft_token_num: int = 1,
|
||||
causal: bool = True,
|
||||
speculate_decoder: bool = False,
|
||||
):
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Args:
|
||||
Returns:
|
||||
append_attention
|
||||
"""
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import append_attention
|
||||
out = append_attention(
|
||||
out = append_attention_gpu(
|
||||
qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
|
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
def init_signal_layerwise(
|
||||
kv_signal_metadata: paddle.Tensor,
|
||||
layer_id: int = 0,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
init_signal_layerwise
|
||||
"""
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import init_signal_layerwise
|
||||
out = init_signal_layerwise(kv_signal_metadata, layer_id)
|
||||
return out
|
||||
else:
|
||||
raise NotImplementedError()
|
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
import paddle
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
def open_shm_and_get_meta_signal(
|
||||
rank: int = 0,
|
||||
device_id: int = 0,
|
||||
keep_pd_step_flag: bool = False,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
open_shm_and_get_meta_signal
|
||||
"""
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
open_shm_and_get_meta_signal
|
||||
out = open_shm_and_get_meta_signal(rank, device_id, keep_pd_step_flag)
|
||||
return out
|
||||
else:
|
||||
raise NotImplementedError()
|
188
fastdeploy/model_executor/layers/attention/xpu_attn_backend.py
Normal file
188
fastdeploy/model_executor/layers/attention/xpu_attn_backend.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
init_signal_layerwise, open_shm_and_get_meta_signal)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from paddle._typing.dtype_like import _DTypeLiteral
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
||||
|
||||
@dataclass
|
||||
class XPUAttentionMetadata(AttentionMetadata):
|
||||
"""
|
||||
XPUAttentionMetadata
|
||||
"""
|
||||
max_len_kv: paddle.Tensor = None
|
||||
set_max_lengths: int = -1
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
decoder_batch_ids: paddle.Tensor = None
|
||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
decoder_num_blocks: paddle.Tensor = None
|
||||
|
||||
_dtype: _DTypeLiteral = paddle.bfloat16
|
||||
encoder_max_partition_size: int = 32768
|
||||
max_partition_size: int = 32768
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
decoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
|
||||
# pd_disaggregation
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
|
||||
|
||||
|
||||
class XPUAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
XPUAttentionBackend backend implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
|
||||
head_dim: int):
|
||||
"""
|
||||
XPUAttentionBackend __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: XPUAttentionMetadata = None
|
||||
# TODO(gongshaotian): Use fd_config parameters in the correct location
|
||||
self.block_size: int = fd_config.parallel_config.block_size
|
||||
self.max_seq_len: int = fd_config.parallel_config.max_model_len
|
||||
self.rope_theta: float = (10000.0
|
||||
if fd_config.model_config.rope_theta is None
|
||||
else fd_config.model_config.rope_theta)
|
||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
|
||||
self.causal: bool = getattr(fd_config.model_config, "causal", True)
|
||||
# self.speculate_method = fd_config.parallel_config.speculate_method
|
||||
# self.use_speculate = self.speculate_method is not None
|
||||
# self.speculate_max_draft_token_num = fd_config.parallel_config.speculate_max_draft_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
self.kv_num_heads: int = kv_num_heads
|
||||
self.num_heads: int = num_heads
|
||||
self.head_dim: int = head_dim
|
||||
self.num_layers: int = fd_config.model_config.num_layers
|
||||
|
||||
# pd_disaggregation
|
||||
self.use_pd_disaggregation: int = int(
|
||||
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = XPUAttentionMetadata()
|
||||
metadata.encoder_block_shape_q = 64
|
||||
metadata.decoder_block_shape_q = 16
|
||||
metadata.max_partition_size = 32768
|
||||
metadata.encoder_max_partition_size = 32768
|
||||
metadata._dtype = paddle.get_default_dtype()
|
||||
if metadata._dtype == "bfloat16":
|
||||
metadata._fuse_kernel_compute_dtype = "bf16"
|
||||
elif metadata._dtype == "float16":
|
||||
metadata._fuse_kernel_compute_dtype = "fp16"
|
||||
elif metadata._dtype == "float32":
|
||||
metadata._fuse_kernel_compute_dtype = "fp32"
|
||||
metadata.block_tables = forward_meta.block_tables
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.attn_mask = forward_meta.attn_mask
|
||||
metadata.pre_caches_length = forward_meta.pre_caches_length
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
|
||||
self.rank, self.keep_pd_step_flag)
|
||||
self.attention_metadata: AttentionMetadata = metadata
|
||||
|
||||
def get_attntion_meta(self) -> AttentionMetadata:
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
) -> Tuple[int, int, int, int]:
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
"""
|
||||
return (max_num_blocks, self.kv_num_heads, self.block_size,
|
||||
self.head_dim)
|
||||
|
||||
def forward_mixed(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
forward_mixed
|
||||
"""
|
||||
metadata = self.attention_metadata
|
||||
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
layer.layer_id + self.start_layer_index)
|
||||
|
||||
k_quant_scale = getattr(layer, "cache_k_scale", None)
|
||||
v_quant_scale = getattr(layer, "cache_v_scale", None)
|
||||
|
||||
from fastdeploy.model_executor.ops.xpu import block_attn
|
||||
res = block_attn(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
forward_meta.cum_offsets,
|
||||
metadata.rotary_embs,
|
||||
metadata.block_tables,
|
||||
None,
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
forward_meta.enc_batch,
|
||||
forward_meta.dec_batch,
|
||||
forward_meta.total_enc_len,
|
||||
forward_meta.encoder_seq_lod_cpu,
|
||||
forward_meta.encoder_batch_map_cpu,
|
||||
forward_meta.decoder_context_len_cpu,
|
||||
forward_meta.decoder_batch_map_cpu,
|
||||
)
|
||||
return res
|
@@ -16,6 +16,6 @@
|
||||
xpu backend methods
|
||||
"""
|
||||
|
||||
from .quantization.weight_only import XPUWeightOnlyLinearMethod
|
||||
from .quantization.weight_only import XPUWeightOnlyLinearMethod, XPUWeightOnlyMoEMethod
|
||||
|
||||
__all__ = ['XPUWeightOnlyLinearMethod']
|
||||
__all__ = ['XPUWeightOnlyLinearMethod', 'XPUWeightOnlyMoEMethod']
|
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
xpu quantization methods
|
||||
"""
|
@@ -13,15 +13,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from .utils import xpu_quant_weight
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import \
|
||||
QuantMethodBase
|
||||
from fastdeploy.model_executor.layers.quantization.weight_only import (
|
||||
WeightOnlyConfig, WeightOnlyLinearMethod)
|
||||
from fastdeploy.model_executor.ops.xpu import weight_quantize_xpu
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase
|
||||
from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig, WeightOnlyLinearMethod
|
||||
|
||||
class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
"""
|
||||
@@ -34,12 +37,133 @@ class XPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
) -> None:
|
||||
super().__init__(quant_config)
|
||||
|
||||
def process_loaded_weights(self, layer, weight) -> None:
|
||||
def create_weights(self, layer: nn.Layer) -> None:
|
||||
"""
|
||||
Create weights for linear layer on XPU
|
||||
"""
|
||||
layer.linear_weight_shape.reverse()
|
||||
if self.quant_config.name() == "weight_only_int4":
|
||||
layer.linear_weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
linear_weight_scale_shape = [layer.embed_dim]
|
||||
if hasattr(layer, "linear_weight_shape"):
|
||||
if isinstance(layer.linear_weight_shape, list):
|
||||
layer_weight_shape = layer.linear_weight_shape
|
||||
linear_weight_scale_shape = layer_weight_shape[:1]
|
||||
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
shape=linear_weight_scale_shape,
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer,
|
||||
weight: paddle.Tensor) -> None:
|
||||
"""
|
||||
loaded_weights using xpu special quantization
|
||||
"""
|
||||
quanted_weight_tensor, weight_scale_tensor = xpu_quant_weight(
|
||||
weight.cpu().numpy())
|
||||
layer.linear_weight.set_value(quanted_weight_tensor)
|
||||
layer.linear_weight_scale.set_value(
|
||||
weight_scale_tensor.astype(paddle.get_default_dtype()))
|
||||
quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu(
|
||||
weight, self.quant_config.algo, -1, -1)
|
||||
layer.linear_weight.set_value(
|
||||
paddle.transpose(quanted_weight_tensor, [1, 0]))
|
||||
layer.linear_weight_scale.set_value(weight_scale_tensor)
|
||||
|
||||
|
||||
class XPUWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
XPU Fused MoE Method.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: WeightOnlyConfig,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.quant_config = quant_config
|
||||
self.moe_quant_type = self.quant_config.algo
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict: Dict[str,
|
||||
paddle.Tensor]):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(ffn1_weights) == layer.num_local_experts
|
||||
assert len(ffn2_weights) == layer.num_local_experts
|
||||
assert ffn1_weights[0].shape == [
|
||||
layer.hidden_size, layer.moe_intermediate_size * 2
|
||||
]
|
||||
assert ffn2_weights[0].shape == [
|
||||
layer.moe_intermediate_size, layer.hidden_size
|
||||
]
|
||||
|
||||
added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
|
||||
added_scale_attrs = ["moe_ffn1_weight_scale", "moe_ffn2_weight_scale"]
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
|
||||
weight_name = added_weight_attrs[idx]
|
||||
scale_name = added_scale_attrs[idx]
|
||||
|
||||
weight_list = []
|
||||
weight_scale_list = []
|
||||
for i in range(layer.num_local_experts):
|
||||
quant_weight, scale = weight_quantize_xpu(
|
||||
weight_tensor[i], self.moe_quant_type, -1,
|
||||
-1) # weight is [k,n]
|
||||
weight_list.append(quant_weight.transpose(
|
||||
[1, 0])) # transpose weight to [n,k]
|
||||
weight_scale_list.append(scale)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
setattr(
|
||||
layer, weight_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight.shape,
|
||||
dtype=quanted_weight.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
))
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
|
||||
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
|
||||
setattr(
|
||||
layer, scale_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight_scale.shape,
|
||||
dtype=quanted_weight_scale.dtype,
|
||||
))
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
XPU compute Fused MoE.
|
||||
"""
|
||||
from fastdeploy.model_executor.ops.xpu import xpu_moe_layer
|
||||
|
||||
fused_moe_out = xpu_moe_layer(
|
||||
x,
|
||||
layer.gate_weight.transpose([1, 0]),
|
||||
layer.gate_correction_bias,
|
||||
layer.moe_ffn1_weight,
|
||||
layer.moe_ffn2_weight,
|
||||
None, # ffn1 bias
|
||||
None, # ffn2 bias
|
||||
(layer.moe_ffn1_weight_scale
|
||||
if hasattr(layer, "moe_ffn1_weight_scale") else None),
|
||||
(layer.moe_ffn2_weight_scale
|
||||
if hasattr(layer, "moe_ffn2_weight_scale") else None),
|
||||
(layer.moe_ffn2_in_scale
|
||||
if hasattr(layer, "moe_ffn2_in_scale") else None),
|
||||
self.moe_quant_type,
|
||||
layer.top_k,
|
||||
False, # moe group, used in deepseek
|
||||
)
|
||||
if layer.tp_size > 1:
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
|
||||
return fused_moe_out
|
||||
|
@@ -16,11 +16,13 @@
|
||||
!! This file will be deleted after the platform is fully functional
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
|
||||
def xpu_clip_and_round(x):
|
||||
def xpu_clip_and_round(x: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Clip and round the input array to the range [-127, 127] and convert to int8.
|
||||
|
||||
@@ -33,7 +35,8 @@ def xpu_clip_and_round(x):
|
||||
return np.clip(np.around(x), -127, 127).astype("int8")
|
||||
|
||||
|
||||
def xpu_quant_qkv_weight(weight_np):
|
||||
def xpu_quant_qkv_weight(
|
||||
weight_np: np.ndarray) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
Quantize the query, key, and value weights for the Transformer model.
|
||||
|
||||
@@ -61,7 +64,8 @@ def xpu_quant_qkv_weight(weight_np):
|
||||
return quanted_weight, weight_scales
|
||||
|
||||
|
||||
def xpu_quant_weight(weight_np):
|
||||
def xpu_quant_weight(
|
||||
weight_np: np.ndarray) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
Quantize the weight tensor for XPU devices.
|
||||
|
||||
|
@@ -28,7 +28,7 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config,
|
||||
fd_config,
|
||||
num_embeddings,
|
||||
embedding_dim=768,
|
||||
params_dtype="bfloat16",
|
||||
@@ -38,7 +38,7 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
Initialize the VocabParallelEmbedding layer for the model.
|
||||
|
||||
Args:
|
||||
llm_config (LLMConfig): Arguments related to inference, containing
|
||||
fd_config (FDConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
num_embeddings : vocabulary size.
|
||||
@@ -48,21 +48,21 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
you can give it any name you like.
|
||||
"""
|
||||
super().__init__()
|
||||
self.fd_config = fd_config
|
||||
hcg = fleet.get_hybrid_communicate_group()
|
||||
self.mp_rank = hcg.get_model_parallel_rank()
|
||||
self.column_cut = llm_config.parallel_config.column_cut
|
||||
self.column_cut = fd_config.parallel_config.column_cut
|
||||
self.world_size = hcg.get_model_parallel_world_size()
|
||||
self.ring_id = hcg.get_model_parallel_group().id
|
||||
self.use_rope = llm_config.model_config.use_rope
|
||||
self.rope_head_dim = llm_config.model_config.rope_head_dim
|
||||
self.use_ep = llm_config.parallel_config.use_ep
|
||||
self.hidden_dropout_prob = llm_config.model_config.hidden_dropout_prob
|
||||
self.initializer_range = llm_config.model_config.initializer_range
|
||||
self.weight_sharing = llm_config.model_config.weight_sharing
|
||||
self.sequence_parallel = llm_config.parallel_config.sequence_parallel
|
||||
self.weight_sharing_add_bias = llm_config.model_config.weight_sharing_add_bias
|
||||
self.max_position_embeddings = llm_config.model_config.max_position_embeddings
|
||||
self.freeze_embedding = llm_config.model_config.freeze_embedding
|
||||
self.use_rope = fd_config.model_config.use_rope
|
||||
self.rope_head_dim = fd_config.model_config.rope_head_dim
|
||||
self.use_ep = fd_config.parallel_config.use_ep
|
||||
self.hidden_dropout_prob = fd_config.model_config.hidden_dropout_prob
|
||||
self.initializer_range = fd_config.model_config.initializer_range
|
||||
self.sequence_parallel = fd_config.parallel_config.sequence_parallel
|
||||
self.max_position_embeddings = fd_config.model_config.max_position_embeddings
|
||||
self.freeze_embedding = fd_config.model_config.freeze_embedding
|
||||
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings
|
||||
|
||||
if self.use_ep:
|
||||
self.word_embeddings = nn.Embedding(
|
||||
@@ -78,8 +78,7 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
get_model_parallel_group(),
|
||||
weight_attr=paddle.ParamAttr(
|
||||
initializer=nn.initializer.Normal(
|
||||
mean=0.0, std=self.initializer_range),
|
||||
),
|
||||
mean=0.0, std=self.initializer_range), ),
|
||||
)
|
||||
else:
|
||||
# column cut embedding
|
||||
@@ -87,6 +86,7 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
num_embeddings,
|
||||
embedding_dim // self.world_size,
|
||||
)
|
||||
|
||||
self.word_embeddings.weight.is_distributed = True
|
||||
self.word_embeddings.weight.split_axis = 1
|
||||
|
||||
@@ -94,34 +94,12 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
self.position_embeddings = nn.Embedding(
|
||||
self.max_position_embeddings,
|
||||
embedding_dim,
|
||||
weight_attr=paddle.ParamAttr(
|
||||
initializer=nn.initializer.Normal(
|
||||
mean=0.0, std=self.initializer_range),
|
||||
),
|
||||
weight_attr=paddle.ParamAttr(initializer=nn.initializer.Normal(
|
||||
mean=0.0, std=self.initializer_range), ),
|
||||
)
|
||||
|
||||
self.prefix = prefix
|
||||
|
||||
if self.weight_sharing and self.weight_sharing_add_bias:
|
||||
assert num_embeddings % self.world_size == 0
|
||||
if self.use_ep:
|
||||
self.bias = self.create_parameter(
|
||||
shape=[num_embeddings],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
attr=paddle.ParamAttr(
|
||||
initializer=paddle.nn.initializer.Constant(value=0.0),
|
||||
),
|
||||
is_bias=True,
|
||||
)
|
||||
else:
|
||||
self.bias = self.create_parameter(
|
||||
shape=[num_embeddings // self.world_size],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
attr=mask_lm_out_bias_attr,
|
||||
is_bias=True,
|
||||
)
|
||||
self.bias.is_distributed = True
|
||||
|
||||
if self.freeze_embedding:
|
||||
self.word_embeddings.weight.learning_rate = 0.0
|
||||
if not self.use_rope:
|
||||
@@ -138,9 +116,14 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
Args:
|
||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||
"""
|
||||
self.word_embeddings.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.prefix + ".weight")).astype(
|
||||
paddle.get_default_dtype()))
|
||||
if self.tie_word_embeddings:
|
||||
self.word_embeddings.weight.set_value(
|
||||
get_tensor(state_dict[self.prefix + ".weight"]).astype(
|
||||
paddle.get_default_dtype()))
|
||||
else:
|
||||
self.word_embeddings.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.prefix + ".weight")).astype(
|
||||
paddle.get_default_dtype()))
|
||||
|
||||
def forward(self, ids_remove_padding=None):
|
||||
"""
|
||||
|
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from paddlenlp.utils.log import logger
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
|
@@ -14,29 +14,25 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import fastdeploy
|
||||
from paddlenlp.utils.log import logger
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from .utils import _set_var_distributed, divide, get_tensor
|
||||
|
||||
import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm
|
||||
|
||||
|
||||
class LinearBase(nn.Layer):
|
||||
"""
|
||||
LinearBase Layer
|
||||
LinearBase Layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config,
|
||||
fd_config: FDConfig,
|
||||
prefix: str = "",
|
||||
input_size: int = None,
|
||||
output_size: int = None,
|
||||
@@ -48,31 +44,26 @@ class LinearBase(nn.Layer):
|
||||
Initializes a linear layer and provides additional parameters required for inference and quantization.
|
||||
|
||||
Args:
|
||||
llm_config (LLMConfig): Inference-related parameters containing attributes such as
|
||||
weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
fd_config (FDConfig): Inference-related parameters.
|
||||
prefix (str): Unique name of the layer, used to name internal attributes.
|
||||
Can be arbitrarily named.
|
||||
input_size (int, optional): Number of input features. Defaults to None.
|
||||
output_size (int, optional): Number of output features. Defaults to None.
|
||||
weight_key (Any, optional): Key for weights. Defaults to None.
|
||||
bias_key (Any, optional): Key for biases. Defaults to None.
|
||||
skip_quant (bool, optional): Whether to skip quantization. Defaults to False.
|
||||
input_size (int): Number of input features. Defaults to None.
|
||||
output_size (int): Number of output features. Defaults to None.
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Raised if the current platform is not a CUDA platform.
|
||||
"""
|
||||
super().__init__()
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.is_cuda() or current_platform.is_xpu():
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.llm_config = llm_config
|
||||
self.fd_config = fd_config
|
||||
self.skip_quant = skip_quant
|
||||
self.use_smooth_quant = llm_config.model_config.use_smooth_quant
|
||||
self.weight_dtype = llm_config.model_config.weight_dtype
|
||||
self.act_dtype = llm_config.model_config.act_dtype
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.with_bias = with_bias
|
||||
@@ -86,61 +77,27 @@ class LinearBase(nn.Layer):
|
||||
self.out_scale_key = f"{prefix}.out_scale"
|
||||
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
|
||||
if llm_config.quant_config:
|
||||
self.quant_method = llm_config.quant_config.get_quant_method(self)
|
||||
self.use_offline_quant = llm_config.tmp_config.use_offline_quant
|
||||
|
||||
def is_y_transposed(self):
|
||||
"""
|
||||
Returns whether the y tensor should be transposed for inference.
|
||||
Args:
|
||||
None.
|
||||
|
||||
Returns:
|
||||
bool, whether the y tensor should be transposed for inference.
|
||||
"""
|
||||
if self.weight_dtype == "int4":
|
||||
return True
|
||||
if self.weight_dtype == "int8":
|
||||
return True
|
||||
if "float8" in self.weight_dtype:
|
||||
return True
|
||||
# bf16/fp16/fp32 y is not transposed
|
||||
return False
|
||||
|
||||
def init_weight_shape(self, trans=False):
|
||||
"""
|
||||
Initialize the weight shape for the first feedforward network layer.
|
||||
|
||||
Args:
|
||||
trans (bool, optional): Whether to transpose the weight shape.
|
||||
Defaults to False. If True, the shape will be reversed.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
self.weight_dtype = self._dtype
|
||||
self.linear_weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
]
|
||||
if trans:
|
||||
self.linear_weight_shape.reverse()
|
||||
if self.use_smooth_quant:
|
||||
self.linear_shift_shape = [self.output_size]
|
||||
self.linear_smooth_shape = [self.output_size]
|
||||
if self.weight_dtype == "int4":
|
||||
self.linear_weight_shape[0] //= 2
|
||||
if fd_config.quant_config:
|
||||
self.quant_method = fd_config.quant_config.get_quant_method(self)
|
||||
if fd_config.model_config.is_quantized:
|
||||
self.weight_key = f"{prefix}.quant_weight"
|
||||
self.weight_scale_key = f"{prefix}.weight_scale"
|
||||
self.act_scale_key = f"{prefix}.activation_scale"
|
||||
|
||||
def init_weight(self):
|
||||
"""
|
||||
Initialize the weights and biases.
|
||||
"""
|
||||
self.init_weight_shape(self.is_y_transposed())
|
||||
|
||||
if self.skip_quant:
|
||||
self.weight_dtype = self._dtype
|
||||
self.linear_weight = self.create_parameter(
|
||||
shape=self.linear_weight_shape,
|
||||
dtype=self.get_weight_create_dtype(),
|
||||
dtype=self.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
@@ -156,117 +113,57 @@ class LinearBase(nn.Layer):
|
||||
# smooth quant
|
||||
self.linear_shift = None
|
||||
self.linear_smooth = None
|
||||
if self.use_smooth_quant:
|
||||
self.linear_shift = self.create_parameter(
|
||||
shape=self.linear_shift_shape,
|
||||
dtype=self._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
self.linear_smooth = self.create_parameter(
|
||||
shape=self.linear_smooth_shape,
|
||||
dtype=self._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
def get_weight_create_dtype(self):
|
||||
def load_prequant_weight(self, state_dict: dict):
|
||||
"""
|
||||
Get the data type for creating weights based on quantization settings.
|
||||
Load the prequantized weight from the state dictionary.
|
||||
|
||||
Args:
|
||||
self (object): The instance of the class where this method is defined.
|
||||
|
||||
Returns:
|
||||
str: The data type for creating weights. It depends on the quantization settings:
|
||||
- If `self.skip_quant` is True, returns the original data type `self._dtype`.
|
||||
- If `self.weight_dtype` is "int4", returns "int8" to ensure compatibility or optimization.
|
||||
- Otherwise, returns the specified weight data type `self.weight_dtype`.
|
||||
state_dict (dict): A dictionary containing the prequantized weights and scales.
|
||||
"""
|
||||
if self.skip_quant:
|
||||
return self._dtype
|
||||
if self.weight_dtype == "int4":
|
||||
return "int8"
|
||||
# TODO(wangzhe24) create_parameter not support FP8
|
||||
if "float8" in self.weight_dtype:
|
||||
return self._dtype
|
||||
return self.weight_dtype
|
||||
self.quant_method.process_prequanted_weights(self, state_dict)
|
||||
|
||||
def load_weight(self, state_dict: dict):
|
||||
"""
|
||||
Load the weight from the state dictionary.
|
||||
|
||||
def load_offline_quant_state_dict(self, quant_weight, quant_scale=None):
|
||||
Args:
|
||||
state_dict (dict): A dictionary containing the weights
|
||||
"""
|
||||
Load offline the checkpoint state dictionary into the layer.
|
||||
"""
|
||||
if quant_scale is None:
|
||||
if "float8" in self.weight_dtype:
|
||||
self.linear_weight.copy_(quant_weight, False)
|
||||
else:
|
||||
self.linear_weight.set_value(quant_weight)
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
|
||||
|
||||
if self.fd_config.quant_config:
|
||||
self.quant_method.process_loaded_weights(self, weight_tensor)
|
||||
else:
|
||||
if self.inference_args.weight_block_size[0] != -1:
|
||||
self.linear_weight.copy_(quant_weight.view(paddle.float8_e4m3fn), False)
|
||||
else:
|
||||
self.linear_weight.set_value(quant_weight)
|
||||
self.linear_weight_scale.set_value(quant_scale)
|
||||
self.linear_weight.set_value(weight_tensor)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
Args:
|
||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||
"""
|
||||
if self.use_offline_quant:
|
||||
self.load_offline_quant_state_dict(
|
||||
quant_weight=get_tensor(
|
||||
state_dict.pop(self.weight_key + ".quant_weight")
|
||||
),
|
||||
quant_scale=get_tensor(
|
||||
state_dict.pop(self.weight_key + ".quant_scale")
|
||||
),
|
||||
)
|
||||
# weight
|
||||
self.state_dict = state_dict
|
||||
assert self.weight_key is not None, 'weight_key should not be None.'
|
||||
if self.fd_config.model_config.is_quantized:
|
||||
self.load_prequant_weight(state_dict)
|
||||
else:
|
||||
# weight
|
||||
assert self.weight_key is not None, 'weight_key should not be None.'
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
|
||||
|
||||
if self.llm_config.quant_config:
|
||||
self.quant_method.process_loaded_weights(self, weight_tensor)
|
||||
else:
|
||||
self.linear_weight.set_value(weight_tensor)
|
||||
self.load_weight(state_dict)
|
||||
|
||||
# bias
|
||||
if self.with_bias:
|
||||
bias_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.bias_key)))
|
||||
bias_tensor = paddle.to_tensor(
|
||||
get_tensor(state_dict.pop(self.bias_key)))
|
||||
self.linear_bias.set_value(bias_tensor)
|
||||
|
||||
# smooth quant
|
||||
if self.use_smooth_quant:
|
||||
if self.shift_key in state_dict:
|
||||
shift_tensor = get_tensor(state_dict.pop(self.shift_key)).astype(
|
||||
paddle.get_default_dtype()
|
||||
)
|
||||
else:
|
||||
shift_tensor = paddle.zeros(
|
||||
shape=self.linear_shift_shape,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
self.linear_shift.set_value(shift_tensor)
|
||||
if self.smooth_key in state_dict:
|
||||
smooth_tensor = get_tensor(state_dict.pop(self.smooth_key)).astype(
|
||||
paddle.get_default_dtype()
|
||||
)
|
||||
else:
|
||||
smooth_tensor = paddle.ones(
|
||||
shape=[self.linear_smooth_shape],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
self.linear_smooth.set_value(smooth_tensor)
|
||||
|
||||
def forward_cuda(self, x):
|
||||
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
Forward function for ColumnParallelLinear.
|
||||
Forward function for Linear.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor to the ColumnParallelLinear layer.
|
||||
x (Tensor): Input tensor to the Linear.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor.
|
||||
@@ -274,22 +171,24 @@ class LinearBase(nn.Layer):
|
||||
Raises:
|
||||
NotImplementedError: If the weight dtype is not float8 or act dtype is not equal to weight dtype.
|
||||
"""
|
||||
if self.llm_config.quant_config:
|
||||
if self.fd_config.quant_config:
|
||||
linear_out = self.quant_method.apply(self, x)
|
||||
else:
|
||||
linear_out = paddle.matmul(x, self.linear_weight)
|
||||
if self.with_bias:
|
||||
linear_out = paddle.add(linear_out, self.linear_bias)
|
||||
|
||||
return linear_out
|
||||
|
||||
|
||||
class ReplicatedLinear(LinearBase):
|
||||
"""
|
||||
ReplicatedLinear Layer
|
||||
ReplicatedLinear Layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config,
|
||||
fd_config: FDConfig,
|
||||
prefix: str = "",
|
||||
input_size: int = None,
|
||||
output_size: int = None,
|
||||
@@ -298,74 +197,39 @@ class ReplicatedLinear(LinearBase):
|
||||
skip_quant: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize a linear layer with additional parameters for inference and quantization.
|
||||
Initializes a replicated linear layer.
|
||||
|
||||
Args:
|
||||
llm_config (LLMConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
prefix (str): Unique name of the layer, used for naming internal attributes,
|
||||
you can give it any name you like.
|
||||
layer_index (int): The index of the linear layer in the model
|
||||
|
||||
fd_config (FDConfig): Inference-related parameters.
|
||||
prefix (str): Unique name of the layer, used to name internal attributes.
|
||||
Can be arbitrarily named.
|
||||
input_size (int): Number of input features. Defaults to None.
|
||||
output_size (int): Number of output features. Defaults to None.
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
super().__init__(llm_config=llm_config,
|
||||
super().__init__(fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
with_bias=with_bias,
|
||||
add_bias=add_bias,
|
||||
skip_quant=skip_quant)
|
||||
self.nranks = llm_config.parallel_config.mp_size
|
||||
self.input_size = input_size
|
||||
self.init_weight()
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
def init_weight(self):
|
||||
"""
|
||||
Initialize the weights and biases.
|
||||
"""
|
||||
self.init_weight_shape(self.is_y_transposed())
|
||||
|
||||
self.linear_weight = self.create_parameter(
|
||||
shape=self.linear_weight_shape,
|
||||
dtype=self.get_weight_create_dtype(),
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
self.linear_bias = None
|
||||
if self.with_bias:
|
||||
self.linear_bias = self.create_parameter(
|
||||
shape=[self.output_size],
|
||||
dtype=self._dtype,
|
||||
is_bias=True,
|
||||
)
|
||||
|
||||
# smooth quant
|
||||
self.linear_shift = None
|
||||
self.linear_smooth = None
|
||||
if self.use_smooth_quant:
|
||||
self.linear_shift = self.create_parameter(
|
||||
shape=self.linear_shift_shape,
|
||||
dtype=self._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
self.linear_smooth = self.create_parameter(
|
||||
shape=self.linear_smooth_shape,
|
||||
dtype=self._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
"""
|
||||
ColumnParallelLinear Layer
|
||||
ColumnParallelLinear Layer.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its second dimension as A = [A_1, ..., A_p].
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config,
|
||||
fd_config: FDConfig,
|
||||
prefix: str = "",
|
||||
input_size: int = None,
|
||||
output_size: int = None,
|
||||
@@ -374,40 +238,45 @@ class ColumnParallelLinear(LinearBase):
|
||||
skip_quant: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize a linear layer with additional parameters for inference and quantization.
|
||||
Initializes a linear layer and provides additional parameters required for inference and quantization.
|
||||
|
||||
Args:
|
||||
llm_config (LLMConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
prefix (str): Unique name of the layer, used for naming internal attributes,
|
||||
you can give it any name you like.
|
||||
layer_index (int): The index of the linear layer in the model
|
||||
|
||||
fd_config (FDConfig): Inference-related parameters.
|
||||
prefix (str): Unique name of the layer, used to name internal attributes.
|
||||
Can be arbitrarily named.
|
||||
input_size (int): Number of input features. Defaults to None.
|
||||
output_size (int): Number of output features. Defaults to None.
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
super().__init__(llm_config=llm_config,
|
||||
super().__init__(fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
with_bias=with_bias,
|
||||
add_bias=add_bias,
|
||||
skip_quant=skip_quant)
|
||||
self.nranks = llm_config.parallel_config.mp_size
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.input_size = input_size
|
||||
self.output_size = divide(output_size, self.nranks)
|
||||
self.linear_weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
]
|
||||
if fd_config.quant_config:
|
||||
self.quant_method.create_weights(self)
|
||||
self.init_weight()
|
||||
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
def init_weight(self):
|
||||
"""
|
||||
Initialize the weights and biases.
|
||||
"""
|
||||
self.init_weight_shape(self.is_y_transposed())
|
||||
|
||||
if self.skip_quant:
|
||||
self.weight_dtype = self._dtype
|
||||
self.linear_weight = self.create_parameter(
|
||||
shape=self.linear_weight_shape,
|
||||
dtype=self.get_weight_create_dtype(),
|
||||
dtype=self.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
@@ -429,62 +298,51 @@ class ColumnParallelLinear(LinearBase):
|
||||
# smooth quant
|
||||
self.linear_shift = None
|
||||
self.linear_smooth = None
|
||||
if self.use_smooth_quant:
|
||||
self.linear_shift = self.create_parameter(
|
||||
shape=self.linear_shift_shape,
|
||||
dtype=self._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
self.linear_smooth = self.create_parameter(
|
||||
shape=self.linear_smooth_shape,
|
||||
dtype=self._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
|
||||
class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
"""
|
||||
MergedColumnParallelLinear Layer.
|
||||
|
||||
Similar to ColumnParallelLinear, but the weight matrix is concatenated
|
||||
along the output dimension. When the weight matrix is loaded, the
|
||||
different partitions are sharded separately.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config,
|
||||
prefix,
|
||||
with_bias=False,
|
||||
add_bias=False,
|
||||
activation="gelu",
|
||||
use_fast_ffn=False,
|
||||
skip_quant=False,
|
||||
fd_config: FDConfig,
|
||||
prefix: str,
|
||||
input_size: int = None,
|
||||
output_size: int = None,
|
||||
with_bias: bool = False,
|
||||
add_bias: bool = False,
|
||||
activation: str = "gelu",
|
||||
use_fast_ffn: bool = False,
|
||||
skip_quant: bool = False,
|
||||
):
|
||||
"""Packed linear layers with column parallelism.
|
||||
|
||||
"""
|
||||
Initialize the fused ffn1 Linear layer with given parameters.
|
||||
|
||||
Args:
|
||||
llm_config (LLMConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
|
||||
prefix (str): Unique name of the layer, used for naming weights and biases.
|
||||
weight_key (str): Key name of weight in the pdparams state dict.
|
||||
bias_key (str): Key name of bias in the pdparams state dict. Defaults to None, means no bias.
|
||||
with_bias (bool, optional): Whether to include bias term. Defaults to True.
|
||||
activation (str, optional): Activation function to use. Defaults to "gelu".
|
||||
use_fast_ffn (bool, optional): Whether to use a faster FFN implementation.
|
||||
fd_config (FDConfig): Inference-related parameters.
|
||||
prefix (str): Unique name of the layer, used to name internal attributes.
|
||||
Can be arbitrarily named.
|
||||
input_size (int): Number of input features. Defaults to None.
|
||||
output_size (int): Number of output features. Defaults to None.
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
|
||||
activation (str): Activation function to use. Defaults to "gelu".
|
||||
use_fast_ffn (bool): Whether to use a faster FFN implementation.
|
||||
Defaults to False.
|
||||
skip_quant (bool, optional): Whether to skip quantization steps. Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
self.use_fast_ffn = use_fast_ffn
|
||||
self.activation = activation
|
||||
self.embed_dim = llm_config.model_config.hidden_size
|
||||
self.dim_feedforward = llm_config.model_config.ffn_hidden_size
|
||||
self.nranks = llm_config.parallel_config.mp_size
|
||||
self.dim_feedforward_per_rank = divide(self.dim_feedforward,
|
||||
self.nranks)
|
||||
input_size = self.embed_dim
|
||||
output_size = self.dim_feedforward * 2
|
||||
super().__init__(llm_config=llm_config,
|
||||
self.embed_dim = fd_config.model_config.hidden_size
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
|
||||
super().__init__(fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
@@ -492,7 +350,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
add_bias=add_bias,
|
||||
skip_quant=skip_quant)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
@@ -542,47 +400,40 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
QKVParallelLinear Layer.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_config, prefix, with_bias=False, add_bias=True):
|
||||
def __init__(self, fd_config, prefix, with_bias=False, add_bias=True):
|
||||
"""
|
||||
Initialize the QKV Linear layer with given parameters.
|
||||
|
||||
Args:
|
||||
llm_config (LLMConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
|
||||
prefix (str): Unique name of the layer, used for naming weights and biases.
|
||||
weight_key (str): Key name of weight in the pdparams state dict.
|
||||
bias_key (str): Key name of bias in the pdparams state dict. Defaults to None, means no bias.
|
||||
with_bias (bool, optional): Whether to include bias term. Defaults to True.
|
||||
skip_quant (bool, optional): Whether to skip quantization steps. Defaults to False.
|
||||
fd_config (FDConfig): Inference-related parameters.
|
||||
prefix (str): Unique name of the layer, used to name internal attributes.
|
||||
Can be arbitrarily named.
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to True.
|
||||
"""
|
||||
self.num_heads = llm_config.model_config.num_attention_heads
|
||||
self.kv_num_heads = llm_config.model_config.num_key_value_heads
|
||||
self.embed_dim = llm_config.model_config.hidden_size
|
||||
self.head_dim = llm_config.model_config.head_dim
|
||||
self.nranks = llm_config.parallel_config.mp_size
|
||||
self.num_heads = fd_config.model_config.num_attention_heads
|
||||
self.kv_num_heads = fd_config.model_config.num_key_value_heads
|
||||
self.embed_dim = fd_config.model_config.hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
|
||||
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks)
|
||||
input_size = self.embed_dim
|
||||
output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim
|
||||
super().__init__(llm_config=llm_config,
|
||||
super().__init__(fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
with_bias=with_bias,
|
||||
add_bias=add_bias)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
def load_weight(self, state_dict: dict):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
Load the weight from the state dictionary.
|
||||
|
||||
Args:
|
||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||
state_dict (dict): A dictionary containing the weights
|
||||
"""
|
||||
# weight
|
||||
assert self.weight_key is not None, 'weight_key should not be None.'
|
||||
# qkv fused in disk
|
||||
if self.weight_key in state_dict.keys():
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
|
||||
else:
|
||||
@@ -601,11 +452,27 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
])
|
||||
weight_tensor = paddle.transpose(weight_tensor, perm=[1, 0])
|
||||
|
||||
if self.llm_config.quant_config:
|
||||
if self.fd_config.quant_config:
|
||||
self.quant_method.process_loaded_weights(self, weight_tensor)
|
||||
else:
|
||||
self.linear_weight.set_value(weight_tensor)
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
Args:
|
||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||
"""
|
||||
# weight
|
||||
assert self.weight_key is not None, 'weight_key should not be None.'
|
||||
# qkv fused in disk
|
||||
|
||||
if self.fd_config.model_config.is_quantized:
|
||||
self.load_prequant_weight(state_dict)
|
||||
else:
|
||||
self.load_weight(state_dict)
|
||||
|
||||
# bias
|
||||
if self.with_bias:
|
||||
if self.bias_key in state_dict.keys():
|
||||
@@ -622,38 +489,25 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1)
|
||||
self.linear_bias.set_value(qkv_bias)
|
||||
|
||||
# smooth quant
|
||||
if self.use_smooth_quant:
|
||||
if self.shift_key in state_dict:
|
||||
shift_tensor = get_tensor(state_dict.pop(self.shift_key)).astype(
|
||||
paddle.get_default_dtype()
|
||||
)
|
||||
else:
|
||||
shift_tensor = paddle.zeros(
|
||||
shape=self.linear_shift_shape,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
self.linear_shift.set_value(shift_tensor)
|
||||
if self.smooth_key in state_dict:
|
||||
smooth_tensor = get_tensor(state_dict.pop(self.smooth_key)).astype(
|
||||
paddle.get_default_dtype()
|
||||
)
|
||||
else:
|
||||
smooth_tensor = paddle.ones(
|
||||
shape=[self.linear_smooth_shape],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
self.linear_smooth.set_value(smooth_tensor)
|
||||
|
||||
|
||||
class RowParallelLinear(LinearBase):
|
||||
"""
|
||||
RowParallelLinear Layer
|
||||
RowParallelLinear Layer.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its first dimension and X along its second dimension as:
|
||||
- -
|
||||
| A_1 |
|
||||
| . |
|
||||
A = | . | X = [X_1, ..., X_p]
|
||||
| . |
|
||||
| A_p |
|
||||
- -
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config,
|
||||
fd_config: FDConfig,
|
||||
prefix: str = "",
|
||||
input_size: int = None,
|
||||
output_size: int = None,
|
||||
@@ -665,57 +519,50 @@ class RowParallelLinear(LinearBase):
|
||||
Initialize a linear layer with additional parameters for inference and quantization.
|
||||
|
||||
Args:
|
||||
llm_config (LLMConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
prefix (str): Unique name of the layer, used for naming internal attributes,
|
||||
you can give it any name you like.
|
||||
layer_index (int): The index of the linear layer in the model
|
||||
|
||||
fd_config (FDConfig): Inference-related parameters.
|
||||
prefix (str): Unique name of the layer, used to name internal attributes.
|
||||
Can be arbitrarily named.
|
||||
input_size (int): Number of input features. Defaults to None.
|
||||
output_size (int): Number of output features. Defaults to None.
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
super().__init__(llm_config=llm_config,
|
||||
super().__init__(fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
with_bias=with_bias,
|
||||
add_bias=add_bias,
|
||||
skip_quant=skip_quant)
|
||||
self.llm_config = llm_config
|
||||
self.fd_config = fd_config
|
||||
self.skip_quant = False
|
||||
self.use_smooth_quant = llm_config.model_config.use_smooth_quant
|
||||
self.weight_dtype = llm_config.model_config.weight_dtype
|
||||
self.act_dtype = llm_config.model_config.act_dtype
|
||||
self.nranks = llm_config.parallel_config.mp_size
|
||||
self.embed_dim = llm_config.model_config.hidden_size
|
||||
self.head_dim = llm_config.model_config.hidden_size // llm_config.model_config.num_attention_heads
|
||||
self.num_heads = llm_config.model_config.num_attention_heads // self.nranks
|
||||
self.dim_feedforward = llm_config.model_config.ffn_hidden_size // self.nranks
|
||||
self.with_bias = with_bias
|
||||
self.prefix = prefix
|
||||
self.shift_key = f"{prefix}.shift_bias"
|
||||
self.smooth_key = f"{prefix}.smooth_weight"
|
||||
self.weight_key = f"{prefix}.weight"
|
||||
self.bias_key = f"{prefix}.bias"
|
||||
self.weight_only_scale_key = f"{prefix}.weight_only_scale"
|
||||
self.out_scale_key = f"{prefix}.out_scale"
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.embed_dim = fd_config.model_config.hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.num_heads = fd_config.model_config.num_attention_heads // self.nranks
|
||||
|
||||
self.linear_weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
]
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
|
||||
if llm_config.quant_config:
|
||||
self.quant_method = llm_config.quant_config.get_quant_method(self)
|
||||
if fd_config.quant_config:
|
||||
self.quant_method = fd_config.quant_config.get_quant_method(self)
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
self.init_weight()
|
||||
|
||||
def init_weight(self):
|
||||
"""
|
||||
Initialize the weights and biases.
|
||||
"""
|
||||
self.init_weight_shape(self.is_y_transposed())
|
||||
if self.skip_quant:
|
||||
self.weight_dtype = self._dtype
|
||||
|
||||
self.linear_weight = self.create_parameter(
|
||||
shape=self.linear_weight_shape,
|
||||
dtype=self.get_weight_create_dtype(),
|
||||
dtype=self.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
@@ -735,27 +582,159 @@ class RowParallelLinear(LinearBase):
|
||||
# smooth quant
|
||||
self.linear_shift = None
|
||||
self.linear_smooth = None
|
||||
if self.use_smooth_quant:
|
||||
self.linear_shift = self.create_parameter(
|
||||
shape=self.linear_shift_shape,
|
||||
dtype=self._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
self.linear_smooth = self.create_parameter(
|
||||
shape=self.linear_smooth_shape,
|
||||
dtype=self._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
def forward_cuda(self, x):
|
||||
if self.llm_config.quant_config:
|
||||
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
if self.fd_config.quant_config:
|
||||
out = self.quant_method.apply(self, x)
|
||||
else:
|
||||
out = paddle.matmul(x, self.linear_weight)
|
||||
|
||||
if self.nranks > 1:
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
tensor_model_parallel_all_reduce(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class KVBatchLinear(LinearBase):
|
||||
"""
|
||||
KVBatchLinear Layer for handling combined KV projections with bmm.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
prefix: str = "",
|
||||
kv_lora_rank: int = None,
|
||||
num_attention_heads: int = None,
|
||||
qk_nope_head_dim: int = None,
|
||||
v_head_dim: int = None,
|
||||
with_bias: bool = False,
|
||||
skip_quant: bool = False,
|
||||
):
|
||||
"""
|
||||
Initializes a KV batch linear layer that internally splits into K and V projections.
|
||||
|
||||
Args:
|
||||
fd_config (FDConfig): Inference-related parameters.
|
||||
prefix (str): Unique name of the layer, used to name internal attributes.
|
||||
kv_lora_rank (int): LoRA rank for KV projection. Defaults to None.
|
||||
num_attention_heads (int): Number of attention heads. Defaults to None.
|
||||
qk_nope_head_dim (int): Dimension for Q/K projection (nope part). Defaults to None.
|
||||
v_head_dim (int): Dimension for V projection. Defaults to None.
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
# Split num_attention_heads when using TP inference.
|
||||
self.num_heads_per_partition = divide(num_attention_heads, self.nranks)
|
||||
|
||||
# Initialize parent with combined dimensions
|
||||
super().__init__(
|
||||
fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
input_size=None, # Will be determined from weight shape
|
||||
output_size=None, # Will be determined from weight shape
|
||||
with_bias=with_bias,
|
||||
add_bias=False,
|
||||
skip_quant=skip_quant,
|
||||
)
|
||||
self.weight_dtype = self._dtype
|
||||
|
||||
# Override weight keys to use the combined kv_b_proj
|
||||
self.weight_key = f"{prefix}.weight" # e.g., "kv_b_proj.weight"
|
||||
self.k_weight_key = f"{prefix.replace('kv_b_proj', 'k_b_proj')}.weight"
|
||||
self.v_weight_key = f"{prefix.replace('kv_b_proj', 'v_b_proj')}.weight"
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""
|
||||
Load the combined KV weight and split it into K and V projections
|
||||
"""
|
||||
# Get the combined KV weight
|
||||
# NOTE(Ryan):Do not pop weight_key here, it will be popped in other class
|
||||
kv_weight_tensor = get_tensor(state_dict[self.weight_key])
|
||||
|
||||
# Reshape and split the weight
|
||||
w = kv_weight_tensor.reshape([
|
||||
self.kv_lora_rank,
|
||||
self.num_heads_per_partition,
|
||||
-1,
|
||||
]).transpose(perm=[1, 2, 0])
|
||||
|
||||
# Split into K and V weights
|
||||
# wk_b: [num_heads, qk_nope_head_dim, kv_lora_rank]
|
||||
wk_b = w[:, :self.qk_nope_head_dim, :]
|
||||
|
||||
if self.v_head_dim is None:
|
||||
raise ValueError("self.v_head_dim should not be None")
|
||||
# wv_b: [num_heads, kv_lora_rank, v_head_dim]
|
||||
wv_b = w[:, -self.v_head_dim:, :].transpose(perm=[0, 2, 1])
|
||||
|
||||
# Create K projection weight
|
||||
self.k_b_proj_weight = self.create_parameter(
|
||||
shape=wk_b.shape,
|
||||
dtype=self.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
# Create V projection weight
|
||||
self.v_b_proj_weight = self.create_parameter(
|
||||
shape=wv_b.shape,
|
||||
dtype=self.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
self.k_b_proj_weight.set_value(wk_b)
|
||||
self.v_b_proj_weight.set_value(wv_b)
|
||||
|
||||
def forward_k_b(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
Forward pass for K_b projection using bmm
|
||||
|
||||
Args:
|
||||
x: Input tensor (e.g., query_nope.transpose([1, 0, 2]))
|
||||
|
||||
Returns:
|
||||
K_b projection output
|
||||
"""
|
||||
|
||||
out = paddle.bmm(x, self.k_b_proj_weight)
|
||||
return out
|
||||
|
||||
def forward_v_b(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
Forward pass for V_b projection using bmm
|
||||
|
||||
Args:
|
||||
x: Input tensor (e.g., fmha_out_decode)
|
||||
|
||||
Returns:
|
||||
V_b projection output
|
||||
"""
|
||||
out = paddle.bmm(x, self.v_b_proj_weight)
|
||||
return out
|
||||
|
||||
def forward_cuda(self,
|
||||
x: paddle.Tensor,
|
||||
proj_type: str = 'k') -> paddle.Tensor:
|
||||
"""
|
||||
Forward function that can handle both K and V projections
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
proj_type: 'k' or 'v' to select which projection to use
|
||||
|
||||
Returns:
|
||||
Projection output
|
||||
"""
|
||||
if proj_type == 'k':
|
||||
return self.forward_k_b(x)
|
||||
elif proj_type == 'v':
|
||||
return self.forward_v_b(x)
|
||||
else:
|
||||
raise ValueError(f"proj_type must be 'k' or 'v', got {proj_type}")
|
||||
|
@@ -21,48 +21,6 @@ from paddle.distributed import fleet
|
||||
from .utils import get_tensor
|
||||
|
||||
|
||||
def parallel_matmul(lm_output, logit_weights, parallel_output):
|
||||
"""
|
||||
Performs parallel matrix multiplication for large-scale language models.
|
||||
|
||||
Args:
|
||||
lm_output (Tensor): The output tensor from the language model layers,
|
||||
which will be multiplied with the logit weights.
|
||||
logit_weights (Tensor): The weights used in the matrix multiplication,
|
||||
typically the weights of the output layer.
|
||||
parallel_output (bool): A flag indicating whether to return the parallel
|
||||
outputs or concatenate them. If True, returns the outputs from the
|
||||
parallel computation directly. If False, concatenates the outputs
|
||||
across the model parallel group before returning.
|
||||
|
||||
Returns:
|
||||
Tensor: The result of the matrix multiplication. If `parallel_output` is True,
|
||||
returns the parallel outputs. If `parallel_output` is False and
|
||||
model parallel world size is greater than 1, returns the concatenated
|
||||
outputs across the model parallel group. Otherwise, returns the direct
|
||||
matrix multiplication result.
|
||||
"""
|
||||
hcg = fleet.get_hybrid_communicate_group()
|
||||
model_parallel_group = hcg.get_model_parallel_group()
|
||||
world_size = hcg.get_model_parallel_world_size()
|
||||
# rank = hcg.get_model_parallel_rank()
|
||||
|
||||
if world_size > 1:
|
||||
input_parallel = paddle.distributed.collective._c_identity(
|
||||
lm_output, group=model_parallel_group)
|
||||
|
||||
logits = paddle.matmul(input_parallel, logit_weights, transpose_y=True)
|
||||
|
||||
if parallel_output:
|
||||
return logits
|
||||
|
||||
return paddle.distributed.collective._c_concat(
|
||||
logits, group=model_parallel_group)
|
||||
else:
|
||||
logits = paddle.matmul(lm_output, logit_weights, transpose_y=True)
|
||||
return logits
|
||||
|
||||
|
||||
class ParallelLMHead(nn.Layer):
|
||||
"""
|
||||
"Parallelized LM head.
|
||||
@@ -70,75 +28,69 @@ class ParallelLMHead(nn.Layer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config,
|
||||
fd_config,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
prefix="",
|
||||
with_bias=False,
|
||||
tie_word_embeddings=None,
|
||||
):
|
||||
"""
|
||||
Parallelized LMhead.
|
||||
|
||||
Args:
|
||||
llm_config (LLMConfig): Arguments related to inference, containing
|
||||
fd_config (FDConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
num_embeddings (int): vocabulary size.
|
||||
embedding_dim (int): size of hidden state.
|
||||
tie_embeddings_weight (bool, optional): Whether to share weights across model parallel ranks,
|
||||
defaults to None.
|
||||
prefix (str): full name of the layer in the state dict
|
||||
"""
|
||||
super(ParallelLMHead, self).__init__()
|
||||
self.use_moe = llm_config.model_config.use_moe
|
||||
self.linear_weight_key = prefix + ".weight"
|
||||
if with_bias:
|
||||
self.linear_bias_key = prefix + ".bias"
|
||||
else:
|
||||
self.linear_bias_key = None
|
||||
self.use_ep = llm_config.parallel_config.use_ep
|
||||
self.use_ep = fd_config.parallel_config.use_ep
|
||||
self.column_cut = True
|
||||
self.fused_linear = True
|
||||
|
||||
hcg = fleet.get_hybrid_communicate_group()
|
||||
mp_rank = hcg.get_model_parallel_rank()
|
||||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
||||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
||||
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings
|
||||
|
||||
if self.tie_word_embeddings is None:
|
||||
if self.use_ep:
|
||||
self.weight = self.create_parameter(
|
||||
shape=[embedding_dim, num_embeddings],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
is_bias=False,
|
||||
if self.use_ep:
|
||||
self.weight = self.create_parameter(
|
||||
shape=[embedding_dim, num_embeddings],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
is_bias=False,
|
||||
)
|
||||
else:
|
||||
if self.column_cut:
|
||||
need_gather = True
|
||||
self.out_linear = ColumnParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().
|
||||
get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True
|
||||
if self.linear_bias_key is not None else False,
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
else:
|
||||
if self.column_cut:
|
||||
need_gather = True
|
||||
self.out_linear = ColumnParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().
|
||||
get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True,
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=self.fused_linear, # False diff更小
|
||||
)
|
||||
else:
|
||||
self.out_linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().
|
||||
get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True,
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=self.fused_linear, # False diff更小
|
||||
)
|
||||
self.out_linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().
|
||||
get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True
|
||||
if self.linear_bias_key is not None else False,
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
@@ -148,25 +100,26 @@ class ParallelLMHead(nn.Layer):
|
||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||
"""
|
||||
|
||||
if self.tie_word_embeddings is None:
|
||||
if self.use_ep:
|
||||
self.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.linear_weight_key)).astype(
|
||||
paddle.get_default_dtype()))
|
||||
else:
|
||||
if self.use_ep:
|
||||
self.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.linear_weight_key)).astype(
|
||||
paddle.get_default_dtype()))
|
||||
else:
|
||||
if self.tie_word_embeddings:
|
||||
self.out_linear.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.linear_weight_key)).astype(
|
||||
paddle.get_default_dtype()))
|
||||
paddle.get_default_dtype()).transpose([1, 0]))
|
||||
else:
|
||||
weight_tensor = get_tensor(
|
||||
state_dict.pop(self.linear_weight_key)).astype(
|
||||
paddle.get_default_dtype())
|
||||
if self.out_linear.weight.shape != weight_tensor.shape:
|
||||
weight_tensor = weight_tensor.transpose([1, 0])
|
||||
self.out_linear.weight.set_value(weight_tensor)
|
||||
|
||||
bias = (
|
||||
get_tensor(state_dict.pop(self.linear_bias_key)).astype(
|
||||
paddle.get_default_dtype()
|
||||
)
|
||||
if self.linear_bias_key is not None
|
||||
else paddle.zeros(
|
||||
self.out_linear.bias.shape, dtype=paddle.get_default_dtype()
|
||||
)
|
||||
)
|
||||
if self.linear_bias_key is not None:
|
||||
bias = get_tensor(state_dict.pop(self.linear_bias_key)).astype(
|
||||
paddle.get_default_dtype())
|
||||
self.out_linear.bias.set_value(bias)
|
||||
|
||||
def forward(self, input):
|
||||
@@ -180,11 +133,8 @@ class ParallelLMHead(nn.Layer):
|
||||
Tensor: The output tensor after processing through the layer.
|
||||
"""
|
||||
logits = input
|
||||
if self.tie_word_embeddings is not None:
|
||||
logits = parallel_matmul(logits, self.tie_word_embeddings, False)
|
||||
if self.use_ep:
|
||||
logits = paddle.matmul(logits, self.weight)
|
||||
else:
|
||||
if self.use_ep:
|
||||
logits = paddle.matmul(logits, self.weight)
|
||||
else:
|
||||
logits = self.out_linear(logits)
|
||||
logits = self.out_linear(logits)
|
||||
return logits
|
||||
|
@@ -11,3 +11,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .fused_moe_cutlass_backend import (CutlassW4A8MoEMethod,
|
||||
CutlassWeightOnlyMoEMethod)
|
||||
from .fused_moe_triton_backend import TritonWeightOnlyMoEMethod
|
||||
from .moe import FusedMoE
|
||||
|
||||
__all__ = [
|
||||
CutlassWeightOnlyMoEMethod, CutlassW4A8MoEMethod, FusedMoE,
|
||||
TritonWeightOnlyMoEMethod
|
||||
]
|
||||
|
@@ -1,222 +0,0 @@
|
||||
"""
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.distributed import fleet
|
||||
from paddle.framework import in_dynamic_or_pir_mode
|
||||
from paddle.nn.quant import weight_quantize
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import (moe_expert_dispatch,
|
||||
moe_expert_ffn,
|
||||
moe_expert_reduce)
|
||||
|
||||
from .fused_moe_method_base import FusedMoEMethodBase
|
||||
|
||||
|
||||
class CutlassFusedMoeMethod(FusedMoEMethodBase):
|
||||
"""
|
||||
Use Cutlass Group Gemm to compute Fused MoE.
|
||||
This method is the oldest way to compute MoE in Paddle.
|
||||
"""
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
moe_compute_params,
|
||||
ffn1_tensor,
|
||||
ffn2_tensor,
|
||||
ffn1_bias=None,
|
||||
ffn2_bias=None,
|
||||
# belows only used in w4a8.
|
||||
moe_ffn1_weight_scale=None,
|
||||
moe_ffn2_weight_scale=None,
|
||||
moe_ffn1_in_scale=None,
|
||||
moe_ffn2_in_scale=None):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
|
||||
num_local_experts = moe_compute_params.num_local_experts
|
||||
moe_quant_type = moe_compute_params.moe_quant_type
|
||||
|
||||
assert len(ffn1_tensor) == num_local_experts
|
||||
assert len(ffn2_tensor) == num_local_experts
|
||||
assert ffn1_tensor[0].shape == [
|
||||
moe_compute_params.hidden_size,
|
||||
moe_compute_params.moe_intermediate_size * 2
|
||||
]
|
||||
assert ffn2_tensor[0].shape == [
|
||||
moe_compute_params.moe_intermediate_size,
|
||||
moe_compute_params.hidden_size
|
||||
]
|
||||
|
||||
added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
|
||||
added_scale_attrs = ["moe_ffn1_weight_scale", "moe_ffn2_weight_scale"]
|
||||
|
||||
if moe_quant_type == "w4a8":
|
||||
moe_ffn1_in_scale = paddle.concat(moe_ffn1_in_scale)
|
||||
moe_ffn2_in_scale = paddle.concat(moe_ffn2_in_scale)
|
||||
moe_ffn1_in_scale = 1 / moe_ffn1_in_scale
|
||||
moe_ffn2_in_scale = 1 / moe_ffn2_in_scale
|
||||
moe_ffn1_weight_scale = paddle.stack(moe_ffn1_weight_scale, axis=0)
|
||||
moe_ffn2_weight_scale = paddle.stack(moe_ffn2_weight_scale, axis=0)
|
||||
|
||||
moe_ffn1_weight_scale = moe_ffn1_weight_scale / (127 * 112)
|
||||
moe_ffn2_weight_scale = moe_ffn2_weight_scale / (127 * 112)
|
||||
moe_ffn1_weight_scale = moe_ffn1_weight_scale / moe_ffn1_in_scale[:,
|
||||
None]
|
||||
moe_ffn2_weight_scale = moe_ffn2_weight_scale / moe_ffn2_in_scale[:,
|
||||
None]
|
||||
moe_ffn1_weight_scale = moe_ffn1_weight_scale.cast(
|
||||
paddle.get_default_dtype())
|
||||
moe_ffn2_weight_scale = moe_ffn2_weight_scale.cast(
|
||||
paddle.get_default_dtype())
|
||||
|
||||
if moe_quant_type in ["weight_only_int4", "weight_only_int8", "w4a8"]:
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]):
|
||||
weight_name = added_weight_attrs[idx]
|
||||
scale_name = added_scale_attrs[idx]
|
||||
|
||||
weight_list = []
|
||||
weight_scale_list = []
|
||||
for i in range(num_local_experts):
|
||||
quant_weight, scale = weight_quantize(weight_tensor[i],
|
||||
algo=moe_quant_type,
|
||||
arch=80)
|
||||
weight_list.append(quant_weight)
|
||||
if moe_quant_type != "w4a8":
|
||||
# scale holds no memoty in w4a8, don't touch it!
|
||||
weight_scale_list.append(scale)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
setattr(
|
||||
layer, weight_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight.shape,
|
||||
dtype=quanted_weight.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
))
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
|
||||
# this scale only useful for wint8/4.
|
||||
if moe_quant_type != "w4a8":
|
||||
quanted_weight_scale = paddle.stack(weight_scale_list,
|
||||
axis=0)
|
||||
setattr(
|
||||
layer, scale_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight_scale.shape,
|
||||
dtype=quanted_weight_scale.dtype,
|
||||
))
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||
|
||||
if moe_quant_type == "w4a8":
|
||||
assert moe_ffn1_weight_scale is not None
|
||||
assert moe_ffn2_weight_scale is not None
|
||||
assert moe_ffn1_in_scale is not None
|
||||
assert moe_ffn2_in_scale is not None
|
||||
added_w4a8_attrs = [
|
||||
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale",
|
||||
"moe_ffn1_in_scale", "moe_ffn2_in_scale"
|
||||
]
|
||||
for idx, weight_tensor in enumerate([
|
||||
moe_ffn1_weight_scale, moe_ffn2_weight_scale,
|
||||
moe_ffn1_in_scale, moe_ffn2_in_scale
|
||||
]):
|
||||
name = added_w4a8_attrs[idx]
|
||||
setattr(
|
||||
layer, name,
|
||||
layer.create_parameter(
|
||||
shape=weight_tensor.shape,
|
||||
dtype=weight_tensor.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
))
|
||||
getattr(layer, name).set_value(weight_tensor)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
moe_compute_params,
|
||||
x: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
|
||||
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
|
||||
|
||||
(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
permute_indices_per_token,
|
||||
topk_weights,
|
||||
topk_idx,
|
||||
expert_idx_per_token,
|
||||
) = moe_expert_dispatch(
|
||||
x,
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
|
||||
else None), # if set, permute_input will be int8_t
|
||||
moe_compute_params.top_k,
|
||||
False,
|
||||
topk_only_mode=False,
|
||||
)
|
||||
|
||||
if moe_compute_params.moe_quant_type != "w4a8":
|
||||
# only w4a8 need expert_idx_per_token
|
||||
# Other need not this tensor, so we make it None.
|
||||
expert_idx_per_token = None
|
||||
else:
|
||||
expert_idx_per_token = expert_idx_per_token.cast("int64")
|
||||
|
||||
ffn_out = moe_expert_ffn(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
layer.moe_ffn1_weight,
|
||||
layer.moe_ffn2_weight,
|
||||
None,
|
||||
(layer.moe_ffn1_weight_scale
|
||||
if hasattr(layer, "moe_ffn1_weight_scale") else None),
|
||||
(layer.moe_ffn2_weight_scale
|
||||
if hasattr(layer, "moe_ffn2_weight_scale") else None),
|
||||
(layer.moe_ffn2_in_scale
|
||||
if hasattr(layer, "moe_ffn2_in_scale") else None),
|
||||
expert_idx_per_token,
|
||||
moe_compute_params.moe_quant_type,
|
||||
False, # used_in_ep_low_latency
|
||||
)
|
||||
|
||||
if False:
|
||||
if in_dynamic_or_pir_mode():
|
||||
hcg = fleet.get_hybrid_communicate_group()
|
||||
mp_group = hcg.get_model_parallel_group()
|
||||
paddle.distributed.all_reduce(ffn_out, group=mp_group)
|
||||
else:
|
||||
paddle.distributed.all_reduce(ffn_out, group=mp_group)
|
||||
|
||||
# reduce 中会做 topk 个 weight 的 norm 和 routed_scaling_factor
|
||||
fused_moe_out = moe_expert_reduce(
|
||||
ffn_out,
|
||||
topk_weights,
|
||||
permute_indices_per_token,
|
||||
topk_idx,
|
||||
None,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
return fused_moe_out
|
File diff suppressed because it is too large
Load Diff
135
fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py
Normal file
135
fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.config import MoEPhase
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
|
||||
|
||||
class MoEMethodBase(QuantMethodBase):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config):
|
||||
super().__init__()
|
||||
if quant_config is None:
|
||||
self.moe_quant_type = "w16a16"
|
||||
else:
|
||||
self.quant_config = quant_config
|
||||
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
|
||||
self.added_scale_attrs = [
|
||||
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
|
||||
]
|
||||
self.pack_num = 1
|
||||
|
||||
def init_ep(self, layer: nn.Layer) -> None:
|
||||
"""
|
||||
Init EP related module
|
||||
"""
|
||||
if layer.ep_size > 1:
|
||||
if layer.fd_config.parallel_config.moe_phase == MoEPhase.DECODER:
|
||||
from .ep import EPDecoderRunner
|
||||
self.ep_decoder_runner = EPDecoderRunner(
|
||||
layer.top_k, layer.hidden_size, layer.num_experts,
|
||||
layer.moe_config.num_max_dispatch_tokens_per_rank,
|
||||
layer.ep_size, layer.ep_rank)
|
||||
else:
|
||||
from .ep import EPPrefillRunner
|
||||
self.ep_prefill_runner = EPPrefillRunner(
|
||||
layer.top_k, layer.hidden_size, layer.num_experts,
|
||||
layer.ep_size, layer.ep_rank)
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
"""
|
||||
process_loaded_weights
|
||||
"""
|
||||
pass
|
||||
|
||||
def check(self, layer: nn.Layer, ffn1_weights, ffn2_weights):
|
||||
"""
|
||||
check layer is valid for this method
|
||||
"""
|
||||
assert ffn1_weights[0].shape == [
|
||||
layer.hidden_size // self.pack_num, layer.moe_intermediate_size * 2
|
||||
]
|
||||
assert ffn2_weights[0].shape == [
|
||||
layer.moe_intermediate_size // self.pack_num, layer.hidden_size
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_ep_prefill(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_ep_decode(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_tp(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
if layer.ep_size > 1:
|
||||
if layer.fd_config.parallel_config.moe_phase == MoEPhase.PREFILL:
|
||||
return self.apply_ep_prefill(layer, x, gate_out)
|
||||
else:
|
||||
return self.apply_ep_decode(layer, x, gate_out)
|
||||
else:
|
||||
return self.apply_tp(layer, x, gate_out)
|
@@ -0,0 +1,431 @@
|
||||
"""
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn.quant import weight_quantize
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from ..utils import get_tensor, create_and_set_parameter
|
||||
from .fused_moe_backend_base import MoEMethodBase
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch
|
||||
from fastdeploy.model_executor.ops.gpu import moe_expert_reduce
|
||||
|
||||
|
||||
class CutlassMoEMethod(MoEMethodBase):
|
||||
"""
|
||||
Use Cutlass Group Gemm to compute Fused MoE.
|
||||
This method is the oldest way to compute MoE in Paddle.
|
||||
"""
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
# bf16
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0)
|
||||
stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0)
|
||||
for idx, weight_tensor in enumerate(
|
||||
[stacked_ffn1_weights, stacked_ffn2_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
setattr(
|
||||
layer, weight_name,
|
||||
layer.create_parameter(
|
||||
shape=weight_tensor.shape,
|
||||
dtype=weight_tensor.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
))
|
||||
getattr(layer, weight_name).set_value(weight_tensor)
|
||||
|
||||
def compute_ffn(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
permute_input: paddle.Tensor,
|
||||
token_nums_per_expert: paddle.Tensor,
|
||||
expert_idx_per_token: paddle.Tensor,
|
||||
used_in_ep_low_latency: bool = False,
|
||||
):
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
return fastdeploy.model_executor.ops.gpu.moe_expert_ffn(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
layer.moe_ffn1_weight,
|
||||
layer.moe_ffn2_weight,
|
||||
None,
|
||||
(layer.moe_ffn1_weight_scale
|
||||
if hasattr(layer, "moe_ffn1_weight_scale") else None),
|
||||
(layer.moe_ffn2_weight_scale
|
||||
if hasattr(layer, "moe_ffn2_weight_scale") else None),
|
||||
(layer.moe_ffn2_in_scale
|
||||
if hasattr(layer, "moe_ffn2_in_scale") else None),
|
||||
expert_idx_per_token,
|
||||
self.moe_quant_type,
|
||||
used_in_ep_low_latency,
|
||||
)
|
||||
|
||||
def apply_ep_prefill(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
"""
|
||||
# 1. Select topk experts and weights
|
||||
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(
|
||||
layer, gate_out)
|
||||
# 2. EP Dispatch
|
||||
(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
recv_num_tokens_per_expert_list,
|
||||
handle,
|
||||
_,
|
||||
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights)
|
||||
token_all_num = sum(recv_num_tokens_per_expert_list)
|
||||
|
||||
# 3. Compute ffn
|
||||
if token_all_num > 0:
|
||||
logger.info(f"token_all_num {token_all_num}")
|
||||
(
|
||||
permute_input,
|
||||
permute_indices_per_token,
|
||||
recv_num_tokens_per_expert_list_cumsum,
|
||||
dst_weights,
|
||||
dst_indices,
|
||||
cumsum_idx_gpu,
|
||||
expert_idx_per_token,
|
||||
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
(self.moe_ffn1_in_scale
|
||||
if hasattr(self, "moe_ffn1_in_scale") else None),
|
||||
recv_num_tokens_per_expert_list,
|
||||
token_all_num,
|
||||
self.moe_quant_type,
|
||||
)
|
||||
if self.moe_quant_type != "w4a8":
|
||||
# only w4a8 need expert_idx_per_token
|
||||
# Other need not this tensor, so we make it None.
|
||||
expert_idx_per_token = None
|
||||
else:
|
||||
expert_idx_per_token = expert_idx_per_token.cast("int64")
|
||||
|
||||
ffn_out = self.compute_ffn(layer, permute_input,
|
||||
recv_num_tokens_per_expert_list_cumsum,
|
||||
expert_idx_per_token)
|
||||
|
||||
# prmt back per rank
|
||||
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
|
||||
ffn_out,
|
||||
dst_weights,
|
||||
permute_indices_per_token,
|
||||
dst_indices,
|
||||
None, # moe_ffn2_bias,
|
||||
False, # norm_topk_prob
|
||||
1.0,
|
||||
)[0]
|
||||
else:
|
||||
tmp_ffn_out = recv_x
|
||||
|
||||
# 4. EP combine
|
||||
return self.ep_prefill_runner.combine(tmp_ffn_out, handle,
|
||||
recv_topk_weights)
|
||||
|
||||
def apply_ep_decode(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
"""
|
||||
# 1. Select topk experts and weights
|
||||
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(
|
||||
layer, gate_out)
|
||||
# 2. EP Dispatch
|
||||
permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch(
|
||||
x, topk_idx, topk_weights)
|
||||
# 3. Compute ffn
|
||||
if self.moe_quant_type == "w4a8":
|
||||
num_local_experts, max_num, _ = permute_input.shape
|
||||
expert_idx_per_token = paddle.arange(
|
||||
num_local_experts)[:, None].tile([1, max_num])
|
||||
elif self.moe_quant_type in ["weight_only_int8", "weight_only_int4"]:
|
||||
expert_idx_per_token = None
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
ffn_out = self.compute_ffn(layer, permute_input,
|
||||
token_nums_per_expert.cast("int64"),
|
||||
expert_idx_per_token, True)
|
||||
|
||||
# 4. EP combine
|
||||
return self.ep_decoder_runner.combine(ffn_out, topk_idx, topk_weights,
|
||||
handle)
|
||||
|
||||
def apply_tp(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
permute_indices_per_token,
|
||||
topk_weights,
|
||||
topk_idx,
|
||||
expert_idx_per_token,
|
||||
) = moe_expert_dispatch(
|
||||
x,
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
|
||||
else None), # if set, permute_input will be int8_t
|
||||
layer.top_k,
|
||||
False,
|
||||
topk_only_mode=False,
|
||||
)
|
||||
|
||||
if self.moe_quant_type != "w4a8":
|
||||
# only w4a8 need expert_idx_per_token
|
||||
# Other need not this tensor, so we make it None.
|
||||
expert_idx_per_token = None
|
||||
else:
|
||||
expert_idx_per_token = expert_idx_per_token.cast("int64")
|
||||
|
||||
ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert,
|
||||
expert_idx_per_token)
|
||||
|
||||
# reduce 中会做 topk 个 weight 的 norm 和 routed_scaling_factor
|
||||
fused_moe_out = moe_expert_reduce(
|
||||
ffn_out,
|
||||
topk_weights,
|
||||
permute_indices_per_token,
|
||||
topk_idx,
|
||||
None,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
|
||||
if layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
|
||||
class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
"""
|
||||
w4a8 MoE Method
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config):
|
||||
super().__init__(quant_config)
|
||||
self.quant_config = quant_config
|
||||
self.moe_quant_type = "w4a8"
|
||||
self.pack_num = 2
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
self.check(layer, ffn1_weights, ffn2_weights)
|
||||
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
weight_list = []
|
||||
for i in range(layer.num_local_experts):
|
||||
quant_weight, scale = weight_quantize(weight_tensor[i],
|
||||
algo=self.moe_quant_type,
|
||||
arch=80)
|
||||
weight_list.append(quant_weight)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
create_and_set_parameter(layer, weight_name, quanted_weight)
|
||||
|
||||
self.create_w4a8_scale_weights(layer, layer.weight_key_map, state_dict)
|
||||
|
||||
def create_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict,
|
||||
state_dict: dict):
|
||||
"""
|
||||
Get w4a8 weights from state dict and process them.
|
||||
Args:
|
||||
layer (nn.Layer): The layer to add parameters to.
|
||||
weight_key_map (dict): The weight key map.
|
||||
state_dict (dict): The state dict.
|
||||
"""
|
||||
|
||||
def _extract_scale_tensor(state_dict, key_template, expert_idx):
|
||||
return get_tensor(state_dict.pop(key_template.format(expert_idx)))
|
||||
|
||||
def _process_in_scale(name: str, in_scales: list[paddle.Tensor]):
|
||||
processed_in_scale = 1 / paddle.concat(in_scales)
|
||||
create_and_set_parameter(layer, name, processed_in_scale)
|
||||
return processed_in_scale
|
||||
|
||||
def _process_weight_scale(name: str,
|
||||
weight_scales: list[paddle.Tensor],
|
||||
processed_in_scale: paddle.Tensor):
|
||||
processed_weight_scale = (paddle.stack(weight_scales, axis=0) /
|
||||
(127 * 112) /
|
||||
processed_in_scale[:, None]).cast(
|
||||
paddle.get_default_dtype())
|
||||
create_and_set_parameter(layer, name, processed_weight_scale)
|
||||
|
||||
# 1. Init scale containers and maps
|
||||
moe_ffn1_weight_scales = []
|
||||
moe_ffn2_weight_scales = []
|
||||
moe_ffn1_in_scales = []
|
||||
moe_ffn2_in_scales = []
|
||||
|
||||
scale_weight_map = {
|
||||
"moe_ffn1_weight_scale": moe_ffn1_weight_scales,
|
||||
"moe_ffn2_weight_scale": moe_ffn2_weight_scales,
|
||||
"moe_ffn1_in_scale": moe_ffn1_in_scales,
|
||||
"moe_ffn2_in_scale": moe_ffn2_in_scales,
|
||||
}
|
||||
scale_key_map = {
|
||||
"moe_ffn1_weight_scale":
|
||||
weight_key_map.get("ffn1_expert_weight_scale_key", None),
|
||||
"moe_ffn2_weight_scale":
|
||||
weight_key_map.get("ffn2_expert_weight_scale_key", None),
|
||||
"moe_ffn1_in_scale":
|
||||
weight_key_map.get("ffn1_expert_in_scale_key", None),
|
||||
"moe_ffn2_in_scale":
|
||||
weight_key_map.get("ffn2_expert_in_scale_key", None),
|
||||
}
|
||||
for name, value in scale_key_map.items():
|
||||
if value is None:
|
||||
raise ValueError(
|
||||
f"scale {name} should not be none in w4a8 mode.")
|
||||
|
||||
# 2. Extract scale tensor from state dict
|
||||
|
||||
for local_expert_idx in range(layer.num_local_experts):
|
||||
expert_idx = local_expert_idx + layer.expert_id_offset * layer.num_local_experts
|
||||
for name, scale_key_template in scale_key_map.items():
|
||||
scale_tensor = _extract_scale_tensor(state_dict,
|
||||
scale_key_template,
|
||||
expert_idx)
|
||||
scale_weight_map[name].append(scale_tensor)
|
||||
|
||||
# 3. Process scale tensor and set to layer
|
||||
in_scales = []
|
||||
for in_scale_name in ["moe_ffn1_in_scale", "moe_ffn2_in_scale"]:
|
||||
in_scales.append(
|
||||
_process_in_scale(in_scale_name,
|
||||
scale_weight_map[in_scale_name]))
|
||||
|
||||
for i, weight_scale_name in enumerate(
|
||||
["moe_ffn1_weight_scale", "moe_ffn2_weight_scale"]):
|
||||
_process_weight_scale(weight_scale_name,
|
||||
scale_weight_map[weight_scale_name],
|
||||
in_scales[i])
|
||||
|
||||
|
||||
class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
||||
"""
|
||||
weight only for moe
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config):
|
||||
super().__init__(quant_config)
|
||||
self.quant_config = quant_config
|
||||
self.moe_quant_type = self.quant_config.algo
|
||||
self.pack_num = 1
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass process prequanted weights.
|
||||
"""
|
||||
ffn1_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_key", None)
|
||||
ffn2_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_key", None)
|
||||
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_scale_key", None)
|
||||
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_scale_key", None)
|
||||
|
||||
ffn1_weights, ffn2_weights = layer.load_experts_weight(
|
||||
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
|
||||
# self.check(layer, ffn1_weights, ffn2_weights)
|
||||
ffn1_weight_scale = []
|
||||
ffn2_weight_scale = []
|
||||
for i in range(layer.num_local_experts):
|
||||
expert_idx = layer.expert_id_offset + i
|
||||
ffn1_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn1_expert_weight_scale_key.format(expert_idx))))
|
||||
ffn2_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_weight_scale_key.format(expert_idx))))
|
||||
|
||||
ffn1_weight = paddle.stack(ffn1_weights, axis=0)
|
||||
ffn2_weight = paddle.stack(ffn2_weights, axis=0)
|
||||
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0)
|
||||
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0)
|
||||
|
||||
name_tensor_map = {
|
||||
"moe_ffn1_weight": ffn1_weight,
|
||||
"moe_ffn2_weight": ffn2_weight,
|
||||
"moe_ffn1_weight_scale": ffn1_weight_scale,
|
||||
"moe_ffn2_weight_scale": ffn2_weight_scale
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
create_and_set_parameter(layer, name, tensor)
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
self.check(layer, ffn1_weights, ffn2_weights)
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
|
||||
weight_list = []
|
||||
weight_scale_list = []
|
||||
for i in range(layer.num_local_experts):
|
||||
quant_weight, scale = weight_quantize(weight_tensor[i],
|
||||
algo=self.moe_quant_type)
|
||||
weight_list.append(quant_weight)
|
||||
weight_scale_list.append(scale)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
create_and_set_parameter(layer, weight_name, quanted_weight)
|
||||
|
||||
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
|
||||
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
|
@@ -0,0 +1,380 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
import fastdeploy
|
||||
import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
|
||||
from ..utils import create_and_set_parameter
|
||||
from .fused_moe_backend_base import MoEMethodBase
|
||||
|
||||
|
||||
class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
"""
|
||||
DeepGemmFusedMoeMethod is a class that implements the MoEMethodBase interface for DeepGemm backend.
|
||||
"""
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
deepgemm create weight process.
|
||||
"""
|
||||
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
|
||||
self.check(layer, ffn1_weights, ffn2_weights)
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
|
||||
weight_list = []
|
||||
weight_scale_list = []
|
||||
for i in range(layer.num_local_experts):
|
||||
from fastdeploy.model_executor.layers.utils import \
|
||||
per_block_cast_to_fp8
|
||||
quant_weight, scale = per_block_cast_to_fp8(
|
||||
weight_tensor[i], self.quant_config.weight_block_size)
|
||||
|
||||
weight_list.append(quant_weight)
|
||||
weight_scale_list.append(scale)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous()
|
||||
create_and_set_parameter(layer, weight_name, quanted_weight)
|
||||
|
||||
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
|
||||
quanted_weight_scale = quanted_weight_scale.transpose(
|
||||
[0, 2, 1]).contiguous()
|
||||
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass process prequanted weights.
|
||||
"""
|
||||
ffn1_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_key", None)
|
||||
ffn2_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_key", None)
|
||||
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_scale_key", None)
|
||||
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_scale_key", None)
|
||||
|
||||
ffn1_weights, ffn2_weights = layer.load_experts_weight(
|
||||
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
|
||||
# self.check(layer, ffn1_weights, ffn2_weights)
|
||||
ffn1_weight_scale = []
|
||||
ffn2_weight_scale = []
|
||||
for i in range(layer.num_local_experts):
|
||||
expert_idx = layer.expert_id_offset + i
|
||||
ffn1_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn1_expert_weight_scale_key.format(expert_idx))))
|
||||
ffn2_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_weight_scale_key.format(expert_idx))))
|
||||
|
||||
ffn1_weight = paddle.stack(ffn1_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
|
||||
ffn2_weight = paddle.stack(ffn2_weights, axis=0).transpose([0, 2, 1]).contiguous().view("float8_e4m3fn")
|
||||
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0).transpose([0, 2, 1]).contiguous()
|
||||
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0).transpose([0, 2, 1]).contiguous()
|
||||
|
||||
name_tensor_map = {
|
||||
"moe_ffn1_weight": ffn1_weight,
|
||||
"moe_ffn2_weight": ffn2_weight,
|
||||
"moe_ffn1_weight_scale": ffn1_weight_scale,
|
||||
"moe_ffn2_weight_scale": ffn2_weight_scale
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
create_and_set_parameter(layer, name, tensor)
|
||||
|
||||
def apply_ep_prefill(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
"""
|
||||
# 1. Select topk experts and weights
|
||||
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(
|
||||
layer, gate_out)
|
||||
# 2. Dynamic compute blockwise quantization scales
|
||||
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
x, self.quant_config.weight_block_size[0])
|
||||
# 3. EP Dispatch
|
||||
(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
recv_num_tokens_per_expert_list,
|
||||
handle,
|
||||
_,
|
||||
) = self.ep_prefill_runner.dispatch(x,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
x_scale_tensor=x_scale_tensor)
|
||||
|
||||
token_all_num = sum(recv_num_tokens_per_expert_list)
|
||||
|
||||
# 4. Compute ffn
|
||||
if token_all_num > 0:
|
||||
logger.info(f"token_all_num {token_all_num}")
|
||||
(recv_x, recv_x_scale) = recv_x
|
||||
tmp = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
|
||||
(
|
||||
permute_input,
|
||||
permute_scale,
|
||||
permute_indices_per_token,
|
||||
recv_num_tokens_per_expert_list_cumsum,
|
||||
recv_num_tokens_per_expert_list_padded_cumsum,
|
||||
dst_weights,
|
||||
dst_indices,
|
||||
cumsum_idx_gpu,
|
||||
m_indices,
|
||||
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8(
|
||||
recv_x,
|
||||
recv_x_scale,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
tmp[0],
|
||||
tmp[1]
|
||||
)
|
||||
|
||||
permute_scale = permute_scale.transpose([1, 0]).contiguous()
|
||||
permute_scale = permute_scale.transpose([1, 0])
|
||||
|
||||
# ffn1
|
||||
ffn_out = paddle.empty(
|
||||
(permute_input.shape[0], layer.moe_ffn1_weight.shape[1]),
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(permute_input, permute_scale),
|
||||
(layer.moe_ffn1_weight, layer.moe_ffn1_weight_scale),
|
||||
ffn_out,
|
||||
m_indices,
|
||||
)
|
||||
# swiglu
|
||||
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out, None)
|
||||
|
||||
# ffn2
|
||||
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
ffn_out, self.quant_config.weight_block_size[0])
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose(
|
||||
[1, 0]).contiguous()
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
|
||||
|
||||
ffn_out = paddle.empty(
|
||||
(ffn_out.shape[0], layer.moe_ffn2_weight.shape[1]),
|
||||
dtype=paddle.bfloat16)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(ffn_in_x, ffn_in_x_scale_tensor),
|
||||
(layer.moe_ffn2_weight, layer.moe_ffn2_weight_scale),
|
||||
ffn_out,
|
||||
m_indices,
|
||||
)
|
||||
# prmt back per rank
|
||||
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
|
||||
ffn_out,
|
||||
dst_weights,
|
||||
permute_indices_per_token,
|
||||
dst_indices,
|
||||
None, # moe_ffn2_bias
|
||||
False, # norm_topk_prob
|
||||
1.0,
|
||||
)[0]
|
||||
|
||||
else:
|
||||
tmp_ffn_out = paddle.cast(recv_x[0], paddle.bfloat16)
|
||||
|
||||
# 5. EP combine
|
||||
return self.ep_prefill_runner.combine(tmp_ffn_out, handle,
|
||||
recv_topk_weights)
|
||||
|
||||
def apply_ep_decode(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
"""
|
||||
# 1. Select topk experts and weights
|
||||
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(
|
||||
layer, gate_out)
|
||||
# 2. EP Dispatch
|
||||
permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch(
|
||||
x, topk_idx, topk_weights, use_fp8=True)
|
||||
|
||||
# 3. Compute ffn
|
||||
assert isinstance(permute_input, tuple)
|
||||
ffn1_out = paddle.empty(
|
||||
[
|
||||
layer.num_local_experts,
|
||||
layer.ep_size *
|
||||
layer.moe_config.num_max_dispatch_tokens_per_rank,
|
||||
layer.moe_intermediate_size * 2,
|
||||
],
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
|
||||
ffn_out = paddle.empty(
|
||||
[
|
||||
layer.num_local_experts,
|
||||
layer.ep_size *
|
||||
layer.moe_config.num_max_dispatch_tokens_per_rank,
|
||||
layer.hidden_size,
|
||||
],
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
|
||||
expected_m = 128
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
permute_input,
|
||||
(
|
||||
layer.moe_ffn1_weight,
|
||||
layer.moe_ffn1_weight_scale,
|
||||
),
|
||||
ffn1_out,
|
||||
token_nums_per_expert,
|
||||
expected_m,
|
||||
)
|
||||
|
||||
act_out = fastdeploy.model_executor.ops.gpu.group_swiglu_with_masked(
|
||||
ffn1_out, token_nums_per_expert)
|
||||
|
||||
act_out_fp8, scale = fastdeploy.model_executor.ops.gpu.masked_per_token_quant(
|
||||
act_out, token_nums_per_expert,
|
||||
self.quant_config.weight_block_size[0])
|
||||
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
(act_out_fp8, scale),
|
||||
(
|
||||
layer.moe_ffn2_weight,
|
||||
layer.moe_ffn2_weight_scale,
|
||||
),
|
||||
ffn_out,
|
||||
token_nums_per_expert,
|
||||
expected_m,
|
||||
)
|
||||
|
||||
# 4. EP combine
|
||||
return self.ep_decoder_runner.combine(ffn_out, topk_idx, topk_weights,
|
||||
handle)
|
||||
|
||||
def apply_tp(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Use DeepGemm compute Fused MoE.
|
||||
below is TP compute method.
|
||||
"""
|
||||
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
layer.top_k,
|
||||
True, # apply_norm_weight
|
||||
False,
|
||||
)
|
||||
|
||||
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts)
|
||||
|
||||
recv_x, recv_x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
x, 128)
|
||||
|
||||
(
|
||||
permute_input,
|
||||
permute_scale,
|
||||
permute_indices_per_token,
|
||||
recv_num_tokens_per_expert_list_cumsum,
|
||||
recv_num_tokens_per_expert_list_padded_cumsum,
|
||||
dst_weights,
|
||||
dst_indices,
|
||||
cumsum_idx_gpu,
|
||||
m_indices,
|
||||
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8(
|
||||
recv_x,
|
||||
recv_x_scale,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
tmp[0],
|
||||
tmp[1],
|
||||
)
|
||||
|
||||
permute_scale = permute_scale.transpose([1, 0]).contiguous()
|
||||
permute_scale = permute_scale.transpose([1, 0])
|
||||
|
||||
# ffn1
|
||||
ffn_out = paddle.empty(
|
||||
(permute_input.shape[0], layer.moe_ffn1_weight.shape[1]),
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(permute_input, permute_scale),
|
||||
(layer.moe_ffn1_weight, layer.moe_ffn1_weight_scale),
|
||||
ffn_out,
|
||||
m_indices,
|
||||
)
|
||||
# swiglu
|
||||
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out)
|
||||
|
||||
# ffn2
|
||||
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
ffn_out, self.quant_config.weight_block_size[0])
|
||||
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose(
|
||||
[1, 0]).contiguous()
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
|
||||
|
||||
ffn_out = paddle.empty(
|
||||
(ffn_out.shape[0], layer.moe_ffn2_weight.shape[1]),
|
||||
dtype=paddle.bfloat16)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(ffn_in_x, ffn_in_x_scale_tensor),
|
||||
(layer.moe_ffn2_weight, layer.moe_ffn2_weight_scale),
|
||||
ffn_out,
|
||||
m_indices,
|
||||
)
|
||||
# prmt back per rank
|
||||
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
|
||||
ffn_out,
|
||||
dst_weights,
|
||||
permute_indices_per_token,
|
||||
dst_indices,
|
||||
None,
|
||||
False, # norm_topk_prob
|
||||
1.0,
|
||||
)[0]
|
||||
if layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(tmp_ffn_out)
|
||||
|
||||
return tmp_ffn_out
|
285
fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py
Normal file
285
fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.ops.gpu import (MoeWna16MarlinGemmApi,
|
||||
tritonmoe_preprocess_func)
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
|
||||
|
||||
def gptq_marlin_moe_repack(b_q_weight: paddle.Tensor, perm: paddle.Tensor,
|
||||
size_k: int, size_n: int,
|
||||
num_bits: int) -> paddle.Tensor:
|
||||
"""
|
||||
Util function.
|
||||
"""
|
||||
from fastdeploy.model_executor.ops.gpu import gptq_marlin_repack
|
||||
num_experts = b_q_weight.shape[0]
|
||||
assert size_k % 16 == 0
|
||||
output = paddle.empty(
|
||||
[num_experts, size_k // 16, size_n * (num_bits // 2)],
|
||||
dtype=b_q_weight.dtype)
|
||||
for e in range(num_experts):
|
||||
output[e] = gptq_marlin_repack(b_q_weight[e], perm[e], size_k, size_n,
|
||||
num_bits)
|
||||
return output
|
||||
|
||||
|
||||
def get_scale_perms():
|
||||
"""
|
||||
Util function.
|
||||
"""
|
||||
scale_perm: list[int] = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||
scale_perm_single: list[int] = []
|
||||
for i in range(4):
|
||||
scale_perm_single.extend(
|
||||
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||
return scale_perm, scale_perm_single
|
||||
|
||||
|
||||
def marlin_permute_scales(s: paddle.Tensor, size_k: int, size_n: int,
|
||||
group_size: int) -> paddle.Tensor:
|
||||
"""
|
||||
Util function.
|
||||
"""
|
||||
scale_perm, scale_perm_single = get_scale_perms()
|
||||
if group_size < size_k and group_size != -1:
|
||||
s = s.reshape([-1, len(scale_perm)])[:, scale_perm]
|
||||
else:
|
||||
s = s.reshape([-1, len(scale_perm_single)])[:, scale_perm_single]
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def marlin_moe_permute_scales(
|
||||
s: paddle.Tensor,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
group_size: int,
|
||||
):
|
||||
"""
|
||||
Util function.
|
||||
"""
|
||||
num_experts = s.shape[0]
|
||||
output = paddle.empty(
|
||||
[num_experts, s.shape[1], s.shape[2]],
|
||||
dtype=s.dtype,
|
||||
)
|
||||
|
||||
for e in range(num_experts):
|
||||
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
|
||||
return output
|
||||
|
||||
|
||||
class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Use Marlin Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_method=None):
|
||||
"""
|
||||
Marlin Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
self.quant_method = quant_method
|
||||
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
|
||||
self.added_scale_attrs = [
|
||||
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
|
||||
]
|
||||
self.added_zeros_attrs = ["zeros0", "zeros1"]
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Marlin MoE create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(ffn1_weights) == layer.num_local_experts
|
||||
assert len(ffn2_weights) == layer.num_local_experts
|
||||
assert ffn1_weights[0].shape == [
|
||||
layer.hidden_size, layer.moe_intermediate_size * 2
|
||||
]
|
||||
assert ffn2_weights[0].shape == [
|
||||
layer.moe_intermediate_size, layer.hidden_size
|
||||
]
|
||||
|
||||
ffn1_tensor = paddle.stack(ffn1_weights, axis=0)
|
||||
ffn2_tensor = paddle.stack(ffn2_weights, axis=0)
|
||||
|
||||
max_bound = 7
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
|
||||
weight_scale = weight_tensor.abs().max(axis=1)
|
||||
quanted_weight = weight_tensor / weight_scale[:,
|
||||
None, :] * max_bound
|
||||
quanted_weight = paddle.round(quanted_weight).astype("int32")
|
||||
|
||||
quanted_weight[quanted_weight > 7] = 7
|
||||
quanted_weight[quanted_weight < -7] = -7
|
||||
quanted_weight += 8
|
||||
|
||||
E, K, N = quanted_weight.shape
|
||||
quanted_weight = quanted_weight.reshape([0, K // 8, 8, N])
|
||||
res = paddle.zeros([E, K // 8, N], dtype='int32')
|
||||
for j in range(8):
|
||||
tmp = quanted_weight[:, :, j, :]
|
||||
res = res | (tmp << (j * 4))
|
||||
quanted_weight = paddle.assign(res)
|
||||
weight_scale = weight_scale / max_bound
|
||||
weight_scale = weight_scale[:, None, :]
|
||||
|
||||
group_size = -1 # means per_channel
|
||||
|
||||
g_idx_sort_indices = paddle.empty([E, 0], dtype="int32")
|
||||
quanted_weight = gptq_marlin_moe_repack(
|
||||
quanted_weight,
|
||||
g_idx_sort_indices,
|
||||
K,
|
||||
N,
|
||||
4,
|
||||
)
|
||||
|
||||
weight_scale = marlin_moe_permute_scales(
|
||||
weight_scale,
|
||||
size_k=layer.moe_intermediate_size, #useless
|
||||
size_n=N,
|
||||
group_size=group_size)
|
||||
|
||||
for (name, tensor) in [(weight_name, quanted_weight),
|
||||
(scale_name, weight_scale)]:
|
||||
setattr(
|
||||
layer, name,
|
||||
layer.create_parameter(
|
||||
shape=tensor.shape,
|
||||
dtype=tensor.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
))
|
||||
getattr(layer, name).set_value(tensor)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Marlin compute Fused MoE.
|
||||
"""
|
||||
token_num = x.shape[0]
|
||||
top_k = layer.top_k
|
||||
top_k = layer.top_k
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
num_experts = layer.num_experts
|
||||
|
||||
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
|
||||
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
top_k,
|
||||
True, # apply_norm_weight,
|
||||
False,
|
||||
)
|
||||
|
||||
block_size_m = 64
|
||||
|
||||
for m in [8, 16, 32, 48, 64]:
|
||||
if token_num * top_k / num_experts / m < 0.9:
|
||||
block_size_m = m
|
||||
break
|
||||
|
||||
topk = top_k
|
||||
|
||||
# for H100 132 sms
|
||||
workspace = paddle.empty([528], dtype="int32")
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
|
||||
topk_ids, num_experts, block_size_m)
|
||||
|
||||
ffn_out = MoeWna16MarlinGemmApi(
|
||||
x,
|
||||
c_or_none=None,
|
||||
b_q_weight=layer.moe_ffn1_weight,
|
||||
b_scales=layer.moe_ffn1_weight_scale,
|
||||
global_scale_or_none=None,
|
||||
b_zeros_or_none=None,
|
||||
g_idx_or_none=None,
|
||||
perm_or_none=None,
|
||||
workspace=workspace,
|
||||
sorted_token_ids=sorted_token_ids,
|
||||
expert_ids=expert_ids,
|
||||
num_tokens_post_padded=num_tokens_post_padded,
|
||||
topk_weights=topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=topk,
|
||||
mul_topk_weights=False,
|
||||
is_ep=False,
|
||||
b_q_type_str="uint4b8",
|
||||
size_m=token_num,
|
||||
size_n=moe_intermediate_size * 2,
|
||||
size_k=hidden_size,
|
||||
is_k_full=True,
|
||||
use_atomic_add=True,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False)[0]
|
||||
|
||||
swiglu_out = paddle.incubate.nn.functional.swiglu(ffn_out)
|
||||
|
||||
ffn_out = MoeWna16MarlinGemmApi(
|
||||
swiglu_out,
|
||||
c_or_none=None,
|
||||
b_q_weight=layer.moe_ffn2_weight,
|
||||
b_scales=layer.moe_ffn2_weight_scale,
|
||||
global_scale_or_none=None,
|
||||
b_zeros_or_none=None,
|
||||
g_idx_or_none=None,
|
||||
perm_or_none=None,
|
||||
workspace=workspace,
|
||||
sorted_token_ids=sorted_token_ids,
|
||||
expert_ids=expert_ids,
|
||||
num_tokens_post_padded=num_tokens_post_padded,
|
||||
topk_weights=topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=1,
|
||||
mul_topk_weights=True,
|
||||
is_ep=False,
|
||||
b_q_type_str="uint4b8",
|
||||
size_m=token_num * topk,
|
||||
size_n=hidden_size,
|
||||
size_k=moe_intermediate_size,
|
||||
is_k_full=True,
|
||||
use_atomic_add=True,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False)[0]
|
||||
|
||||
ffn_out.reshape_([token_num, -1, hidden_size])
|
||||
ffn_out = ffn_out.sum(axis=1)
|
||||
|
||||
if layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(ffn_out)
|
||||
|
||||
return ffn_out
|
@@ -1,57 +0,0 @@
|
||||
"""
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import \
|
||||
QuantMethodBase
|
||||
|
||||
|
||||
class FusedMoEMethodBase(QuantMethodBase):
|
||||
"""
|
||||
All MoE Method should inherit this class.
|
||||
and must implement following methods!
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self,
|
||||
layer: nn.Layer,
|
||||
moe_compute_params,
|
||||
ffn1_tensor,
|
||||
ffn2_tensor,
|
||||
ffn1_bias=None,
|
||||
ffn2_bias=None):
|
||||
"""
|
||||
How to create weights, you must implement this method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
moe_compute_params,
|
||||
x: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Compute methods, you must implement this method.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
479
fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py
Normal file
479
fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py
Normal file
@@ -0,0 +1,479 @@
|
||||
"""
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.utils import (create_hadamard_matrix_map,
|
||||
get_tensor)
|
||||
from fastdeploy.utils import ceil_div
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
|
||||
|
||||
class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Use Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_method=None):
|
||||
"""
|
||||
Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
self.quant_method = quant_method
|
||||
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
|
||||
self.added_scale_attrs = [
|
||||
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
|
||||
]
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
|
||||
"""process_prequanted_weights"""
|
||||
pass
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(ffn1_weights) == layer.num_local_experts
|
||||
assert len(ffn2_weights) == layer.num_local_experts
|
||||
assert layer.quant_method.quant_config.name() == "wint8"
|
||||
assert ffn1_weights[0].shape == [
|
||||
layer.hidden_size, layer.moe_intermediate_size * 2
|
||||
]
|
||||
assert ffn2_weights[0].shape == [
|
||||
layer.moe_intermediate_size, layer.hidden_size
|
||||
]
|
||||
|
||||
ffn1_tensor = paddle.stack(ffn1_weights, axis=0)
|
||||
ffn2_tensor = paddle.stack(ffn2_weights, axis=0)
|
||||
|
||||
if self.quant_config.name() == "wint8":
|
||||
max_bound = 127
|
||||
elif self.quant_config.name() == "wint4":
|
||||
max_bound = 7
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
|
||||
quanted_weight_scale = weight_tensor.abs().max(axis=1)
|
||||
quanted_weight = weight_tensor / quanted_weight_scale[:,
|
||||
None, :] * max_bound
|
||||
quanted_weight = paddle.round(quanted_weight).astype("int8")
|
||||
quanted_weight_scale = quanted_weight_scale / max_bound
|
||||
|
||||
setattr(
|
||||
layer, weight_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight.shape,
|
||||
dtype=quanted_weight.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
))
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
|
||||
setattr(
|
||||
layer, scale_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight_scale.shape,
|
||||
dtype=quanted_weight_scale.dtype,
|
||||
))
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
"""
|
||||
token_num = x.shape[0]
|
||||
top_k = layer.top_k
|
||||
num_local_experts = layer.num_local_experts
|
||||
top_k = layer.top_k
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
|
||||
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
|
||||
scores = paddle.nn.functional.softmax(gate_out, axis=-1)
|
||||
|
||||
topk_weights, topk_ids = paddle.topk(scores,
|
||||
k=top_k,
|
||||
axis=-1,
|
||||
sorted=False)
|
||||
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True)
|
||||
|
||||
intermediate_cache1 = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
)
|
||||
intermediate_cache2 = paddle.empty(
|
||||
(token_num * top_k, moe_intermediate_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
intermediate_cache3 = paddle.empty(
|
||||
(token_num * top_k, hidden_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
|
||||
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
|
||||
topk_ids, num_local_experts, config["BLOCK_SIZE_M"])
|
||||
max_num_tokens_padded = sorted_token_ids.shape[0]
|
||||
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x,
|
||||
layer.moe_ffn1_weight,
|
||||
intermediate_cache1,
|
||||
None,
|
||||
layer.moe_ffn1_weight_scale,
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
moe_intermediate_size * 2,
|
||||
hidden_size,
|
||||
max_num_tokens_padded,
|
||||
token_num * top_k,
|
||||
stride_am=x.strides[0],
|
||||
stride_ak=x.strides[1],
|
||||
stride_be=layer.moe_ffn1_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn1_weight.strides[1],
|
||||
stride_bn=layer.moe_ffn1_weight.strides[2],
|
||||
stride_cm=intermediate_cache1.strides[0],
|
||||
stride_cn=intermediate_cache1.strides[1],
|
||||
#
|
||||
stride_asm=-1,
|
||||
stride_ask=-1,
|
||||
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
|
||||
stride_bsk=-1,
|
||||
stride_bsn=layer.moe_ffn1_weight_scale.strides[1],
|
||||
group_n=-1,
|
||||
group_k=-1,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
||||
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
||||
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
||||
MUL_ROUTED_WEIGHT=False,
|
||||
top_k=top_k,
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a16=True,
|
||||
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
|
||||
intermediate_cache1)
|
||||
|
||||
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
||||
fused_moe_kernel_paddle[grid](
|
||||
intermediate_cache2,
|
||||
layer.moe_ffn2_weight,
|
||||
intermediate_cache3,
|
||||
None,
|
||||
layer.moe_ffn2_weight_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
hidden_size,
|
||||
moe_intermediate_size,
|
||||
max_num_tokens_padded,
|
||||
token_num * top_k,
|
||||
stride_am=intermediate_cache2.strides[0],
|
||||
stride_ak=intermediate_cache2.strides[1],
|
||||
stride_be=layer.moe_ffn2_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn2_weight.strides[1],
|
||||
stride_bn=layer.moe_ffn2_weight.strides[2],
|
||||
stride_cm=intermediate_cache3.strides[0],
|
||||
stride_cn=intermediate_cache3.strides[1],
|
||||
stride_asm=-1,
|
||||
stride_ask=-1,
|
||||
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
|
||||
stride_bsk=-1,
|
||||
stride_bsn=layer.moe_ffn2_weight_scale.strides[1],
|
||||
group_n=-1,
|
||||
group_k=-1,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
||||
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
||||
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
||||
MUL_ROUTED_WEIGHT=True,
|
||||
top_k=1,
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a16=True,
|
||||
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
|
||||
out = intermediate_cache3.sum(axis=1)
|
||||
return out
|
||||
|
||||
|
||||
class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Use Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_method=None):
|
||||
"""
|
||||
Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
self.quant_method = quant_method
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
|
||||
"""process_prequanted_weights"""
|
||||
|
||||
ffn1_tensor, ffn2_tensor = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert ffn1_tensor[0].shape == [
|
||||
layer.hidden_size, layer.moe_intermediate_size * 2
|
||||
]
|
||||
assert ffn2_tensor[0].shape == [
|
||||
layer.moe_intermediate_size, layer.hidden_size
|
||||
]
|
||||
|
||||
ffn1_tensor = paddle.stack(ffn1_tensor, axis=0)
|
||||
ffn2_tensor = paddle.stack(ffn2_tensor, axis=0)
|
||||
|
||||
added_wfp8afp8_attrs = [
|
||||
"moe_ffn1_weight", "moe_ffn2_weight", "moe_ffn1_weight_scale",
|
||||
"moe_ffn2_weight_scale", "moe_ffn1_in_scale", "moe_ffn2_in_scale"
|
||||
]
|
||||
|
||||
def _extract_scale_tensor(key_template):
|
||||
result = []
|
||||
for i in range(layer.num_experts):
|
||||
result.append(
|
||||
get_tensor(state_dict.pop(key_template.format(i))))
|
||||
return paddle.concat(result).cast("float32")
|
||||
|
||||
weight_key_map = layer.weight_key_map
|
||||
moe_ffn1_weight_scale = _extract_scale_tensor(
|
||||
weight_key_map["ffn1_expert_weight_scale_key"])
|
||||
moe_ffn2_weight_scale = _extract_scale_tensor(
|
||||
weight_key_map["ffn2_expert_weight_scale_key"])
|
||||
moe_ffn1_in_scale = _extract_scale_tensor(
|
||||
weight_key_map["ffn1_expert_in_scale_key"])
|
||||
moe_ffn2_in_scale = _extract_scale_tensor(
|
||||
weight_key_map["ffn2_expert_in_scale_key"])
|
||||
|
||||
for idx, weight_tensor in enumerate([
|
||||
ffn1_tensor, ffn2_tensor, moe_ffn1_weight_scale,
|
||||
moe_ffn2_weight_scale, moe_ffn1_in_scale, moe_ffn2_in_scale
|
||||
]):
|
||||
name = added_wfp8afp8_attrs[idx]
|
||||
setattr(
|
||||
layer, name,
|
||||
layer.create_parameter(
|
||||
shape=weight_tensor.shape,
|
||||
dtype=weight_tensor.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
))
|
||||
getattr(layer, name).set_value(weight_tensor)
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
"""
|
||||
|
||||
token_num = x.shape[0]
|
||||
top_k = layer.top_k
|
||||
num_local_experts = layer.num_local_experts
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
|
||||
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
|
||||
scores = paddle.nn.functional.softmax(gate_out, axis=-1)
|
||||
|
||||
topk_weights, topk_ids = paddle.topk(scores,
|
||||
k=top_k,
|
||||
axis=-1,
|
||||
sorted=False)
|
||||
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True)
|
||||
|
||||
intermediate_cache1 = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
)
|
||||
intermediate_cache2 = paddle.empty(
|
||||
(token_num * top_k, moe_intermediate_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
intermediate_cache3 = paddle.empty(
|
||||
(token_num * top_k, hidden_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
|
||||
topk_ids, num_local_experts, config["BLOCK_SIZE_M"])
|
||||
max_num_tokens_padded = sorted_token_ids.shape[0]
|
||||
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
|
||||
|
||||
adamard_matrix = create_hadamard_matrix_map[hidden_size]
|
||||
x = paddle.matmul(x.cast("float32"), adamard_matrix)
|
||||
|
||||
permute_x = x[:, None, :].tile([1, top_k, 1])
|
||||
permute_x = permute_x.reshape([-1, hidden_size])
|
||||
|
||||
quant_activation_scale = layer.moe_ffn1_in_scale[topk_ids].reshape(
|
||||
[-1, 1])
|
||||
permute_x = permute_x / quant_activation_scale
|
||||
permute_x = permute_x.astype("float8_e4m3fn")
|
||||
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
permute_x,
|
||||
layer.moe_ffn1_weight.view(paddle.float8_e4m3fn),
|
||||
intermediate_cache1,
|
||||
layer.moe_ffn1_in_scale,
|
||||
layer.moe_ffn1_weight_scale,
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
moe_intermediate_size * 2,
|
||||
hidden_size,
|
||||
max_num_tokens_padded,
|
||||
token_num * top_k,
|
||||
stride_am=x.strides[0],
|
||||
stride_ak=x.strides[1],
|
||||
stride_be=layer.moe_ffn1_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn1_weight.strides[1],
|
||||
stride_bn=layer.moe_ffn1_weight.strides[2],
|
||||
stride_cm=intermediate_cache1.strides[0],
|
||||
stride_cn=intermediate_cache1.strides[1],
|
||||
#
|
||||
stride_asm=-1, # only used in blockwise fp8
|
||||
stride_ask=-1, # only used in blockwise fp8
|
||||
stride_bse=-1,
|
||||
stride_bsk=-1,
|
||||
stride_bsn=-1,
|
||||
group_n=-1,
|
||||
group_k=-1,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
||||
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
||||
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
||||
MUL_ROUTED_WEIGHT=False,
|
||||
top_k=1,
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=True,
|
||||
use_int8_w8a16=False,
|
||||
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
|
||||
intermediate_cache1)
|
||||
|
||||
hadamard_matrix = create_hadamard_matrix_map[moe_intermediate_size]
|
||||
intermediate_cache2 = paddle.matmul(
|
||||
intermediate_cache2.cast("float32"), hadamard_matrix)
|
||||
quant_activation_scale = layer.moe_ffn2_in_scale[topk_ids].reshape(
|
||||
[-1, 1])
|
||||
intermediate_cache2 = intermediate_cache2 / quant_activation_scale
|
||||
intermediate_cache2 = intermediate_cache2.astype("float8_e4m3fn")
|
||||
|
||||
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
|
||||
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
intermediate_cache2,
|
||||
layer.moe_ffn2_weight.view(paddle.float8_e4m3fn),
|
||||
intermediate_cache3,
|
||||
layer.moe_ffn2_in_scale,
|
||||
layer.moe_ffn2_weight_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
hidden_size,
|
||||
moe_intermediate_size,
|
||||
max_num_tokens_padded,
|
||||
token_num * top_k,
|
||||
stride_am=intermediate_cache2.strides[0],
|
||||
stride_ak=intermediate_cache2.strides[1],
|
||||
stride_be=layer.moe_ffn2_weight.strides[0],
|
||||
stride_bk=layer.moe_ffn2_weight.strides[1],
|
||||
stride_bn=layer.moe_ffn2_weight.strides[2],
|
||||
stride_cm=intermediate_cache3.strides[0],
|
||||
stride_cn=intermediate_cache3.strides[1],
|
||||
stride_asm=-1,
|
||||
stride_ask=-1,
|
||||
stride_bse=-1,
|
||||
stride_bsk=-1,
|
||||
stride_bsn=-1,
|
||||
group_n=-1,
|
||||
group_k=-1,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
||||
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
||||
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
||||
MUL_ROUTED_WEIGHT=True,
|
||||
top_k=1,
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=True,
|
||||
use_int8_w8a16=False,
|
||||
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
|
||||
out = intermediate_cache3.sum(axis=1)
|
||||
|
||||
if layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(out)
|
||||
|
||||
return out
|
236
fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py
Normal file
236
fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
import fastdeploy
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
from ..utils import create_and_set_parameter, get_tensor
|
||||
|
||||
|
||||
class Wint2MoeMethod(QuantMethodBase):
|
||||
"""
|
||||
Use compute Fused MoE.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config):
|
||||
super().__init__()
|
||||
self.moe_quant_type = quant_config.moe_quant_type
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
"""
|
||||
process_loaded_weights
|
||||
"""
|
||||
pass
|
||||
|
||||
def check(self, layer: nn.Layer, ffn1_weights, ffn2_weights):
|
||||
"""
|
||||
check layer is valid for this method
|
||||
"""
|
||||
assert len(
|
||||
ffn1_weights
|
||||
) == layer.num_local_experts, "ffn1_weights length should be equal to num_local_experts."
|
||||
assert len(
|
||||
ffn2_weights
|
||||
) == layer.num_local_experts, "ffn2_weights length should be equal to num_local_experts."
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TritonWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
"""
|
||||
Use Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config):
|
||||
super().__init__(quant_config)
|
||||
self.moe_quant_type = quant_config.moe_quant_type
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
"""
|
||||
process_loaded_weights
|
||||
"""
|
||||
pass
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass process prequanted weights.
|
||||
"""
|
||||
ffn1_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_key", None)
|
||||
ffn2_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_key", None)
|
||||
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_scale_key", None)
|
||||
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_scale_key", None)
|
||||
ffn1_expert_super_scales_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_super_scales_key", None)
|
||||
ffn2_expert_super_scales_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_super_scales_key", None)
|
||||
ffn1_expert_code_scale_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_code_scale_key", None)
|
||||
ffn2_expert_code_scale_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_code_scale_key", None)
|
||||
ffn1_expert_code_zp_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_code_zp_key", None)
|
||||
ffn2_expert_code_zp_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_code_zp_key", None)
|
||||
|
||||
ffn1_weights, ffn2_weights = layer.load_experts_weight(
|
||||
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
|
||||
# self.check(layer, ffn1_weights, ffn2_weights)
|
||||
|
||||
ffn1_weight_scale = []
|
||||
ffn2_weight_scale = []
|
||||
ffn1_super_scales = []
|
||||
ffn2_super_scales = []
|
||||
ffn1_code_scale = []
|
||||
ffn2_code_scale = []
|
||||
ffn1_code_zp = []
|
||||
ffn2_code_zp = []
|
||||
for i in range(layer.num_experts):
|
||||
expert_idx = layer.expert_id_offset + i
|
||||
ffn1_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn1_expert_weight_scale_key.format(expert_idx))))
|
||||
ffn2_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_weight_scale_key.format(expert_idx))))
|
||||
ffn1_super_scales.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn1_expert_super_scales_key.format(expert_idx))))
|
||||
ffn2_super_scales.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_super_scales_key.format(expert_idx))))
|
||||
ffn1_code_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn1_expert_code_scale_key.format(expert_idx))))
|
||||
ffn2_code_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_code_scale_key.format(expert_idx))))
|
||||
ffn1_code_zp.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn1_expert_code_zp_key.format(expert_idx))))
|
||||
ffn2_code_zp.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_code_zp_key.format(expert_idx))))
|
||||
|
||||
ffn1_weight = paddle.stack(ffn1_weights, axis=0)
|
||||
ffn2_weight = paddle.stack(ffn2_weights, axis=0)
|
||||
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0)
|
||||
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0)
|
||||
ffn1_super_scales = paddle.stack(ffn1_super_scales, axis=0)
|
||||
ffn2_super_scales = paddle.stack(ffn2_super_scales, axis=0)
|
||||
ffn1_code_scale = paddle.stack(ffn1_code_scale, axis=0)
|
||||
ffn2_code_scale = paddle.stack(ffn2_code_scale, axis=0)
|
||||
ffn1_code_zp = paddle.stack(ffn1_code_zp, axis=0)
|
||||
ffn2_code_zp = paddle.stack(ffn2_code_zp, axis=0)
|
||||
|
||||
name_tensor_map = {
|
||||
"moe_ffn1_weight": ffn1_weight,
|
||||
"moe_ffn2_weight": ffn2_weight,
|
||||
"moe_ffn1_weight_scale": ffn1_weight_scale,
|
||||
"moe_ffn2_weight_scale": ffn2_weight_scale,
|
||||
"moe_ffn1_super_scales": ffn1_super_scales,
|
||||
"moe_ffn2_super_scales": ffn2_super_scales,
|
||||
"moe_ffn1_code_scale": ffn1_code_scale,
|
||||
"moe_ffn2_code_scale": ffn2_code_scale,
|
||||
"moe_ffn1_code_zp": ffn1_code_zp,
|
||||
"moe_ffn2_code_zp": ffn2_code_zp
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
create_and_set_parameter(layer, name, tensor)
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Use Wint2 Triton Fusedmoe compute Fused MoE.
|
||||
"""
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch
|
||||
(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
permute_indices_per_token,
|
||||
topk_weights,
|
||||
topk_idx,
|
||||
expert_idx_per_token,
|
||||
) = moe_expert_dispatch(
|
||||
x,
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
(layer.moe_ffn1_in_scale if hasattr(layer, "moe_ffn1_in_scale")
|
||||
else None), # if set, permute_input will be int8_t
|
||||
layer.top_k,
|
||||
False,
|
||||
topk_only_mode=False,
|
||||
)
|
||||
|
||||
ffn_out = fastdeploy.model_executor.ops.gpu.moe_expert_ffn_wint2(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
layer.moe_ffn1_weight,
|
||||
layer.moe_ffn2_weight,
|
||||
None,
|
||||
layer.moe_ffn1_super_scales,
|
||||
layer.moe_ffn2_super_scales,
|
||||
layer.moe_ffn1_weight_scale,
|
||||
layer.moe_ffn1_code_scale,
|
||||
layer.moe_ffn1_code_zp,
|
||||
layer.moe_ffn2_weight_scale,
|
||||
layer.moe_ffn2_code_scale,
|
||||
layer.moe_ffn2_code_zp,
|
||||
False,
|
||||
)
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import moe_expert_reduce
|
||||
|
||||
fused_moe_out = moe_expert_reduce(
|
||||
ffn_out,
|
||||
topk_weights,
|
||||
permute_indices_per_token,
|
||||
topk_idx,
|
||||
None,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
|
||||
return fused_moe_out
|
@@ -1,273 +0,0 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from fastdeploy.model_executor.layers.moe.moe import MoELayer
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
|
||||
|
||||
class TextMoELayer(MoELayer):
|
||||
"""
|
||||
MoELayer is a layer that performs MoE (Mixture of Experts) computation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
初始化函数,用于设置类的属性和方法。
|
||||
参数:
|
||||
- args (tuple, optional): 可变长度的位置参数列表,默认为空元组。
|
||||
- kwargs (dict, optional): 关键字参数字典,默认为空字典。
|
||||
返回值:
|
||||
无返回值,直接修改类的属性和方法。
|
||||
"""
|
||||
kwargs["moe_tag"] = "Text"
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def load_gate_state_dict(self, state_dict):
|
||||
"""
|
||||
加载门状态字典,用于初始化网络参数。
|
||||
将从给定的状态字典中弹出的参数赋值给网络的门参数。
|
||||
|
||||
Args:
|
||||
state_dict (OrderedDict): 包含网络门参数的字典。
|
||||
|
||||
Returns:
|
||||
tuple (list, list): 返回两个列表,分别代表上阶网关投影和下阶投影的参数。
|
||||
每个元素都是一个列表,长度为网络的专家数量。
|
||||
"""
|
||||
up_gate_proj_weight = []
|
||||
up_gate_proj_weight_scale = []
|
||||
down_proj_weight = []
|
||||
down_proj_weight_scale = []
|
||||
for j in range(0, self.num_experts):
|
||||
up_gate_proj_weight.append(
|
||||
get_tensor(state_dict.pop(self.ffn1_expert_weight_key.format(j)))
|
||||
)
|
||||
down_proj_weight.append(
|
||||
get_tensor(state_dict.pop(self.ffn2_expert_weight_key.format(j)))
|
||||
)
|
||||
return (
|
||||
up_gate_proj_weight,
|
||||
down_proj_weight,
|
||||
up_gate_proj_weight_scale,
|
||||
down_proj_weight_scale,
|
||||
)
|
||||
|
||||
def load_gate_correction_bias(self, state_dict):
|
||||
"""
|
||||
加载网关校正偏置。如果使用了网关校正偏置,则从state_dict中获取相应的张量并设置到网关校正偏置上。
|
||||
参数:
|
||||
state_dict (OrderedDict): 包含模型参数和状态的字典。
|
||||
返回值:
|
||||
无返回值,直接修改了网关校正偏置的值。
|
||||
"""
|
||||
if self.moe_config.moe_use_gate_correction_bias:
|
||||
gate_correction_bias_tensor = get_tensor(
|
||||
state_dict[self.gate_correction_bias_key]
|
||||
)
|
||||
self.gate_correction_bias.set_value(
|
||||
gate_correction_bias_tensor[0].unsqueeze(0)
|
||||
)
|
||||
|
||||
|
||||
class ImageMoELayer(MoELayer):
|
||||
"""
|
||||
MoELayer is a layer that performs MoE (Mixture of Experts) computation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
初始化函数,用于设置类的属性和方法。
|
||||
参数:
|
||||
- args (tuple, optional): 可变长度的位置参数列表,默认为空元组。
|
||||
- kwargs (dict, optional): 关键字参数字典,默认为空字典。
|
||||
返回值:
|
||||
无返回值,直接修改类的属性和方法。
|
||||
"""
|
||||
moe_quant_type = os.getenv("ELLM_MM_IMAGE_QUANT_TYPE", None)
|
||||
if moe_quant_type is not None:
|
||||
kwargs["moe_quant_type"] = moe_quant_type
|
||||
kwargs["moe_tag"] = "Image"
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def load_gate_state_dict(self, state_dict):
|
||||
"""
|
||||
加载门状态字典。
|
||||
从给定的状态字典中提取并返回两个专家的上下关门投影权重,以及两个专家的下降投影权重。
|
||||
参数:
|
||||
state_dict (OrderedDict): 包含网络参数的有序字典。
|
||||
返回值:
|
||||
tuple (list, list),分别是两个专家的上下关门投影权重和两个专家的下降投影权重,都是列表类型。
|
||||
"""
|
||||
up_gate_proj_weight = []
|
||||
up_gate_proj_weight_scale = []
|
||||
down_proj_weight = []
|
||||
down_proj_weight_scale = []
|
||||
for j in range(self.num_experts, self.num_experts + self.num_experts):
|
||||
up_gate_proj_weight.append(
|
||||
get_tensor(state_dict.pop(self.ffn1_expert_weight_key.format(j)))
|
||||
)
|
||||
down_proj_weight.append(
|
||||
get_tensor(state_dict.pop(self.ffn2_expert_weight_key.format(j)))
|
||||
)
|
||||
return (
|
||||
up_gate_proj_weight,
|
||||
down_proj_weight,
|
||||
up_gate_proj_weight_scale,
|
||||
down_proj_weight_scale,
|
||||
)
|
||||
|
||||
def load_gate_correction_bias(self, state_dict):
|
||||
"""
|
||||
加载门级别校正偏置参数,如果使用门级别校正偏置则从state_dict中获取并设置到gate_correction_bias中。
|
||||
参数:
|
||||
state_dict (OrderedDict): 模型的状态字典,包含所有需要被加载的参数。
|
||||
返回值:
|
||||
无返回值,直接修改了gate_correction_bias的值。
|
||||
"""
|
||||
if self.moe_config.moe_use_gate_correction_bias:
|
||||
gate_correction_bias_tensor = get_tensor(
|
||||
state_dict[self.gate_correction_bias_key]
|
||||
)
|
||||
self.gate_correction_bias.set_value(
|
||||
gate_correction_bias_tensor[1].unsqueeze(0)
|
||||
)
|
||||
|
||||
|
||||
class MultimodalityMoeLayer(nn.Layer):
|
||||
"""
|
||||
Multimodality MOE Layer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inference_args,
|
||||
layer_name,
|
||||
layer_idx,
|
||||
):
|
||||
"""
|
||||
初始化一个 MoELayer。
|
||||
|
||||
Args:
|
||||
inference_args (InferenceArgs): 推理参数类,包含了所有必要的配置信息。
|
||||
layer_name (str): 当前 MoE Layer 的名称。
|
||||
layer_idx (int): 当前 MoE Layer 在模型中的索引。
|
||||
|
||||
Returns:
|
||||
None, 无返回值。
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.text_moe_layer = TextMoELayer(
|
||||
inference_args=inference_args,
|
||||
moe_config=inference_args.moe_config,
|
||||
layer_name=layer_name + ".text",
|
||||
gate_weight_key=f"ernie.layers.{layer_idx}.mlp.gate.weight",
|
||||
ffn1_expert_weight_key=f"ernie.layers.{layer_idx}.mlp.experts"
|
||||
+ ".{}.up_gate_proj.weight",
|
||||
ffn2_expert_weight_key=f"ernie.layers.{layer_idx}.mlp.experts"
|
||||
+ ".{}.down_proj.weight",
|
||||
gate_correction_bias_key=f"ernie.layers.{layer_idx}.mlp.moe_statics.e_score_correction_bias",
|
||||
ffn1_bias_key=None,
|
||||
ffn2_bias_key=None,
|
||||
ffn1_shared_weight_key=None,
|
||||
ffn1_shared_bias_key=None,
|
||||
ffn2_shared_weight_key=None,
|
||||
ffn2_shared_bias_key=None,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
self.image_moe_layer = ImageMoELayer(
|
||||
inference_args=inference_args,
|
||||
moe_config=inference_args.moe_config_1,
|
||||
layer_name=layer_name + ".image",
|
||||
gate_weight_key=f"ernie.layers.{layer_idx}.mlp.gate.weight_1",
|
||||
ffn1_expert_weight_key=f"ernie.layers.{layer_idx}.mlp.experts"
|
||||
+ ".{}.up_gate_proj.weight",
|
||||
ffn2_expert_weight_key=f"ernie.layers.{layer_idx}.mlp.experts"
|
||||
+ ".{}.down_proj.weight",
|
||||
gate_correction_bias_key=f"ernie.layers.{layer_idx}.mlp.moe_statics.e_score_correction_bias",
|
||||
ffn1_bias_key=None,
|
||||
ffn2_bias_key=None,
|
||||
ffn1_shared_weight_key=None,
|
||||
ffn1_shared_bias_key=None,
|
||||
ffn2_shared_weight_key=None,
|
||||
ffn2_shared_bias_key=None,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
加载模型参数。
|
||||
将给定的字典中的参数覆盖到当前模型上,并返回一个新的字典,其中包含未被覆盖的键值对。
|
||||
|
||||
Args:
|
||||
state_dict (dict): 包含了要加载的模型参数的字典。
|
||||
|
||||
Returns:
|
||||
dict: 包含未被覆盖的键值对的字典。
|
||||
"""
|
||||
self.text_moe_layer.load_state_dict(state_dict)
|
||||
self.image_moe_layer.load_state_dict(state_dict)
|
||||
state_dict.pop(self.text_moe_layer.gate_correction_bias_key)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
"""
|
||||
前向计算函数,将输入的张量进行处理并返回结果。
|
||||
该函数接受以下键值对参数:
|
||||
- token_type_ids (Optional, Tensor, default=None): 一个bool型Tensor,用于指定每个元素是否为文本类型(值为0)或图像类型(值为1)。
|
||||
如果未提供此参数,则会引发AssertionError。
|
||||
返回值是一个Tensor,形状与输入相同,表示处理后的结果。
|
||||
|
||||
Args:
|
||||
x (Tensor): 输入张量,形状为[token_num, hidden_size],其中token_num是序列长度,hidden_size是隐藏状态维度。
|
||||
kwargs (dict, optional): 可选参数字典,默认为None,包含以下键值对:
|
||||
- token_type_ids (Tensor, optional): 一个bool型Tensor,用于指定每个元素是否为文本类型(值为0)或图像类型(值为1),默认为None。
|
||||
|
||||
Returns:
|
||||
Tensor: 一个Tensor,形状与输入相同,表示处理后的结果。
|
||||
|
||||
Raises:
|
||||
AssertionError: 当未提供token_type_ids参数时会引发此错误。
|
||||
"""
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
assert token_type_ids is not None
|
||||
|
||||
# x.shape is [token_num, hidden_size]
|
||||
fused_moe_out = paddle.zeros_like(x)
|
||||
|
||||
text_mask = token_type_ids == 0 # [token_num]
|
||||
image_mask = token_type_ids == 1
|
||||
|
||||
if text_mask.any():
|
||||
text_out = self.text_moe_layer(x[text_mask])
|
||||
fused_moe_out[text_mask] = text_out
|
||||
|
||||
if image_mask.any():
|
||||
image_out = self.image_moe_layer(x[image_mask])
|
||||
fused_moe_out[image_mask] = image_out
|
||||
|
||||
return fused_moe_out
|
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,34 +14,13 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddlenlp.utils.log import logger
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
|
||||
from .cutlass_fused_moe import CutlassFusedMoeMethod
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoEComputeParams:
|
||||
"""
|
||||
some params for computing MoE.
|
||||
it is given to different compute methods.
|
||||
"""
|
||||
global_num_experts: int = -1
|
||||
top_k: int = -1
|
||||
hidden_size: int = -1
|
||||
num_local_experts: int = -1
|
||||
moe_intermediate_size: int = -1
|
||||
|
||||
tp_size: int = -1
|
||||
ep_size: int = -1
|
||||
dp_size: int = -1
|
||||
|
||||
moe_quant_type: str = ""
|
||||
|
||||
|
||||
class FusedMoE(nn.Layer):
|
||||
"""
|
||||
@@ -50,174 +29,195 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config,
|
||||
fd_config,
|
||||
moe_intermediate_size: int = -1,
|
||||
num_experts: int = -1,
|
||||
expert_id_offset: int = 0,
|
||||
top_k: int = -1,
|
||||
moe_use_gate_correction_bias: bool = False,
|
||||
moe_quant_type: str = "weight_only_int4",
|
||||
layer_idx: int = -1,
|
||||
gate_weight_key=None,
|
||||
gate_correction_bias_key=None,
|
||||
ffn1_expert_weight_key=None,
|
||||
ffn2_expert_weight_key=None,
|
||||
moe_ffn1_bias_keys=None,
|
||||
moe_ffn2_bias_keys=None,
|
||||
moe_ffn1_weight_scale_keys=None,
|
||||
moe_ffn2_weight_scale_keys=None,
|
||||
moe_ffn1_in_scale_keys=None,
|
||||
moe_ffn2_in_scale_keys=None,
|
||||
moe_tag: str = "",
|
||||
weight_key_map: dict = {},
|
||||
):
|
||||
"""
|
||||
Initialize the Moe layer with given parameters.
|
||||
Args:
|
||||
llm_config (LLMConfig): Arguments related to inference, containing
|
||||
fd_config (FDConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.llm_config = llm_config
|
||||
self.fd_config = fd_config
|
||||
self.layer_idx = layer_idx
|
||||
self.tp_size = llm_config.parallel_config.mp_size
|
||||
self.ep_size = llm_config.parallel_config.ep_size
|
||||
|
||||
self.moe_use_gate_correction_bias = moe_use_gate_correction_bias
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.ep_size = fd_config.parallel_config.expert_parallel_degree
|
||||
self.ep_rank = fd_config.parallel_config.expert_parallel_rank
|
||||
|
||||
assert (self.tp_size >= 1 and self.ep_size == 1) or \
|
||||
(self.tp_size == 1 and self.ep_size > 1), \
|
||||
'MoE only support parallelism on TP or EP dimension.'
|
||||
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.moe_config = fd_config.moe_config
|
||||
|
||||
self.hidden_size = llm_config.model_config.hidden_size
|
||||
self.moe_config = llm_config.moe_config
|
||||
self.use_offline_quant = llm_config.tmp_config.use_offline_quant
|
||||
moe_tag = self.llm_config.moe_config.moe_tag
|
||||
logger.info(f"{moe_tag}MoE is running in {moe_quant_type} mode")
|
||||
|
||||
self.moe_quant_type = moe_quant_type
|
||||
self.num_experts = num_experts
|
||||
self.num_local_experts = self.num_experts // self.ep_size
|
||||
|
||||
logger.info(f'''MoE config is num_experts:{num_experts},
|
||||
top_k:{top_k},
|
||||
hidden_size:{self.hidden_size},
|
||||
moe_intermediate_size:{moe_intermediate_size}''')
|
||||
logger.info(
|
||||
f"MoE is running on moe_quant_type: {self.moe_quant_type}, ep:{self.ep_size}, tp:{self.tp_size} mode"
|
||||
)
|
||||
self.moe_intermediate_size = moe_intermediate_size // self.tp_size
|
||||
|
||||
self.gate_weight_key = gate_weight_key
|
||||
self.gate_correction_bias_key = gate_correction_bias_key
|
||||
self.top_k = top_k
|
||||
self.hidden_size = self.hidden_size
|
||||
self.moe_intermediate_size = moe_intermediate_size // self.tp_size
|
||||
self.weight_key_map = weight_key_map
|
||||
|
||||
self.ffn1_expert_weight_key = ffn1_expert_weight_key
|
||||
self.ffn2_expert_weight_key = ffn2_expert_weight_key
|
||||
self.ffn1_bias_key = moe_ffn1_bias_keys
|
||||
self.ffn2_bias_key = moe_ffn2_bias_keys
|
||||
self.use_method = envs.FD_MOE_BACKEND.lower()
|
||||
self.gate_correction_bias = None
|
||||
self.moe_tag = moe_tag
|
||||
|
||||
if self.moe_quant_type == "w4a8":
|
||||
# below keys are only used in MoE W4A8!
|
||||
self.ffn1_expert_weight_scale_key = moe_ffn1_weight_scale_keys
|
||||
self.ffn2_expert_weight_scale_key = moe_ffn2_weight_scale_keys
|
||||
self.ffn1_expert_in_scale_key = moe_ffn1_in_scale_keys
|
||||
self.ffn2_expert_in_scale_key = moe_ffn2_in_scale_keys
|
||||
if self.ep_size > 1:
|
||||
expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts
|
||||
|
||||
self.compute_method = CutlassFusedMoeMethod()
|
||||
self.expert_id_offset = expert_id_offset
|
||||
|
||||
self.moe_compute_params = MoEComputeParams()
|
||||
self.moe_compute_params.global_num_experts = self.num_experts
|
||||
self.moe_compute_params.top_k = top_k
|
||||
self.moe_compute_params.hidden_size = self.hidden_size
|
||||
self.moe_compute_params.num_local_experts = self.num_local_experts
|
||||
self.moe_compute_params.moe_quant_type = self.moe_quant_type
|
||||
self.moe_compute_params.moe_intermediate_size = self.moe_intermediate_size
|
||||
self.moe_compute_params.ep_size = self.ep_size
|
||||
self.moe_compute_params.tp_size = self.tp_size
|
||||
if fd_config.quant_config:
|
||||
self.quant_method = fd_config.quant_config.get_quant_method(self)
|
||||
else:
|
||||
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
|
||||
from .fused_moe_cutlass_backend import CutlassMoEMethod
|
||||
self.quant_method = CutlassMoEMethod(None)
|
||||
|
||||
def load_gate_state_dict(self, state_dict):
|
||||
if self.ep_size > 1:
|
||||
self.quant_method.init_ep(self)
|
||||
|
||||
logger.info(
|
||||
f"{moe_tag}MoE config is {num_experts=}[{expert_id_offset}, {expert_id_offset+self.num_local_experts}), \
|
||||
{top_k=}, hidden_size={self.hidden_size}, {moe_intermediate_size=}, \
|
||||
, ep_size={self.ep_size}, \
|
||||
tp_size={self.tp_size}.")
|
||||
|
||||
def load_experts_weight(self, state_dict: dict,
|
||||
ffn1_expert_weight_key: str,
|
||||
ffn2_expert_weight_key: str):
|
||||
"""
|
||||
load_gate_state_dict function.
|
||||
Load experts weight from state_dict.
|
||||
Args:
|
||||
state_dict (dict): The state_dict of model.
|
||||
ffn1_expert_weight_key (str): The key of ffn1 expert weight.
|
||||
ffn2_expert_weight_key (str): The key of ffn2 expert weight.
|
||||
"""
|
||||
up_gate_proj_weight = []
|
||||
up_gate_proj_weight_scale = []
|
||||
down_proj_weight = []
|
||||
down_proj_weight_scale = []
|
||||
for j in range(self.num_experts):
|
||||
up_gate_proj_weight.append(
|
||||
get_tensor(
|
||||
state_dict.pop(self.ffn1_expert_weight_key.format(j))))
|
||||
down_proj_weight.append(
|
||||
get_tensor(
|
||||
state_dict.pop(self.ffn2_expert_weight_key.format(j))))
|
||||
return up_gate_proj_weight, down_proj_weight
|
||||
ffn1_weights = []
|
||||
ffn2_weights = []
|
||||
is_ffn_merged = ffn1_expert_weight_key.format(
|
||||
self.expert_id_offset) in state_dict
|
||||
if is_ffn_merged:
|
||||
for i in range(self.num_local_experts):
|
||||
expert_idx = self.expert_id_offset + i
|
||||
ffn1_weights.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn1_expert_weight_key.format(expert_idx))))
|
||||
ffn2_weights.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_weight_key.format(expert_idx))))
|
||||
else:
|
||||
gate_expert_weight_key = ffn1_expert_weight_key.replace(
|
||||
"up_gate_proj", "gate_proj")
|
||||
up_expert_weight_key = ffn1_expert_weight_key.replace(
|
||||
"up_gate_proj", "up_proj")
|
||||
for j in range(self.num_local_experts):
|
||||
expert_idx = self.expert_id_offset + j
|
||||
gate = get_tensor(
|
||||
state_dict.pop(gate_expert_weight_key.format(expert_idx)))
|
||||
up = get_tensor(
|
||||
state_dict.pop(up_expert_weight_key.format(expert_idx)))
|
||||
ffn1_weights.append(paddle.concat([gate, up], axis=-1))
|
||||
ffn2_weights.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_weight_key.format(expert_idx))))
|
||||
return ffn1_weights, ffn2_weights
|
||||
|
||||
def load_state_dict(self, state_dict, is_update: bool = False):
|
||||
def extract_moe_ffn_weights(self, state_dict: dict):
|
||||
"""
|
||||
Extract MoE FFN weights from state dict based on weight key mapping.
|
||||
|
||||
Args:
|
||||
state_dict (dict): Model state dictionary containing the weights.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing two lists:
|
||||
- ffn1_weights: List of tensors for first FFN layer weights
|
||||
- ffn2_weights: List of tensors for second FFN layer weights
|
||||
|
||||
Raises:
|
||||
AssertionError: If required weight keys are missing or number of weights
|
||||
doesn't match number of local experts.
|
||||
"""
|
||||
ffn1_expert_weight_key = self.weight_key_map.get(
|
||||
"ffn1_expert_weight_key", None)
|
||||
ffn2_expert_weight_key = self.weight_key_map.get(
|
||||
"ffn2_expert_weight_key", None)
|
||||
assert ffn1_expert_weight_key is not None, "ffn1_expert_weight_key should not be none."
|
||||
assert ffn2_expert_weight_key is not None, "ffn2_expert_weight_key should not be none."
|
||||
|
||||
ffn1_weights, ffn2_weights = self.load_experts_weight(
|
||||
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
|
||||
assert len(
|
||||
ffn1_weights
|
||||
) == self.num_local_experts, "ffn1_weights length should be equal to num_local_experts."
|
||||
assert len(
|
||||
ffn2_weights
|
||||
) == self.num_local_experts, "ffn2_weights length should be equal to num_local_experts."
|
||||
|
||||
return ffn1_weights, ffn2_weights
|
||||
|
||||
def extract_gate_correction_bias(self, gate_correction_bias_key,
|
||||
state_dict):
|
||||
"""
|
||||
extract_gate_correction_bias function.
|
||||
"""
|
||||
gate_correction_bias_tensor = get_tensor(
|
||||
state_dict.pop(gate_correction_bias_key)).astype("float32")
|
||||
return gate_correction_bias_tensor
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
load_state_dict function.
|
||||
"""
|
||||
# gate
|
||||
if not is_update:
|
||||
gate_weight_tensor = get_tensor(state_dict.pop(self.gate_weight_key))
|
||||
self.gate_weight = self.create_parameter(
|
||||
shape=gate_weight_tensor.shape,
|
||||
dtype="float32",
|
||||
)
|
||||
self.gate_weight.set_value(gate_weight_tensor)
|
||||
|
||||
# gate_correction_bias
|
||||
self.gate_correction_bias_key = self.weight_key_map.get(
|
||||
"gate_correction_bias_key", None)
|
||||
if self.gate_correction_bias_key is not None and self.gate_correction_bias_key in state_dict:
|
||||
self.moe_use_gate_correction_bias = True
|
||||
else:
|
||||
self.moe_use_gate_correction_bias = False
|
||||
if self.moe_use_gate_correction_bias:
|
||||
gate_correction_bias_tensor = get_tensor(
|
||||
state_dict.pop(self.gate_correction_bias_key))
|
||||
|
||||
gate_correction_bias_tensor = self.extract_gate_correction_bias(
|
||||
self.gate_correction_bias_key, state_dict)
|
||||
self.gate_correction_bias = self.create_parameter(
|
||||
shape=gate_correction_bias_tensor.shape,
|
||||
dtype="float32",
|
||||
)
|
||||
|
||||
self.gate_correction_bias.set_value(gate_correction_bias_tensor)
|
||||
|
||||
gate_weight_key = self.weight_key_map.get("gate_weight_key", None)
|
||||
assert gate_weight_key is not None, "gate_weight_key should not be None, please check model checkpoints"
|
||||
|
||||
gate_weight_tensor = get_tensor(state_dict.pop(gate_weight_key))
|
||||
|
||||
self.gate_weight = self.create_parameter(
|
||||
shape=gate_weight_tensor.shape,
|
||||
dtype="float32",
|
||||
)
|
||||
self.gate_weight.set_value(gate_weight_tensor.astype("float32"))
|
||||
|
||||
if self.fd_config.model_config.is_quantized:
|
||||
self.quant_method.process_prequanted_weights(self, state_dict)
|
||||
else:
|
||||
self.gate_correction_bias = None
|
||||
self.quant_method.create_weights(self, state_dict)
|
||||
|
||||
up_gate_proj_weight, down_proj_weight = self.load_gate_state_dict(
|
||||
state_dict)
|
||||
|
||||
weight1_scale = None
|
||||
weight2_scale = None
|
||||
ffn1_in_scale = None
|
||||
ffn2_in_scale = None
|
||||
if self.moe_quant_type == "w4a8":
|
||||
weight1_scale = []
|
||||
weight2_scale = []
|
||||
ffn1_in_scale = []
|
||||
ffn2_in_scale = []
|
||||
|
||||
for j in range(self.num_experts):
|
||||
weight1_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
self.ffn1_expert_weight_scale_key.format(
|
||||
self.layer_idx, j))))
|
||||
weight2_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
self.ffn2_expert_weight_scale_key.format(
|
||||
self.layer_idx, j))))
|
||||
ffn1_in_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
self.ffn1_expert_in_scale_key.format(
|
||||
self.layer_idx, j))))
|
||||
ffn2_in_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
self.ffn2_expert_in_scale_key.format(
|
||||
self.layer_idx, j))))
|
||||
|
||||
# other weight is with compute_method
|
||||
# different method may have different way to create weights
|
||||
self.compute_method.create_weights(self, self.moe_compute_params,
|
||||
up_gate_proj_weight,
|
||||
down_proj_weight, None, None,
|
||||
weight1_scale, weight2_scale,
|
||||
ffn1_in_scale, ffn2_in_scale)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
def forward(self, x: paddle.Tensor):
|
||||
"""
|
||||
Defines the forward computation of the moe layer.
|
||||
|
||||
@@ -225,13 +225,9 @@ class FusedMoE(nn.Layer):
|
||||
x (Tensor): Input tensor to the moe layer.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor.
|
||||
Tensor: Output tensor.s
|
||||
|
||||
"""
|
||||
|
||||
out = self.compute_method.apply(self, self.moe_compute_params, x)
|
||||
if self.tp_size > 1:
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
tensor_model_parallel_all_reduce(out)
|
||||
gate_out = paddle.matmul(x.cast("float32"), self.gate_weight)
|
||||
out = self.quant_method.apply(self, x, gate_out)
|
||||
return out
|
||||
|
@@ -1,126 +0,0 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import paddle
|
||||
import fastdeploy
|
||||
import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm
|
||||
from fastdeploy.model_executor.layers.moe.moe import MoELayer
|
||||
|
||||
|
||||
class MoeTPDecoerDeepDeepGEMMLayer(MoELayer):
|
||||
"""
|
||||
MoeTPDecoerDeepDeepGEMMLayer
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
"""
|
||||
forward
|
||||
"""
|
||||
gate_out = paddle.matmul(x.cast("float32"), self.gate_weight)
|
||||
if os.getenv("EP_DECODER_PERF_TEST", "False") == "True":
|
||||
gate_out = paddle.rand(shape=gate_out.shape, dtype=gate_out.dtype)
|
||||
ffn1_out = paddle.empty(
|
||||
[
|
||||
self.num_local_experts,
|
||||
self.max_batch_size,
|
||||
self.moe_intermediate_size * 2,
|
||||
],
|
||||
dtype=self._dtype,
|
||||
)
|
||||
|
||||
ffn_out = paddle.empty(
|
||||
[
|
||||
self.num_local_experts,
|
||||
self.max_batch_size,
|
||||
self.embed_dim,
|
||||
],
|
||||
dtype=self._dtype,
|
||||
)
|
||||
|
||||
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
(
|
||||
self.gate_correction_bias
|
||||
if self.moe_config.moe_use_gate_correction_bias
|
||||
else None
|
||||
),
|
||||
self.top_k,
|
||||
True, # apply_norm_weight
|
||||
False,
|
||||
)
|
||||
permute_input, token_nums_per_expert, permute_indices_per_token = (
|
||||
fastdeploy.model_executor.ops.gpu.moe_deepgemm_permute(
|
||||
x, topk_idx, self.num_local_experts, self.max_batch_size
|
||||
)
|
||||
)
|
||||
|
||||
expected_m = 128
|
||||
|
||||
permute_input_fp8, scale = fastdeploy.model_executor.ops.gpu.masked_per_token_quant(
|
||||
permute_input, token_nums_per_expert, 128
|
||||
)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
(permute_input_fp8, scale),
|
||||
(
|
||||
self.moe_ffn1_weight,
|
||||
self.moe_ffn1_weight_scale,
|
||||
),
|
||||
ffn1_out,
|
||||
token_nums_per_expert,
|
||||
expected_m,
|
||||
)
|
||||
|
||||
act_out = fastdeploy.model_executor.ops.gpu.group_swiglu_with_masked(
|
||||
ffn1_out, token_nums_per_expert
|
||||
)
|
||||
|
||||
act_out_fp8, scale = fastdeploy.model_executor.ops.gpu.masked_per_token_quant(
|
||||
act_out, token_nums_per_expert, 128
|
||||
)
|
||||
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
(act_out_fp8, scale),
|
||||
(
|
||||
self.moe_ffn2_weight,
|
||||
self.moe_ffn2_weight_scale,
|
||||
),
|
||||
ffn_out,
|
||||
token_nums_per_expert,
|
||||
expected_m,
|
||||
)
|
||||
|
||||
fused_moe_out = fastdeploy.model_executor.ops.gpu.moe_deepgemm_depermute(
|
||||
ffn_out, permute_indices_per_token, topk_idx, topk_weights
|
||||
)[0]
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
|
||||
class MoeTPPrefillDeepDeepGEMMLayer(MoELayer):
|
||||
"""
|
||||
MoeTPPrefillDeepDeepGEMMLayer
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
"""
|
||||
forward
|
||||
"""
|
||||
raise NotImplementedError("Prefill is comming soon...")
|
198
fastdeploy/model_executor/layers/moe/triton_moe_kernels.py
Normal file
198
fastdeploy/model_executor/layers/moe/triton_moe_kernels.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_moe_kernel_paddle(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
topk_weights_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
num_tokens_post_padded_ptr,
|
||||
|
||||
# Matrix dimensions
|
||||
N,
|
||||
K,
|
||||
num_tokens_post_padded,
|
||||
num_valid_tokens,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
# Block size for block-wise fp8 quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
compute_type_enum: tl.constexpr,
|
||||
use_fp8_w8a8: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr,
|
||||
even_Ks: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
|
||||
Key Parameters:
|
||||
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
||||
be any shape representing batches and K is the feature dimension of
|
||||
each token.
|
||||
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
||||
the number of experts, K is the input feature dimension, and N is
|
||||
the output feature dimension.
|
||||
- C: The output cache tensor with shape (M, topk, N), where M is the
|
||||
total number of tokens post padding, topk is the number of times
|
||||
each token is repeated, and N is the output feature dimension.
|
||||
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
||||
repeated topk times and arranged by the expert index they are
|
||||
assigned to.
|
||||
- expert_ids: A tensor containing the indices of the expert for each
|
||||
block. It determines which expert matrix from B should be used for
|
||||
each block in A.
|
||||
This kernel performs the multiplication of a token by its corresponding
|
||||
expert matrix as determined by `expert_ids`. The sorting of
|
||||
`sorted_token_ids` by expert index and padding ensures divisibility by
|
||||
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
||||
multiplication across different blocks processed by the same expert.
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(num_tokens_post_padded, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
assert compute_type_enum == 1
|
||||
compute_type = tl.bfloat16
|
||||
|
||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||
return
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
||||
offs_k[None, :] * stride_ak)
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m)
|
||||
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
|
||||
offs_bn[None, :] * stride_bn)
|
||||
|
||||
if use_int8_w8a16:
|
||||
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[
|
||||
None, :] * stride_bsn
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
|
||||
if use_fp8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
||||
offs_bsn = offs_bn // group_n
|
||||
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
||||
else:
|
||||
# (Zkk): every expert has one activation scale and weight scale.
|
||||
a_scale = tl.load(a_scale_ptr + off_experts)
|
||||
b_scale = tl.load(b_scale_ptr + off_experts)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
if even_Ks:
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs,
|
||||
cache_modifier=".cv",
|
||||
eviction_policy='evict_first')
|
||||
else:
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] &
|
||||
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
|
||||
# We accumulate along the K dimension.
|
||||
if use_int8_w8a16:
|
||||
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
||||
elif use_fp8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
k_start = k * BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask,
|
||||
mask=token_mask,
|
||||
other=0.0)
|
||||
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
||||
|
||||
accumulator += tl.dot(a, b) * a_scale[:,
|
||||
None] * b_scale[None, :]
|
||||
else:
|
||||
accumulator = tl.dot(a, b, acc=accumulator)
|
||||
else:
|
||||
accumulator += tl.dot(a, b)
|
||||
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
||||
mask=token_mask,
|
||||
other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
if use_int8_w8a16:
|
||||
accumulator = (accumulator * b_scale).to(compute_type)
|
||||
elif use_fp8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
else:
|
||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||
else:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
||||
None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
@@ -28,18 +28,19 @@ class RMSNorm(nn.Layer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config,
|
||||
fd_config,
|
||||
hidden_size,
|
||||
eps=1e-5,
|
||||
prefix="",
|
||||
linear_bias=None,
|
||||
quant_scale=None,
|
||||
begin_norm_axis=1,
|
||||
):
|
||||
"""
|
||||
Initializes the normalization layer.
|
||||
|
||||
Args:
|
||||
llm_config (LLMConfig): Arguments related to inference, containing
|
||||
fd_config (FDConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
hidden_size (int) : size of hidden state.
|
||||
@@ -52,7 +53,7 @@ class RMSNorm(nn.Layer):
|
||||
NotImplementedError: If the specified norm_type is not supported.
|
||||
"""
|
||||
super().__init__()
|
||||
self.llm_config = llm_config
|
||||
self.fd_config = fd_config
|
||||
self.prefix = prefix
|
||||
self.hidden_size = hidden_size
|
||||
if len(prefix) == 0:
|
||||
@@ -66,6 +67,11 @@ class RMSNorm(nn.Layer):
|
||||
self.quant_scale = quant_scale
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
self._norm_weight_dtype = self._dtype
|
||||
self.begin_norm_axis = begin_norm_axis
|
||||
self.quant_round_type = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
|
||||
self.quant_max_bound = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
|
||||
self.quant_min_bound = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
|
||||
self.begin_norm_axis = begin_norm_axis
|
||||
|
||||
self.init_weight()
|
||||
|
||||
@@ -118,13 +124,13 @@ class RMSNorm(nn.Layer):
|
||||
norm_weight=self.ln_weight,
|
||||
norm_bias=None,
|
||||
epsilon=self.eps,
|
||||
begin_norm_axis=1,
|
||||
begin_norm_axis=self.begin_norm_axis,
|
||||
bias=self.linear_bias,
|
||||
residual=residual_input,
|
||||
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
|
||||
quant_round_type=self.llm_config.quant_config.quant_round_type,
|
||||
quant_max_bound=self.llm_config.quant_config.quant_max_bound,
|
||||
quant_min_bound=self.llm_config.quant_config.quant_min_bound,
|
||||
quant_round_type=self.quant_round_type,
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
if residual_input is not None:
|
||||
return norm_out[0], norm_out[1]
|
||||
@@ -139,7 +145,7 @@ class LayerNorm(nn.Layer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config,
|
||||
fd_config,
|
||||
hidden_size,
|
||||
eps=1e-5,
|
||||
prefix="",
|
||||
@@ -151,7 +157,7 @@ class LayerNorm(nn.Layer):
|
||||
Initializes the normalization layer.
|
||||
|
||||
Args:
|
||||
llm_config (LLMConfig): Arguments related to inference, containing
|
||||
fd_config (FDConfig): Arguments related to inference, containing
|
||||
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
|
||||
num_attention_heads, and ffn_hidden_size.
|
||||
prefix (str): Unique name of the layer, used for naming internal attributes,
|
||||
@@ -163,7 +169,7 @@ class LayerNorm(nn.Layer):
|
||||
NotImplementedError: If the specified norm_type is not supported.
|
||||
"""
|
||||
super().__init__()
|
||||
self.llm_config = llm_config
|
||||
self.fd_config = fd_config
|
||||
self.prefix = prefix
|
||||
self.hidden_size = hidden_size
|
||||
if len(prefix) == 0:
|
||||
@@ -180,6 +186,10 @@ class LayerNorm(nn.Layer):
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
self._norm_weight_dtype = "float32"
|
||||
|
||||
self.quant_round_type = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
|
||||
self.quant_max_bound = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
|
||||
self.quant_min_bound = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
|
||||
|
||||
self.init_weight()
|
||||
|
||||
def init_weight(self):
|
||||
@@ -240,6 +250,7 @@ class LayerNorm(nn.Layer):
|
||||
The `residual_output` is the result of applying the normalization and possibly other
|
||||
operations (like linear transformation) on the `residual_input`.
|
||||
"""
|
||||
|
||||
norm_out = self.norm_func(
|
||||
x,
|
||||
norm_weight=self.ln_weight,
|
||||
@@ -249,9 +260,9 @@ class LayerNorm(nn.Layer):
|
||||
bias=self.linear_bias,
|
||||
residual=residual_input,
|
||||
quant_scale=-1,
|
||||
quant_round_type=self.llm_config.quant_config.quant_round_type,
|
||||
quant_max_bound=self.llm_config.quant_config.quant_max_bound,
|
||||
quant_min_bound=self.llm_config.quant_config.quant_min_bound,
|
||||
quant_round_type=self.quant_round_type,
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
if residual_input is not None:
|
||||
return norm_out[0], norm_out[1]
|
||||
|
@@ -19,11 +19,18 @@ from typing import Dict, List, Type
|
||||
from .quant_base import QuantConfigBase
|
||||
|
||||
QUANTIZATION_METHODS: List[str] = [
|
||||
"wint2",
|
||||
"wint4",
|
||||
"wint8",
|
||||
"weight_only",
|
||||
"block_wise",
|
||||
"block_wise_fp8",
|
||||
"w4afp8",
|
||||
"w8a8",
|
||||
"w4a8",
|
||||
"wfp8afp8",
|
||||
"mix_quant",
|
||||
"tensor_wise_fp8",
|
||||
"kvcache",
|
||||
]
|
||||
|
||||
|
||||
@@ -34,20 +41,30 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]:
|
||||
if quantization not in QUANTIZATION_METHODS:
|
||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||
|
||||
from .block_wise import BlockWiseConfig
|
||||
from .block_wise_fp8 import BlockWiseFP8Config
|
||||
from .kv_cache import KvCacheQuantConfig
|
||||
from .mix_quant import MixQuantConfig
|
||||
from .tensor_wise_fp8 import TensorWiseFP8Config
|
||||
from .w4a8 import W4A8Config
|
||||
from .w4afp8 import W4AFP8Config
|
||||
from .w8a8 import W8A8Config
|
||||
from .weight_only import WeightOnlyConfig
|
||||
from .weight_only import WeightOnlyConfig, WINT4Config, WINT8Config
|
||||
from .wfp8afp8 import WFP8AFP8Config
|
||||
from .kv_cache import KvCacheQuantConfig
|
||||
|
||||
from .wint2 import WINT2Config
|
||||
|
||||
method_to_config: Dict[str, Type[QuantConfigBase]] = {
|
||||
"wint2": WINT2Config,
|
||||
"wint4": WINT4Config,
|
||||
"wint8": WINT8Config,
|
||||
"weight_only": WeightOnlyConfig,
|
||||
"block_wise": BlockWiseConfig,
|
||||
"block_wise_fp8": BlockWiseFP8Config,
|
||||
"w4afp8": W4AFP8Config,
|
||||
"w8a8": W8A8Config,
|
||||
"w4a8": W4A8Config,
|
||||
"wfp8afp8": WFP8AFP8Config,
|
||||
"kvcache": KvCacheQuantConfig
|
||||
"tensor_wise_fp8": TensorWiseFP8Config,
|
||||
"kvcache": KvCacheQuantConfig,
|
||||
"mix_quant": MixQuantConfig,
|
||||
}
|
||||
|
||||
return method_to_config[quantization]
|
||||
|
@@ -18,16 +18,13 @@ from typing import Optional
|
||||
import paddle
|
||||
|
||||
import fastdeploy
|
||||
import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm
|
||||
from fastdeploy.model_executor.layers.moe import FusedMoE
|
||||
|
||||
from ..utils import per_block_cast_to_fp8
|
||||
from ..utils import per_block_cast_to_fp8, get_tensor
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
QUANT_ALIGNMENT_OFFSET = 127
|
||||
QUANT_BLOCK_SIZE = 128
|
||||
|
||||
|
||||
class BlockWiseConfig(QuantConfigBase):
|
||||
class BlockWiseFP8Config(QuantConfigBase):
|
||||
"""
|
||||
block wise quantization config, only support fp8 quant and only supports loading weights in BF16 format.
|
||||
After loading the weights, it will automatically compute quantization sparsity and dynamically perform
|
||||
@@ -37,41 +34,55 @@ class BlockWiseConfig(QuantConfigBase):
|
||||
def __init__(self, weight_block_size: list = [-1, -1]) -> None:
|
||||
super().__init__()
|
||||
self.weight_block_size = weight_block_size
|
||||
self.quant_max_bound = 448
|
||||
self.quant_min_bound = -448
|
||||
self.quant_round_type = 1
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "block_wise"
|
||||
def name(self) -> str:
|
||||
return "block_wise_fp8"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "BlockWiseConfig":
|
||||
weight_block_size = config["weight_block_size"]
|
||||
def from_config(cls, config: dict) -> "BlockWiseFP8Config":
|
||||
weight_block_size = config.get("weight_block_size", [128, 128])
|
||||
return cls(weight_block_size)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
return BlockWiseLinearMethod(self)
|
||||
'''
|
||||
Get quantization method.
|
||||
'''
|
||||
if isinstance(layer, FusedMoE):
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import \
|
||||
DeepGemmFusedMoeMethod
|
||||
return DeepGemmFusedMoeMethod(self)
|
||||
else:
|
||||
return BlockWiseFP8LinearMethod(self)
|
||||
|
||||
|
||||
class BlockWiseLinearMethod(QuantMethodBase):
|
||||
class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
"""
|
||||
block wise quantization method for linear
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: BlockWiseConfig,
|
||||
quant_config: BlockWiseFP8Config,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer):
|
||||
layer.linear_weight_scale = self.create_parameter(
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
shape=[
|
||||
(layer.embed_dim + QUANT_ALIGNMENT_OFFSET) // QUANT_BLOCK_SIZE,
|
||||
(layer.num_heads * layer.head_dim + QUANT_ALIGNMENT_OFFSET) //
|
||||
QUANT_BLOCK_SIZE,
|
||||
(layer.output_size + self.quant_config.weight_block_size[0] -
|
||||
1) // self.quant_config.weight_block_size[0],
|
||||
(layer.input_size + self.quant_config.weight_block_size[1] - 1)
|
||||
// self.quant_config.weight_block_size[1],
|
||||
],
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
)
|
||||
layer.weight_dtype = "float8_e4m3fn"
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
weight_tensor = weights.transpose([1, 0])
|
||||
@@ -80,15 +91,30 @@ class BlockWiseLinearMethod(QuantMethodBase):
|
||||
layer.linear_weight.copy_(quanted_weight_tensor, False)
|
||||
layer.linear_weight_scale.set_value(weight_block_scale_tensor)
|
||||
|
||||
def process_prequanted_weights(self, layer, state_dict):
|
||||
"""
|
||||
process_prequanted_weights
|
||||
"""
|
||||
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
|
||||
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
|
||||
|
||||
quant_weight = quant_weight.transpose([1, 0]).contiguous()
|
||||
layer.linear_weight.copy_(quant_weight.view("float8_e4m3fn"), False)
|
||||
|
||||
weight_scale = weight_scale.transpose([1, 0])
|
||||
layer.linear_weight_scale.set_value(weight_scale)
|
||||
|
||||
def apply(self, layer, x):
|
||||
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant_padding(
|
||||
x, self.quant_config.weight_block_size[0])
|
||||
linear_out = paddle.empty(
|
||||
(x.shape[0], layer.llm_config.model_config.hidden_size),
|
||||
dtype=paddle.bfloat16)
|
||||
linear_out = paddle.empty((x.shape[0], layer.output_size),
|
||||
dtype=paddle.bfloat16)
|
||||
import fastdeploy.model_executor.ops.gpu.deep_gemm as deep_gemm
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt(
|
||||
(x, x_scale_tensor),
|
||||
(layer.linear_weight, layer.linear_weight_scale),
|
||||
linear_out,
|
||||
)
|
||||
if layer.with_bias:
|
||||
linear_out = paddle.add(linear_out, layer.linear_bias)
|
||||
return linear_out
|
@@ -13,38 +13,66 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
from paddle import nn
|
||||
import os
|
||||
import paddle
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
|
||||
from ..utils import create_and_set_parameter
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
||||
class KvCacheQuantzationTypes(str, Enum):
|
||||
"""
|
||||
KvCacheQuantzationTypes
|
||||
"""
|
||||
INT8 = "int8"
|
||||
FP8 = "float8_e4m3fn"
|
||||
INT8_ZP = "int8_zp"
|
||||
FP8_ZP = "float8_e4m3fn_zp"
|
||||
|
||||
|
||||
class KvCacheQuantConfig(QuantConfigBase):
|
||||
"""
|
||||
quantization config for weight 4bits and activation fp8
|
||||
"""
|
||||
|
||||
def __init__(self, cachekv_scale_dict) -> None:
|
||||
def __init__(self, kv_cache_quant_type: str) -> None:
|
||||
"""
|
||||
__init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.cachekv_scale_dict = cachekv_scale_dict
|
||||
self.kv_cache_quant_type = kv_cache_quant_type
|
||||
|
||||
def get_name(self) -> str:
|
||||
try:
|
||||
self.quant_type = KvCacheQuantzationTypes(kv_cache_quant_type)
|
||||
except ValueError:
|
||||
raise ValueError(f'Invalid Kvcache type: {kv_cache_quant_type}')
|
||||
|
||||
self.has_zero_point = "zp" in kv_cache_quant_type
|
||||
|
||||
if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP:
|
||||
self.max_bound = 127.0
|
||||
elif self.quant_type == KvCacheQuantzationTypes.FP8 or self.quant_type == KvCacheQuantzationTypes.FP8_ZP:
|
||||
self.max_bound = 448.0
|
||||
else:
|
||||
raise ValueError(f'Invalid Kvcache type: {kv_cache_quant_type}')
|
||||
|
||||
def name(self) -> str:
|
||||
"""
|
||||
get_name
|
||||
"""
|
||||
return "kvcache"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "KvCacheQuantConfig":
|
||||
def from_config(cls, kv_cache_quant_type: str) -> "KvCacheQuantConfig":
|
||||
"""
|
||||
from_config
|
||||
"""
|
||||
cachekv_scale_dict = config["cachekv_scale_dict"]
|
||||
return cls(cachekv_scale_dict)
|
||||
return cls(kv_cache_quant_type)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
"""
|
||||
@@ -66,197 +94,63 @@ class KVCacheMethodBase(QuantMethodBase):
|
||||
KVCacheMethodBase __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.quant_config = quant_config
|
||||
self.cache_quant_config = quant_config
|
||||
|
||||
def load_zp(self, layer: nn.Layer):
|
||||
def load_zp(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
load_zp
|
||||
"""
|
||||
if self.cache_k_zp_name in self.quant_config.cachekv_scale_dict:
|
||||
cache_k_zp = paddle.cast(
|
||||
paddle.to_tensor(
|
||||
self.quant_config.cachekv_scale_dict[self.cache_k_zp_name]
|
||||
),
|
||||
self.cache_scale_dtype,
|
||||
)
|
||||
else:
|
||||
cache_k_zp = paddle.zeros(
|
||||
(
|
||||
[self.kv_num_heads * self.head_dim]
|
||||
if self.quant_config.is_channel_wise
|
||||
else [self.kv_num_heads]
|
||||
),
|
||||
dtype=self.cache_scale_dtype,
|
||||
)
|
||||
if self.cache_v_zp_name in self.quant_config.cachekv_scale_dict:
|
||||
cache_v_zp = paddle.cast(
|
||||
paddle.to_tensor(
|
||||
self.quant_config.cachekv_scale_dict[self.cache_v_zp_name]
|
||||
),
|
||||
self.cache_scale_dtype,
|
||||
)
|
||||
else:
|
||||
cache_v_zp = paddle.zeros(
|
||||
(
|
||||
[self.kv_num_heads * self.head_dim]
|
||||
if self.quant_config.is_channel_wise
|
||||
else [self.kv_num_heads]
|
||||
),
|
||||
dtype=self.cache_scale_dtype,
|
||||
)
|
||||
layer.cache_k_zp.set_value(cache_k_zp)
|
||||
layer.cache_v_zp.set_value(cache_v_zp)
|
||||
cache_k_zeropoint = get_tensor(state_dict.pop(self.cache_k_zp_name))
|
||||
cache_v_zeropoint = get_tensor(state_dict.pop(self.cache_v_zp_name))
|
||||
|
||||
def load_scale(self, layer: nn.Layer):
|
||||
create_and_set_parameter(layer, "cache_k_zp", cache_k_zeropoint)
|
||||
create_and_set_parameter(layer, "cache_v_zp", cache_v_zeropoint)
|
||||
|
||||
def load_scale(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
load_scale
|
||||
"""
|
||||
if self.cache_k_scale_name in self.quant_config.cachekv_scale_dict:
|
||||
cache_k_scale = paddle.cast(
|
||||
paddle.to_tensor(
|
||||
self.quant_config.cachekv_scale_dict[self.cache_k_scale_name]
|
||||
),
|
||||
self.cache_scale_dtype,
|
||||
)
|
||||
cache_k_out_scale = 1.0 / cache_k_scale
|
||||
else:
|
||||
raise KeyError(
|
||||
f"{self.cache_k_scale_name} not found in scale dict")
|
||||
cache_k_scale_tensor = get_tensor(
|
||||
state_dict.pop(self.cache_k_scale_name)).cast(
|
||||
paddle.get_default_dtype()).reshape_([-1])
|
||||
cache_v_scale_tensor = get_tensor(
|
||||
state_dict.pop(self.cache_v_scale_name)).cast(
|
||||
paddle.get_default_dtype()).reshape_([-1])
|
||||
|
||||
if self.cache_v_scale_name in self.quant_config.cachekv_scale_dict:
|
||||
cache_v_scale = paddle.cast(
|
||||
paddle.to_tensor(
|
||||
self.quant_config.cachekv_scale_dict[self.cache_v_scale_name]
|
||||
),
|
||||
self.cache_scale_dtype,
|
||||
)
|
||||
cache_v_out_scale = 1.0 / cache_v_scale
|
||||
else:
|
||||
raise KeyError(
|
||||
f"{self.cache_v_scale_name} not found in scale dict")
|
||||
cache_k_scale = self.cache_quant_config.max_bound / cache_k_scale_tensor
|
||||
cache_v_scale = self.cache_quant_config.max_bound / cache_v_scale_tensor
|
||||
cache_k_out_scale = cache_k_scale_tensor / self.cache_quant_config.max_bound
|
||||
cache_v_out_scale = cache_v_scale_tensor / self.cache_quant_config.max_bound
|
||||
|
||||
if self.cache_v_scale_name in self.quant_config.cachekv_scale_dict:
|
||||
cache_v_scale = paddle.cast(
|
||||
paddle.to_tensor(
|
||||
self.quant_config.cachekv_scale_dict[self.cache_v_scale_name]
|
||||
),
|
||||
self.cache_scale_dtype,
|
||||
)
|
||||
cache_v_out_scale = 1.0 / cache_v_scale
|
||||
else:
|
||||
raise KeyError(
|
||||
f"{self.cache_v_scale_name} not found in scale dict")
|
||||
create_and_set_parameter(layer, "cache_k_scale", cache_k_scale)
|
||||
create_and_set_parameter(layer, "cache_v_scale", cache_v_scale)
|
||||
create_and_set_parameter(layer, "cache_k_out_scale", cache_k_out_scale)
|
||||
create_and_set_parameter(layer, "cache_v_out_scale", cache_v_out_scale)
|
||||
|
||||
layer.cache_k_scale.set_value(cache_k_scale)
|
||||
layer.cache_v_scale.set_value(cache_v_scale)
|
||||
layer.cache_k_out_scale.set_value(cache_k_out_scale)
|
||||
layer.cache_v_out_scale.set_value(cache_v_out_scale)
|
||||
|
||||
def create_scale(self, layer: nn.Layer):
|
||||
"""
|
||||
create_scale
|
||||
"""
|
||||
layer.cache_k_scale = layer.create_parameter(
|
||||
shape=(
|
||||
[layer.kv_num_heads * layer.head_dim]
|
||||
if self.quant_config.is_channel_wise
|
||||
else [layer.kv_num_heads]
|
||||
),
|
||||
dtype=self.cache_scale_dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
layer.cache_v_scale = layer.create_parameter(
|
||||
shape=(
|
||||
[layer.kv_num_heads * layer.head_dim]
|
||||
if self.quant_config.is_channel_wise
|
||||
else [layer.kv_num_heads]
|
||||
),
|
||||
dtype=self.cache_scale_dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
layer.cache_k_out_scale = layer.create_parameter(
|
||||
shape=(
|
||||
[layer.kv_num_heads * layer.head_dim]
|
||||
if self.quant_config.is_channel_wise
|
||||
else [layer.kv_num_heads]
|
||||
),
|
||||
attr=None,
|
||||
dtype=self.cache_scale_dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
layer.cache_v_out_scale = layer.create_parameter(
|
||||
shape=(
|
||||
[layer.kv_num_heads * layer.head_dim]
|
||||
if self.quant_config.is_channel_wise
|
||||
else [layer.kv_num_heads]
|
||||
),
|
||||
attr=None,
|
||||
dtype=self.cache_scale_dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
def create_zp(self, layer: nn.Layer):
|
||||
"""
|
||||
create_zp
|
||||
"""
|
||||
layer.cache_k_zp = layer.create_parameter(
|
||||
shape=(
|
||||
[layer.kv_num_heads * layer.head_dim]
|
||||
if self.quant_config.is_channel_wise
|
||||
else [layer.kv_num_heads]
|
||||
),
|
||||
dtype=self.cache_scale_dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
layer.cache_v_zp = layer.create_parameter(
|
||||
shape=(
|
||||
[layer.kv_num_heads * layer.head_dim]
|
||||
if self.quant_config.is_channel_wise
|
||||
else [layer.kv_num_heads]
|
||||
),
|
||||
dtype=self.cache_scale_dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
def create_weights(self, layer: nn.Layer):
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
create_weights
|
||||
"""
|
||||
self.prefix = layer.prefix
|
||||
self.cache_k_scale_name = layer.prefix + ".cachek_matmul.activation_quanter"
|
||||
self.cache_v_scale_name = layer.prefix + ".cachev_matmul.activation_quanter"
|
||||
self.cache_k_zp_name = layer.cache_k_scale_name + ".zero_point"
|
||||
self.cache_v_zp_name = layer.cache_v_scale_name + ".zero_point"
|
||||
self.cache_k_scale_name = layer.prefix + ".cachek_matmul.activation_scale"
|
||||
self.cache_v_scale_name = layer.prefix + ".cachev_matmul.activation_scale"
|
||||
self.cache_k_zp_name = layer.prefix + ".cachek_matmul.activation_zero_point"
|
||||
self.cache_v_zp_name = layer.prefix + ".cachev_matmul.activation_zero_point"
|
||||
|
||||
layer.cache_k_zp = None
|
||||
layer.cache_v_zp = None
|
||||
layer.cache_k_scale = None
|
||||
layer.cache_v_scale = None
|
||||
layer.cache_k_out_scale = None
|
||||
layer.cache_v_out_scale = None
|
||||
if self.cache_quant_config.quant_type == KvCacheQuantzationTypes.INT8:
|
||||
setattr(layer, "cache_quant_type_str", "cache_int8")
|
||||
setattr(layer, "quant_max_bound", 127.0)
|
||||
setattr(layer, "quant_min_bound", -127.0)
|
||||
elif self.cache_quant_config.quant_type == KvCacheQuantzationTypes.FP8:
|
||||
setattr(layer, "cache_quant_type_str", "cache_fp8")
|
||||
setattr(layer, "quant_max_bound", 448.0)
|
||||
setattr(layer, "quant_min_bound", -448.0)
|
||||
else:
|
||||
raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented")
|
||||
|
||||
self._dtype = layer._dtype
|
||||
if self._dtype != "bfloat16" and self._dtype != "float16" and self._dtype == "float32":
|
||||
raise ValueError(
|
||||
f"Just support float32, float16 and \
|
||||
bfloat16 as default dtype, but received {self._dtype}"
|
||||
)
|
||||
self.cache_scale_dtype = (
|
||||
self._dtype if self.quant_config.use_append_attn else "float32"
|
||||
)
|
||||
|
||||
if not self.quant_config.use_dynamic_cachekv_quant:
|
||||
if (
|
||||
self.quant_config.cachekv_dtype == "int8"
|
||||
or self.quant_config.cachekv_dtype == "int4"
|
||||
or self.quant_config.cachekv_dtype == "float8_e4m3fn"
|
||||
):
|
||||
self.create_scale(layer)
|
||||
self.load_scale(layer)
|
||||
if self.quant_config.has_zero_point:
|
||||
self.create_zp(layer)
|
||||
self.load_zp(layer)
|
||||
layer.cache_quant_type_str = self.quant_config.cache_quant_type
|
||||
self.load_scale(layer, state_dict)
|
||||
if self.cache_quant_config.has_zero_point:
|
||||
self.load_zp(layer, state_dict)
|
||||
|
||||
def apply(self, layer):
|
||||
"""
|
||||
@@ -264,4 +158,3 @@ class KVCacheMethodBase(QuantMethodBase):
|
||||
"""
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__}.apply should not be called.")
|
||||
|
||||
|
75
fastdeploy/model_executor/layers/quantization/mix_quant.py
Normal file
75
fastdeploy/model_executor/layers/quantization/mix_quant.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from ..attention import Attention
|
||||
from ..moe import FusedMoE
|
||||
from . import get_quantization_config
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
||||
class MixQuantConfig(QuantConfigBase):
|
||||
"""
|
||||
Quantization config for layers that has different quantization methods.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dense_quant_type: str,
|
||||
moe_quant_type: str,
|
||||
kv_cache_quant_type: str = None,
|
||||
image_moe_quant_type: str = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dense_quant_type = dense_quant_type
|
||||
self.moe_quant_type = moe_quant_type
|
||||
self.kv_cache_quant_type = kv_cache_quant_type
|
||||
if image_moe_quant_type is None:
|
||||
self.image_moe_quant_type = moe_quant_type
|
||||
else:
|
||||
self.image_moe_quant_type = image_moe_quant_type
|
||||
self.quant_max_bound = 0
|
||||
self.quant_min_bound = 0
|
||||
self.quant_round_type = 0
|
||||
|
||||
def name(self) -> str:
|
||||
return "mix_quant"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "MixQuantConfig":
|
||||
return cls(config['dense_quant_type'], config['moe_quant_type'],
|
||||
config.get('kv_cache_quant_type', None),
|
||||
config.get('image_moe_quant_type', None))
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
if isinstance(layer, FusedMoE):
|
||||
if layer.moe_tag == "Image":
|
||||
return get_quantization_config(
|
||||
self.image_moe_quant_type).from_config(
|
||||
{}).get_quant_method(layer)
|
||||
else:
|
||||
return get_quantization_config(
|
||||
self.moe_quant_type).from_config(
|
||||
{}).get_quant_method(layer)
|
||||
elif isinstance(layer, Attention):
|
||||
if self.kv_cache_quant_type is not None:
|
||||
return (get_quantization_config("kvcache").from_config(
|
||||
self.kv_cache_quant_type).get_quant_method(layer))
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return get_quantization_config(self.dense_quant_type).from_config(
|
||||
{}).get_quant_method(layer)
|
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
from .cutlass_scaled_mm import cutlass_scaled_mm
|
||||
from .scaled_fp8_quant import scaled_fp8_quant
|
||||
|
||||
__all__ = [
|
||||
"cutlass_scaled_mm",
|
||||
"scaled_fp8_quant",
|
||||
]
|
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
|
||||
import fastdeploy
|
||||
|
||||
|
||||
def cutlass_scaled_mm(a: paddle.Tensor,
|
||||
b: paddle.Tensor,
|
||||
scale_a: paddle.Tensor,
|
||||
scale_b: paddle.Tensor,
|
||||
out_dtype: paddle.dtype,
|
||||
bias: Optional[paddle.Tensor] = None) -> paddle.Tensor:
|
||||
"""
|
||||
`cutlass_scaled_mm` implements a fused version of
|
||||
`output = paddle.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
|
||||
where scale_a * a and scale_b * b are implemented using numpy-style
|
||||
broadcasting.
|
||||
|
||||
In order to support blockwise scaling like found in DeepSeek V3 we also
|
||||
support extended "group" broadcast rules. We extend the numpy-style
|
||||
broadcasting rules with the following rule:
|
||||
"if the extent of a dimension in the source shape is between 1 and
|
||||
corresponding extent in the target shape we repeat each element along
|
||||
that dimension src_shape[dim] // target_shape[dim] times consecutively"
|
||||
example if we have:
|
||||
a = [[1, 2], and target_shape = (2, 4)
|
||||
[3, 4]]
|
||||
then we would expand a to:
|
||||
a = [[1, 1, 2, 2],
|
||||
[3, 3, 4, 4]]
|
||||
currently we only support the case:
|
||||
scale_a.shape * [1, 128] == a.shape
|
||||
scale_b.shape * [128, 128] == b.shape
|
||||
"""
|
||||
assert (out_dtype == paddle.bfloat16 or out_dtype == paddle.float16)
|
||||
assert bias is None or bias.shape[0] == b.shape[
|
||||
0] and bias.dtype == out_dtype
|
||||
# Ensure input tensors have valid shapes
|
||||
# assert a.numel() > 0, "Input tensor 'a' must not be empty"
|
||||
# assert b.numel() > 0, "Input tensor 'b' must not be empty"
|
||||
# assert scale_a.numel() > 0, "Scale tensor 'scale_a' must not be empty"
|
||||
# assert scale_b.numel() > 0, "Scale tensor 'scale_b' must not be empty"
|
||||
|
||||
m = a.shape[0]
|
||||
n = b.shape[0]
|
||||
cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||
assert cutlass_compatible_b
|
||||
|
||||
out = paddle.empty([m, n], dtype=out_dtype)
|
||||
fastdeploy.model_executor.ops.gpu.cutlass_scaled_mm(
|
||||
out, a, b, scale_a, scale_b, bias)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def scaled_fp8_quant(
|
||||
input: paddle.Tensor,
|
||||
scale: Optional[paddle.Tensor] = None,
|
||||
num_token_padding: Optional[int] = None,
|
||||
scale_ub: float = 0,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP8 and return quantized tensor and scale.
|
||||
|
||||
This function supports both static and dynamic quantization: If you
|
||||
provide the scale, it will use static scaling and if you omit it,
|
||||
the scale will be determined dynamically. The function also allows
|
||||
optional padding of the output tensors for downstream kernels that
|
||||
will benefit from padding.
|
||||
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP8
|
||||
scale: Optional scaling factor for the FP8 quantization
|
||||
scale_ub: Optional upper bound for scaling factor in dynamic
|
||||
per token case
|
||||
num_token_padding: If specified, pad the first dimension
|
||||
of the output to at least this value.
|
||||
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
||||
in the dynamic quantization case.
|
||||
|
||||
Returns:
|
||||
tuple[paddle.Tensor, paddle.Tensor]: The output tensor in FP8 and
|
||||
scaling factor.
|
||||
"""
|
||||
# This code assumes batch_dim and num_tokens are flattened
|
||||
assert (input.ndim == 2)
|
||||
shape = input.shape
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
output = paddle.empty(shape, dtype=paddle.float8_e4m3fn)
|
||||
|
||||
if scale is None:
|
||||
if use_per_token_if_dynamic:
|
||||
scale = paddle.empty([shape[0], 1], dtype=paddle.float32)
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
dynamic_per_token_scaled_fp8_quant
|
||||
dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
|
||||
else:
|
||||
scale = paddle.zeros([1], dtype=paddle.float32)
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
dynamic_scaled_fp8_quant
|
||||
dynamic_scaled_fp8_quant(output, input, scale)
|
||||
else:
|
||||
# num_token_padding not implemented for this case
|
||||
# assert (scale.numel() == 1 or num_token_padding is None)
|
||||
from fastdeploy.model_executor.ops.gpu import static_scaled_fp8_quant
|
||||
static_scaled_fp8_quant(output, input, scale)
|
||||
|
||||
return output, scale
|
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
|
||||
|
||||
def scaled_fp8_quant(
|
||||
input: paddle.Tensor,
|
||||
scale: Optional[paddle.Tensor] = None,
|
||||
num_token_padding: Optional[int] = None,
|
||||
scale_ub: float = 0,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP8 and return quantized tensor and scale.
|
||||
|
||||
This function supports both static and dynamic quantization: If you
|
||||
provide the scale, it will use static scaling and if you omit it,
|
||||
the scale will be determined dynamically. The function also allows
|
||||
optional padding of the output tensors for downstream kernels that
|
||||
will benefit from padding.
|
||||
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP8
|
||||
scale: Optional scaling factor for the FP8 quantization
|
||||
scale_ub: Optional upper bound for scaling factor in dynamic
|
||||
per token case
|
||||
num_token_padding: If specified, pad the first dimension
|
||||
of the output to at least this value.
|
||||
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
||||
in the dynamic quantization case.
|
||||
|
||||
Returns:
|
||||
tuple[paddle.Tensor, paddle.Tensor]: The output tensor in FP8 and
|
||||
scaling factor.
|
||||
"""
|
||||
# This code assumes batch_dim and num_tokens are flattened
|
||||
assert (input.ndim == 2)
|
||||
shape = input.shape
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
output = paddle.empty(shape, dtype=paddle.float8_e4m3fn)
|
||||
|
||||
if scale is None:
|
||||
if use_per_token_if_dynamic:
|
||||
scale = paddle.empty([shape[0], 1], dtype=paddle.float32)
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
dynamic_per_token_scaled_fp8_quant
|
||||
dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
|
||||
else:
|
||||
scale = paddle.zeros([1], dtype=paddle.float32)
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
dynamic_scaled_fp8_quant
|
||||
dynamic_scaled_fp8_quant(output, input, scale)
|
||||
else:
|
||||
# num_token_padding not implemented for this case
|
||||
# assert (scale.numel() == 1 or num_token_padding is None)
|
||||
from fastdeploy.model_executor.ops.gpu import static_scaled_fp8_quant
|
||||
static_scaled_fp8_quant(output, input, scale)
|
||||
|
||||
return output, scale
|
@@ -47,12 +47,9 @@ class QuantConfigBase(ABC):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.quant_round_type = None
|
||||
self.quant_max_bound = None
|
||||
self.quant_min_bound = None
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
def name(self) -> str:
|
||||
"""Name of the quantization method."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
135
fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py
Normal file
135
fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.moe import FusedMoE
|
||||
|
||||
from ..utils import get_tensor
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
||||
class TensorWiseFP8Config(QuantConfigBase):
|
||||
"""
|
||||
Quantization config for weight and activation with FP8.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Nothing else to do!
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Nothing else to do!
|
||||
"""
|
||||
return "tensor_wise_fp8"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "TensorWiseFP8Config":
|
||||
"""
|
||||
Nothing else to do!
|
||||
"""
|
||||
return cls()
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
"""
|
||||
return method according to this config!
|
||||
"""
|
||||
if isinstance(layer, FusedMoE):
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import \
|
||||
TensorWiseFP8MoEMethod
|
||||
return TensorWiseFP8MoEMethod(self)
|
||||
else:
|
||||
return TensorWiseFP8LinearMethod(self)
|
||||
|
||||
|
||||
class TensorWiseFP8LinearMethod(QuantMethodBase):
|
||||
"""
|
||||
Weight and activation quantization method for linear layer with per tensor FP8
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: TensorWiseFP8Config,
|
||||
) -> None:
|
||||
"""
|
||||
Nothing special to do!
|
||||
"""
|
||||
super().__init__()
|
||||
self.quant_config = quant_config
|
||||
self.quant_max_bound = 448
|
||||
self.quant_min_bound = -448
|
||||
self.quant_round_type = 1
|
||||
self.weight_dtype = "float8_e4m3fn"
|
||||
|
||||
def create_weights(self, layer):
|
||||
"""
|
||||
Nothing to do!
|
||||
"""
|
||||
pass
|
||||
|
||||
def process_prequanted_weights(self, layer, state_dict) -> None:
|
||||
"""
|
||||
Process pre-quantized weights before applying them to the model
|
||||
Args:
|
||||
layer: The layer that owns the weights
|
||||
quant_weight: The quantized weights
|
||||
weight_scale: The scale of the quantized weights
|
||||
"""
|
||||
|
||||
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
|
||||
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
|
||||
act_scale = get_tensor(state_dict.pop(layer.act_scale_key))
|
||||
|
||||
quant_weight = quant_weight.transpose([1, 0]).contiguous()
|
||||
layer.linear_weight.copy_(quant_weight.view("float8_e4m3fn"), False)
|
||||
|
||||
self.act_scale = act_scale.item()
|
||||
self.total_scale = (act_scale * weight_scale).item()
|
||||
|
||||
def process_loaded_weights(self, layer, weights, state_dict) -> None:
|
||||
"""
|
||||
Read fp8 weight, act scale, weight scale
|
||||
"""
|
||||
pass
|
||||
|
||||
def apply(self, layer, x):
|
||||
"""
|
||||
compute!
|
||||
"""
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
cutlass_fp8_fp8_half_gemm_fused
|
||||
|
||||
from ..utils import create_hadamard_matrix_map
|
||||
|
||||
hadamard_matrix = create_hadamard_matrix_map[x.shape[-1]]
|
||||
new_x = paddle.matmul(x.cast("float32"), hadamard_matrix)
|
||||
fp8_x = new_x / self.act_scale
|
||||
fp8_x = fp8_x.astype("float8_e4m3fn")
|
||||
|
||||
linear_out = cutlass_fp8_fp8_half_gemm_fused(
|
||||
fp8_x,
|
||||
layer.linear_weight,
|
||||
transpose_x=False,
|
||||
transpose_y=True,
|
||||
bias=None,
|
||||
scale=self.total_scale,
|
||||
output_dtype="bfloat16",
|
||||
activation_type="identity")
|
||||
return linear_out
|
42
fastdeploy/model_executor/layers/quantization/w4a8.py
Normal file
42
fastdeploy/model_executor/layers/quantization/w4a8.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from ..moe import FusedMoE
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
||||
class W4A8Config(QuantConfigBase):
|
||||
"""
|
||||
quantization config for weight 4bits and activation 8bits
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def name(self) -> str:
|
||||
return "w4a8"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "W4A8Config":
|
||||
return cls()
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
if isinstance(layer, FusedMoE):
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import CutlassW4A8MoEMethod
|
||||
return CutlassW4A8MoEMethod(self)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer type {type(layer)} for w4a8")
|
@@ -23,16 +23,21 @@ from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
QUANT_SCALING_FACTOR = 448
|
||||
|
||||
|
||||
class W4AFP8Config(QuantConfigBase):
|
||||
"""
|
||||
quantization config for weight 4bits and activation fp8
|
||||
"""
|
||||
|
||||
def __init__(self, weight_scale_dict, act_scale_dict) -> None:
|
||||
super().__init__()
|
||||
self.weight_scale_dict = weight_scale_dict
|
||||
self.act_scale_dict = act_scale_dict
|
||||
self.quant_max_bound = 448
|
||||
self.quant_min_bound = -448
|
||||
self.quant_round_type = 1
|
||||
|
||||
def get_name(self) -> str:
|
||||
def name(self) -> str:
|
||||
return "w4afp8"
|
||||
|
||||
@classmethod
|
||||
@@ -49,6 +54,7 @@ class W4AFP8LinearMethod(QuantMethodBase):
|
||||
"""
|
||||
W4 AFP8 quant method for linear
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: W4AFP8Config,
|
||||
@@ -57,6 +63,9 @@ class W4AFP8LinearMethod(QuantMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer):
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.linear_weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
pass
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
@@ -78,11 +87,11 @@ class W4AFP8LinearMethod(QuantMethodBase):
|
||||
layer.linear_weight_scale,
|
||||
zero_points=None,
|
||||
bias=layer.linear_bias if layer.add_bias else None,
|
||||
out_scale=self.quant_config.weight_scale_dict.get(
|
||||
layer.prefix + ".weight_quanter") /
|
||||
(self.quant_config.act_scale_dict.get(layer.prefix +
|
||||
".activation_quanter") *
|
||||
QUANT_SCALING_FACTOR * QUANT_SCALING_FACTOR),
|
||||
out_scale=self.quant_config.weight_scale_dict.get(layer.prefix +
|
||||
".weight_scale")
|
||||
/ (self.quant_config.act_scale_dict.get(layer.prefix +
|
||||
".activation_scale") *
|
||||
QUANT_SCALING_FACTOR * QUANT_SCALING_FACTOR),
|
||||
groupsize=0,
|
||||
out_dtype=layer._dtype,
|
||||
)
|
||||
|
@@ -16,11 +16,12 @@
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
from paddlenlp.utils.log import logger
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.platforms.utils import convert_to_npu_dequant_scale
|
||||
|
||||
from ..utils import get_tensor
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
||||
@@ -29,14 +30,18 @@ class W8A8Config(QuantConfigBase):
|
||||
quantization config for weight 8bits and activation 8bits
|
||||
"""
|
||||
|
||||
def __init__(self, weight_scale_dict, act_scale_dict,
|
||||
use_gemm_dequant) -> None:
|
||||
def __init__(self, weight_scale_dict, act_scale_dict, use_gemm_dequant,
|
||||
use_smooth_quant) -> None:
|
||||
super().__init__()
|
||||
self.weight_scale_dict = weight_scale_dict
|
||||
self.act_scale_dict = act_scale_dict
|
||||
self.use_gemm_dequant = use_gemm_dequant
|
||||
self.use_smooth_quant = use_smooth_quant
|
||||
self.quant_max_bound = 127
|
||||
self.quant_min_bound = -127
|
||||
self.quant_round_type = 0
|
||||
|
||||
def get_name(self) -> str:
|
||||
def name(self) -> str:
|
||||
return "w8a8"
|
||||
|
||||
@classmethod
|
||||
@@ -61,12 +66,17 @@ class W8A8LinearMethod(QuantMethodBase):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.quant_config = quant_config
|
||||
self.smooth_quant_method = SmoothQuantLinearMethod(quant_config)
|
||||
|
||||
def create_weights(self, layer):
|
||||
weight_scale = self.quant_config.weight_scale_dict.get(
|
||||
layer.prefix + ".weight_quanter")
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.weight_dtype = "int8"
|
||||
if self.quant_config.use_smooth_quant:
|
||||
self.smooth_quant_method.create_weights(layer)
|
||||
weight_scale = self.quant_config.weight_scale_dict.get(layer.prefix +
|
||||
".weight_scale")
|
||||
in_scale = self.quant_config.act_scale_dict.get(layer.prefix +
|
||||
".activation_quanter")
|
||||
".activation_scale")
|
||||
self.skip_quant = False
|
||||
if weight_scale is None or in_scale is None:
|
||||
self.skip_quant = True
|
||||
@@ -86,13 +96,15 @@ class W8A8LinearMethod(QuantMethodBase):
|
||||
convert_to_npu_dequant_scale(linear_out_scale))
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
if self.quant_config.use_smooth_quant:
|
||||
self.smooth_quant_method.process_loaded_weights(layer, weights)
|
||||
if self.skip_quant:
|
||||
logger.debug(f"{layer.prefix} skip quant")
|
||||
weight_tensor = weights.cast(layer._dtype)
|
||||
layer.linear_weight.set_value(weight_tensor)
|
||||
else:
|
||||
weight_tensor = weights.transpose([1, 0])
|
||||
weight_tensor = paddle.cast(weight_tensor, layer.weight_dtype)
|
||||
weight_tensor = paddle.cast(weight_tensor, "int8")
|
||||
layer.linear_weight.set_value(weight_tensor)
|
||||
|
||||
def apply(self, layer, x):
|
||||
@@ -107,3 +119,53 @@ class W8A8LinearMethod(QuantMethodBase):
|
||||
linear_out = fastdeploy.model_executor.ops.gpu.dequant_int8(
|
||||
linear_out, layer.linear_out_scale, layer._dtype)
|
||||
return linear_out
|
||||
|
||||
|
||||
class SmoothQuantLinearMethod(QuantMethodBase):
|
||||
"""
|
||||
SmoothQuant Method
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: QuantConfigBase,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer):
|
||||
linear_shift_shape = [layer.output_size]
|
||||
linear_smooth_shape = [layer.output_size]
|
||||
layer.linear_shift = self.create_parameter(
|
||||
shape=linear_shift_shape,
|
||||
dtype=layer._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
layer.linear_smooth = layer.create_parameter(
|
||||
shape=linear_smooth_shape,
|
||||
dtype=layer._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
if layer.shift_key in layer.state_dict:
|
||||
shift_tensor = get_tensor(layer.state_dict.pop(
|
||||
layer.shift_key)).astype(paddle.get_default_dtype())
|
||||
else:
|
||||
shift_tensor = paddle.zeros(
|
||||
shape=layer.linear_shift_shape,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
layer.linear_shift.set_value(shift_tensor)
|
||||
if layer.smooth_key in layer.state_dict:
|
||||
smooth_tensor = get_tensor(layer.state_dict.pop(
|
||||
layer.smooth_key)).astype(paddle.get_default_dtype())
|
||||
else:
|
||||
smooth_tensor = paddle.ones(
|
||||
shape=[layer.linear_smooth_shape],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
layer.linear_smooth.set_value(smooth_tensor)
|
||||
|
||||
def apply(self, layer, x):
|
||||
pass
|
||||
|
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
@@ -21,6 +22,8 @@ from paddle.nn.quant import weight_only_linear, weight_quantize
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from ..moe import FusedMoE
|
||||
from ..utils import get_tensor
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
||||
@@ -28,34 +31,92 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
"""
|
||||
Quantization config for weight only
|
||||
Args:
|
||||
weight_only_linear_arch: The architecture of weight only linear layer
|
||||
algo: The quant algorithm("weight_only_int8" or "weight_only_int4") used for weight only linear layer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_only_linear_arch: int,
|
||||
algo: str,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_only_linear_arch = weight_only_linear_arch
|
||||
self.algo = algo
|
||||
# arch (int): The compute arch for target device. For example, A100 is 80, v100 is 70,
|
||||
# if you do not assign arch, we will get arch from your device, default: None.
|
||||
self.weight_only_linear_arch = os.getenv(
|
||||
"FLAGS_weight_only_linear_arch")
|
||||
if self.weight_only_linear_arch is not None:
|
||||
self.weight_only_linear_arch = int(self.weight_only_linear_arch)
|
||||
self.quant_max_bound = 0
|
||||
self.quant_min_bound = 0
|
||||
self.quant_round_type = 0
|
||||
|
||||
def get_name(self) -> str:
|
||||
def name(self) -> str:
|
||||
return "weight_only"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WeightOnlyConfig":
|
||||
weight_only_linear_arch = config["weight_only_linear_arch"]
|
||||
algo = config["algo"]
|
||||
return cls(weight_only_linear_arch, algo)
|
||||
return cls(algo)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
if current_platform.is_xpu():
|
||||
from fastdeploy.model_executor.layers.backends import XPUWeightOnlyLinearMethod
|
||||
return XPUWeightOnlyLinearMethod(self)
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
XPUWeightOnlyLinearMethod, XPUWeightOnlyMoEMethod)
|
||||
if isinstance(layer, FusedMoE):
|
||||
return XPUWeightOnlyMoEMethod(self)
|
||||
else:
|
||||
return XPUWeightOnlyLinearMethod(self)
|
||||
else:
|
||||
return GPUWeightOnlyLinearMethod(self)
|
||||
if isinstance(layer, FusedMoE):
|
||||
if layer.use_method == "cutlass":
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import \
|
||||
CutlassWeightOnlyMoEMethod
|
||||
return CutlassWeightOnlyMoEMethod(self)
|
||||
elif layer.use_method == "triton":
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import \
|
||||
TritonWeightOnlyMoEMethod
|
||||
return TritonWeightOnlyMoEMethod(self)
|
||||
elif layer.use_method == "marlin":
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_marlin_backend import \
|
||||
MarlinWeightOnlyMoEMethod
|
||||
return MarlinWeightOnlyMoEMethod(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported MOE backend {layer.use_method}")
|
||||
else:
|
||||
return GPUWeightOnlyLinearMethod(self)
|
||||
|
||||
|
||||
class WINT8Config(WeightOnlyConfig):
|
||||
"""
|
||||
weight only int8 config
|
||||
"""
|
||||
|
||||
def __init__(self, ) -> None:
|
||||
super().__init__("weight_only_int8")
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WINT8Config":
|
||||
return cls()
|
||||
|
||||
def name(self) -> str:
|
||||
return "wint8"
|
||||
|
||||
|
||||
class WINT4Config(WeightOnlyConfig):
|
||||
"""
|
||||
weight only int4 config
|
||||
"""
|
||||
|
||||
def __init__(self, ) -> None:
|
||||
super().__init__("weight_only_int4")
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WINT4Config":
|
||||
return cls()
|
||||
|
||||
def name(self) -> str:
|
||||
return "wint4"
|
||||
|
||||
|
||||
class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
@@ -71,12 +132,17 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer):
|
||||
weight_only_scale_name = layer.prefix + ".weight_only_scale"
|
||||
layer.linear_weight_shape.reverse()
|
||||
if self.quant_config.name() == "wint4":
|
||||
layer.linear_weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
linear_weight_scale_shape = [layer.embed_dim]
|
||||
if hasattr(layer, "linear_weight_shape"):
|
||||
if isinstance(layer.linear_weight_shape, list):
|
||||
layer_weight_shape = layer.linear_weight_shape
|
||||
linear_weight_scale_shape = layer_weight_shape[:1]
|
||||
if self.quant_config.name() == "wint4":
|
||||
linear_weight_scale_shape[0] *= 2
|
||||
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
shape=linear_weight_scale_shape,
|
||||
@@ -94,7 +160,8 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
weight=layer.linear_weight,
|
||||
bias=layer.linear_bias if layer.add_bias else None,
|
||||
weight_scale=layer.linear_weight_scale,
|
||||
weight_dtype=layer.weight_dtype,
|
||||
weight_dtype="int8"
|
||||
if self.quant_config.name() == "wint8" else "int4",
|
||||
arch=self.quant_config.weight_only_linear_arch,
|
||||
)
|
||||
return linear_out
|
||||
@@ -113,6 +180,20 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
) -> None:
|
||||
super().__init__(quant_config)
|
||||
|
||||
def process_prequanted_weights(self, layer, state_dict) -> None:
|
||||
"""
|
||||
Process pre-quantized weights before applying them to the model
|
||||
Args:
|
||||
layer: The layer that owns the weights
|
||||
quant_weight: The quantized weights
|
||||
weight_scale: The scale of the quantized weights
|
||||
"""
|
||||
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
|
||||
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
|
||||
layer.linear_weight.set_value(quant_weight)
|
||||
layer.linear_weight_scale.set_value(
|
||||
weight_scale.astype(paddle.get_default_dtype()))
|
||||
|
||||
def process_loaded_weights(self, layer, weight) -> None:
|
||||
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
|
||||
weight,
|
||||
|
@@ -17,10 +17,10 @@ from typing import Optional
|
||||
|
||||
import paddle
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.platforms.utils import convert_to_npu_dequant_scale
|
||||
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
from fastdeploy.model_executor.layers.quantization.ops import (
|
||||
cutlass_scaled_mm, scaled_fp8_quant)
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import (
|
||||
QuantConfigBase, QuantMethodBase)
|
||||
|
||||
|
||||
class WFP8AFP8Config(QuantConfigBase):
|
||||
@@ -32,17 +32,26 @@ class WFP8AFP8Config(QuantConfigBase):
|
||||
super().__init__()
|
||||
self.weight_scale_dict = weight_scale_dict
|
||||
self.act_scale_dict = act_scale_dict
|
||||
self.quant_max_bound = 448
|
||||
self.quant_min_bound = -448
|
||||
self.quant_round_type = 1
|
||||
|
||||
def get_name(self) -> str:
|
||||
def name(self) -> str:
|
||||
"""
|
||||
"""
|
||||
return "wfp8afp8"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WFP8AFP8Config":
|
||||
weight_scale_dict = config["weight_scale_dict"]
|
||||
act_scale_dict = config["act_scale_dict"]
|
||||
"""
|
||||
"""
|
||||
weight_scale_dict = config.get("weight_scale_dict", None)
|
||||
act_scale_dict = config.get("act_scale_dict", None)
|
||||
return cls(weight_scale_dict, act_scale_dict)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
"""
|
||||
"""
|
||||
return WFP8AFP8LinearMethod(self)
|
||||
|
||||
|
||||
@@ -59,58 +68,49 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer):
|
||||
"""
|
||||
"""
|
||||
layer.linear_weight_shape.reverse()
|
||||
layer.weight_dtype = "float8_e4m3fn"
|
||||
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
|
||||
weight_scale = self.quant_config.weight_scale_dict.get(
|
||||
layer.prefix + ".weight_quanter")
|
||||
in_scale = self.quant_config.act_scale_dict.get(layer.prefix +
|
||||
".activation_quanter")
|
||||
self.skip_quant = False
|
||||
# we will skip quant if weight_scale is not found or in_scale is not found
|
||||
if weight_scale is None or in_scale is None:
|
||||
self.skip_quant = True
|
||||
else:
|
||||
max_range = 448.0
|
||||
layer.scalar_scale_name = layer.prefix + ".scalar_weight_quanter"
|
||||
layer.scalar_scale = layer.create_parameter(
|
||||
shape=([1]),
|
||||
dtype="float32",
|
||||
)
|
||||
layer.scalar_scale.set_value(
|
||||
paddle.to_tensor([1.0 / (max_range * in_scale)],
|
||||
dtype="float32"))
|
||||
linear_out_scale = paddle.to_tensor(weight_scale /
|
||||
max_range).astype("float32")
|
||||
layer.linear_out_scale = layer.create_parameter(
|
||||
shape=[layer.embed_dim],
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
layer.linear_out_scale.set_value(
|
||||
convert_to_npu_dequant_scale(linear_out_scale))
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
shape=[1],
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
# TODO(YuanRisheng): We should abstract the skip_quant logic to adapt to more quant methods
|
||||
"""
|
||||
"""
|
||||
if self.skip_quant:
|
||||
weight_tensor = weights.cast(layer._dtype)
|
||||
layer.linear_weight.set_value(weight_tensor)
|
||||
return
|
||||
weight_tensor = weights.transpose([1, 0])
|
||||
weight_tensor = paddle.cast(weight_tensor, self.weight_dtype)
|
||||
self.linear_weight.copy_(weight_tensor, False)
|
||||
if weights.dtype != paddle.float8_e4m3fn:
|
||||
self.use_per_token_if_dynamic = True
|
||||
weight_tensor = weights.transpose([1, 0]).contiguous()
|
||||
qweight, weight_scale = scaled_fp8_quant(
|
||||
weight_tensor,
|
||||
use_per_token_if_dynamic=False,
|
||||
)
|
||||
layer.linear_weight.copy_(qweight, False)
|
||||
layer.linear_weight_scale.set_value(weight_scale)
|
||||
|
||||
def apply(self, layer, x):
|
||||
"""
|
||||
"""
|
||||
if self.skip_quant:
|
||||
linear_out = paddle.matmul(x, layer.linear_weight, False, True)
|
||||
return linear_out
|
||||
linear_out = fastdeploy.model_executor.ops.gpu.per_channel_fp8_fp8_half_gemm_fused(
|
||||
x,
|
||||
layer.linear_weight,
|
||||
bias=layer.linear_bias if layer.add_bias else None,
|
||||
scalar_scale=layer.scalar_scale,
|
||||
channel_scale=layer.linear_out_scale,
|
||||
transpose_x=False,
|
||||
transpose_y=True,
|
||||
output_dtype=layer._dtype,
|
||||
)
|
||||
if self.use_per_token_if_dynamic:
|
||||
out_type = x.dtype
|
||||
a_q, a_scales = scaled_fp8_quant(
|
||||
x, use_per_token_if_dynamic=self.use_per_token_if_dynamic)
|
||||
linear_out = cutlass_scaled_mm(a_q, layer.linear_weight, a_scales,
|
||||
layer.linear_weight_scale, out_type,
|
||||
layer.linear_bias)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return linear_out
|
||||
|
142
fastdeploy/model_executor/layers/quantization/wint2.py
Normal file
142
fastdeploy/model_executor/layers/quantization/wint2.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from ..moe import FusedMoE
|
||||
from . import get_quantization_config
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
||||
class WINT2Config(QuantConfigBase):
|
||||
"""
|
||||
Quantization config for wint8 linear and w4w2 MoE.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dense_quant_type: str,
|
||||
dense_quant_granularity: str,
|
||||
moe_quant_type: str,
|
||||
moe_w4_quant_type: str,
|
||||
moe_w4_quant_granularity: str,
|
||||
moe_w4_quant_start_layer: int,
|
||||
moe_w4_quant_end_layer: int,
|
||||
moe_w2_quant_type: str,
|
||||
moe_w2_quant_granularity: str,
|
||||
moe_w2_quant_group_size: int,
|
||||
moe_w2_quant_start_layer: int,
|
||||
moe_w2_quant_end_layer: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.quant_max_bound = 0
|
||||
self.quant_min_bound = 0
|
||||
self.quant_round_type = 0
|
||||
|
||||
# wint2 quantization config
|
||||
self.dense_quant_type = dense_quant_type
|
||||
self.dense_quant_granularity = dense_quant_granularity
|
||||
self.moe_quant_type = moe_quant_type
|
||||
self.moe_w4_quant_type = moe_w4_quant_type
|
||||
self.moe_w4_quant_granularity = moe_w4_quant_granularity
|
||||
self.moe_w4_quant_start_layer = moe_w4_quant_start_layer
|
||||
self.moe_w4_quant_end_layer = moe_w4_quant_end_layer
|
||||
self.moe_w2_quant_type = moe_w2_quant_type
|
||||
self.moe_w2_quant_granularity = moe_w2_quant_granularity
|
||||
self.moe_w2_quant_group_size = moe_w2_quant_group_size
|
||||
self.moe_w2_quant_start_layer = moe_w2_quant_start_layer
|
||||
self.moe_w2_quant_end_layer = moe_w2_quant_end_layer
|
||||
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Get the name of the quantization configuration.
|
||||
Returns:
|
||||
str: The name of the quantization configuration.
|
||||
"""
|
||||
return "wint2"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WINT2Config":
|
||||
"""
|
||||
Create a new instance of `WINT2Config` using the provided configuration dictionary.
|
||||
Args:
|
||||
config (dict): A dictionary containing the configuration parameters for the new instance.
|
||||
|
||||
Returns:
|
||||
WINT2Config: The newly created instance of `WINT2Config`.
|
||||
"""
|
||||
|
||||
dense_quant_type = config.get("dense_quant_config", "wint8")
|
||||
dense_quant_granularity = config.get("dense_quant_granularity",
|
||||
"per_channel")
|
||||
|
||||
moe_quant_config = config.get("moe_quant_config", {})
|
||||
moe_quant_type = moe_quant_config.get("quant_type", "w4w2")
|
||||
|
||||
moe_w4_quant_config = moe_quant_config.get("moe_w4_quant_config", {})
|
||||
moe_w4_quant_type = moe_w4_quant_config.get("quant_type",
|
||||
"wint4")
|
||||
moe_w4_quant_granularity = moe_w4_quant_config.get(
|
||||
"quant_granularity", "per_channel")
|
||||
moe_w4_quant_start_layer = moe_w4_quant_config.get(
|
||||
"quant_start_layer", 0)
|
||||
moe_w4_quant_end_layer = moe_w4_quant_config.get("quant_end_layer", 6)
|
||||
|
||||
moe_w2_quant_config = moe_quant_config.get("moe_w2_quant_config", {})
|
||||
moe_w2_quant_type = moe_w2_quant_config.get("quant_type", "wint2")
|
||||
moe_w2_quant_granularity = moe_w2_quant_config.get(
|
||||
"quant_granularity", "pp_acc")
|
||||
moe_w2_quant_group_size = moe_w2_quant_config.get(
|
||||
"quant_group_size", 0)
|
||||
moe_w2_quant_start_layer = moe_w2_quant_config.get(
|
||||
"quant_start_layer", 0)
|
||||
moe_w2_quant_end_layer = moe_w2_quant_config.get("quant_end_layer", 0)
|
||||
|
||||
return cls(
|
||||
dense_quant_type,
|
||||
dense_quant_granularity,
|
||||
moe_quant_type,
|
||||
moe_w4_quant_type,
|
||||
moe_w4_quant_granularity,
|
||||
moe_w4_quant_start_layer,
|
||||
moe_w4_quant_end_layer,
|
||||
moe_w2_quant_type,
|
||||
moe_w2_quant_granularity,
|
||||
moe_w2_quant_group_size,
|
||||
moe_w2_quant_start_layer,
|
||||
moe_w2_quant_end_layer,
|
||||
)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
"""
|
||||
Get the quantization method associated with the given layer based on the current quantization configuration.
|
||||
Args:
|
||||
layer (Layer): The layer for which the quantization method should be retrieved.
|
||||
|
||||
Returns:
|
||||
QuantMethodBase: The quantization method associated with the given layer.
|
||||
"""
|
||||
if isinstance(layer, FusedMoE):
|
||||
if layer.layer_idx <= self.moe_w4_quant_end_layer:
|
||||
return get_quantization_config(
|
||||
self.moe_w4_quant_type).from_config(
|
||||
{}).get_quant_method(layer)
|
||||
else:
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_wint2_backend import \
|
||||
TritonWint2FusedMoeMethod
|
||||
return TritonWint2FusedMoeMethod(self)
|
||||
else:
|
||||
return get_quantization_config(self.dense_quant_type).from_config(
|
||||
{}).get_quant_method(layer)
|
@@ -14,25 +14,25 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.config import ModelConfig
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from .utils import CpuGuard
|
||||
|
||||
|
||||
class ErnieRotaryEmbedding:
|
||||
|
||||
def __init__(self,
|
||||
rotary_dim,
|
||||
base,
|
||||
partial_rotary_factor,
|
||||
rope_scaling=None):
|
||||
def __init__(self, rotary_dim, base, partial_rotary_factor):
|
||||
"""
|
||||
Pre-calculate rotary position embedding for position_ids.
|
||||
"""
|
||||
self.rotary_dim = rotary_dim
|
||||
self.base = base
|
||||
self.partial_rotary_factor = partial_rotary_factor
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
def __call__(self, position_ids):
|
||||
bsz, max_seq_len = position_ids.shape[:2]
|
||||
@@ -70,18 +70,13 @@ class ErnieRotaryEmbedding:
|
||||
|
||||
class QwenRotaryEmbedding:
|
||||
|
||||
def __init__(self,
|
||||
rotary_dim,
|
||||
base,
|
||||
partial_rotary_factor,
|
||||
rope_scaling=None):
|
||||
def __init__(self, rotary_dim, base, partial_rotary_factor):
|
||||
"""
|
||||
Pre-calculate rotary position embedding for position_ids.
|
||||
"""
|
||||
self.rotary_dim = rotary_dim
|
||||
self.base = base
|
||||
self.partial_rotary_factor = partial_rotary_factor
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
def __call__(self, position_ids):
|
||||
bsz, max_seq_len = position_ids.shape[:2]
|
||||
@@ -104,35 +99,72 @@ class QwenRotaryEmbedding:
|
||||
return rot_emb
|
||||
|
||||
|
||||
def get_rope_impl(
|
||||
rotary_dim: int,
|
||||
base: 10000.0,
|
||||
position_ids,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
partial_rotary_factor=1,
|
||||
):
|
||||
"""
|
||||
The real implementation of get_rope
|
||||
"""
|
||||
|
||||
architecture = model_config.architectures[0]
|
||||
if model_config is not None and model_config is None or architecture.startswith(
|
||||
"Qwen"):
|
||||
rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base,
|
||||
partial_rotary_factor)
|
||||
rotary_emb = rotary_emb_layer(position_ids)
|
||||
else:
|
||||
rotary_emb_layer = ErnieRotaryEmbedding(rotary_dim, base,
|
||||
partial_rotary_factor)
|
||||
rotary_emb = rotary_emb_layer(position_ids)
|
||||
return rotary_emb
|
||||
|
||||
|
||||
def get_rope_xpu(
|
||||
rotary_dim: int,
|
||||
base: 10000.0,
|
||||
position_ids,
|
||||
model_config: ModelConfig,
|
||||
partial_rotary_factor=1,
|
||||
):
|
||||
"""
|
||||
In XPU, cos and sin compute must be done on cpu
|
||||
"""
|
||||
with CpuGuard():
|
||||
position_ids = position_ids.cpu()
|
||||
rotary_emb = get_rope_impl(rotary_dim, base, position_ids,
|
||||
model_config, partial_rotary_factor)
|
||||
return rotary_emb.to('xpu')
|
||||
|
||||
|
||||
def get_rope(
|
||||
rotary_dim: int,
|
||||
base: 10000.0,
|
||||
position_ids,
|
||||
model_config: ModelConfig,
|
||||
partial_rotary_factor=1,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
rope_type = rope_scaling.get("architectures", None)
|
||||
if "Qwen2ForCausalLM" in rope_type:
|
||||
rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base,
|
||||
partial_rotary_factor,
|
||||
rope_scaling)
|
||||
rotary_emb = rotary_emb_layer(position_ids)
|
||||
"""
|
||||
The warpper of get_rope
|
||||
"""
|
||||
if current_platform.is_xpu():
|
||||
return get_rope_xpu(rotary_dim, base, position_ids, model_config,
|
||||
partial_rotary_factor)
|
||||
else:
|
||||
rotary_emb_layer = ErnieRotaryEmbedding(rotary_dim, base,
|
||||
partial_rotary_factor,
|
||||
rope_scaling)
|
||||
rotary_emb = rotary_emb_layer(position_ids)
|
||||
return rotary_emb
|
||||
return get_rope_impl(rotary_dim, base, position_ids, model_config,
|
||||
partial_rotary_factor)
|
||||
|
||||
|
||||
class ErnieVlRotaryEmbedding3D:
|
||||
|
||||
def __init__(self, rotary_dim, base, partial_rotary_factor, max_position,
|
||||
freq_allocation, rope_scaling):
|
||||
freq_allocation):
|
||||
self.rotary_dim = rotary_dim
|
||||
self.base = base
|
||||
self.paritial_rotary_factor = partial_rotary_factor
|
||||
self.rope_scaling = rope_scaling
|
||||
self.max_position = max_position
|
||||
self.freq_allocation = freq_allocation
|
||||
|
||||
@@ -223,12 +255,10 @@ def get_rope_3d(
|
||||
paritial_rotary_factor: 1,
|
||||
max_position: 131072,
|
||||
freq_allocation: 2,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
rotary_emb3d_layer = ErnieVlRotaryEmbedding3D(rotary_dim, base,
|
||||
paritial_rotary_factor,
|
||||
max_position,
|
||||
freq_allocation,
|
||||
rope_scaling)
|
||||
freq_allocation)
|
||||
rotary_emb_3d = rotary_emb3d_layer(position_ids)
|
||||
return rotary_emb_3d
|
||||
|
@@ -23,11 +23,12 @@ import paddle
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
"""
|
||||
metadata for sampling.
|
||||
"""
|
||||
|
||||
temperature: paddle.Tensor
|
||||
|
||||
prompt_token_ids: paddle.Tensor
|
||||
pre_token_ids: paddle.Tensor
|
||||
eos_token_ids: paddle.Tensor
|
||||
frequency_penalties: paddle.Tensor
|
||||
presence_penalties: paddle.Tensor
|
||||
|
@@ -14,8 +14,12 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from .apply_penalty_multi_scores import apply_penalty_multi_scores
|
||||
from .apply_penalty_multi_scores import (
|
||||
apply_penalty_multi_scores, apply_speculative_penalty_multi_scores)
|
||||
from .top_p_sampling import top_p_sampling
|
||||
|
||||
__all__ = [
|
||||
"apply_penalty_multi_scores",
|
||||
"apply_speculative_penalty_multi_scores",
|
||||
"top_p_sampling",
|
||||
]
|
||||
|
@@ -20,7 +20,7 @@ from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
def apply_penalty_multi_scores(
|
||||
prompt_token_ids: paddle.Tensor,
|
||||
pre_token_ids: paddle.Tensor,
|
||||
logits: paddle.Tensor,
|
||||
repetition_penalties: paddle.Tensor,
|
||||
frequency_penalties: paddle.Tensor,
|
||||
@@ -30,16 +30,30 @@ def apply_penalty_multi_scores(
|
||||
step_idx: paddle.Tensor,
|
||||
min_dec_lens: paddle.Tensor,
|
||||
eos_token_ids: paddle.Tensor,
|
||||
):
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Args:
|
||||
Returns:
|
||||
apply_penalty_multi_scores
|
||||
"""
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
get_token_penalty_multi_scores
|
||||
logits = get_token_penalty_multi_scores(
|
||||
prompt_token_ids,
|
||||
pre_token_ids,
|
||||
logits,
|
||||
repetition_penalties,
|
||||
frequency_penalties,
|
||||
presence_penalties,
|
||||
temperature,
|
||||
bad_words_token_ids,
|
||||
step_idx,
|
||||
min_dec_lens,
|
||||
eos_token_ids,
|
||||
)
|
||||
elif current_platform.is_xpu():
|
||||
from fastdeploy.model_executor.ops.xpu import \
|
||||
get_token_penalty_multi_scores
|
||||
logits = get_token_penalty_multi_scores(
|
||||
pre_token_ids,
|
||||
logits,
|
||||
repetition_penalties,
|
||||
frequency_penalties,
|
||||
@@ -54,3 +68,48 @@ def apply_penalty_multi_scores(
|
||||
raise NotImplementedError()
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def apply_speculative_penalty_multi_scores(
|
||||
pre_token_ids: paddle.Tensor,
|
||||
logits: paddle.Tensor,
|
||||
repetition_penalties: paddle.Tensor,
|
||||
frequency_penalties: paddle.Tensor,
|
||||
presence_penalties: paddle.Tensor,
|
||||
temperature: paddle.Tensor,
|
||||
bad_words_token_ids: paddle.Tensor,
|
||||
step_idx: paddle.Tensor,
|
||||
min_dec_lens: paddle.Tensor,
|
||||
eos_token_ids: paddle.Tensor,
|
||||
seq_lens_this_time: paddle.Tensor,
|
||||
output_padding_offset: paddle.Tensor,
|
||||
output_cum_offsets: paddle.Tensor,
|
||||
max_len: int,
|
||||
):
|
||||
"""
|
||||
apply_speculative_penalty_multi_scores
|
||||
"""
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
speculate_get_token_penalty_multi_scores
|
||||
|
||||
logits = speculate_get_token_penalty_multi_scores(
|
||||
pre_token_ids,
|
||||
logits,
|
||||
repetition_penalties,
|
||||
frequency_penalties,
|
||||
presence_penalties,
|
||||
temperature,
|
||||
bad_words_token_ids,
|
||||
step_idx,
|
||||
min_dec_lens,
|
||||
eos_token_ids,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
max_len,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return logits
|
||||
|
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy import envs
|
||||
|
||||
|
||||
def top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
ps: paddle.Tensor,
|
||||
threshold: Optional[paddle.Tensor] = None,
|
||||
topp_seed: Optional[paddle.Tensor] = None,
|
||||
seed: int = -1,
|
||||
k: int = 0,
|
||||
mode: Literal['truncated', 'non-truncated'] = "truncated",
|
||||
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
top_p_sampling
|
||||
"""
|
||||
top_p_class = envs.FD_SAMPLING_CLASS.lower()
|
||||
if top_p_class == "air":
|
||||
_, ids = air_top_p_sampling(x,
|
||||
ps,
|
||||
threshold,
|
||||
topp_seed,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode=mode)
|
||||
elif top_p_class == "rejection":
|
||||
ids = rejection_top_p_sampling(x, ps, seed)
|
||||
_ = None
|
||||
else:
|
||||
_, ids = paddle.tensor.top_p_sampling(x,
|
||||
ps,
|
||||
threshold=threshold,
|
||||
topp_seed=topp_seed,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode=mode)
|
||||
return _, ids
|
||||
|
||||
|
||||
def air_top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
ps: paddle.Tensor,
|
||||
threshold: Optional[paddle.Tensor] = None,
|
||||
topp_seed: Optional[paddle.Tensor] = None,
|
||||
seed: int = -1,
|
||||
k: int = 0,
|
||||
mode: Literal['truncated', 'non-truncated'] = "truncated",
|
||||
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
air_top_p_sampling
|
||||
"""
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import air_top_p_sampling
|
||||
out, ids = air_top_p_sampling(x, ps, threshold, topp_seed, seed, k,
|
||||
mode)
|
||||
except ImportError:
|
||||
raise RuntimeError("Cannot import air_top_p_sampling op.")
|
||||
return out, ids
|
||||
|
||||
|
||||
def rejection_top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
ps: paddle.Tensor,
|
||||
seed: int = -1,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
rejection_top_p_sampling
|
||||
"""
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import rejection_top_p_sampling
|
||||
ids = rejection_top_p_sampling(
|
||||
x,
|
||||
ps,
|
||||
seed,
|
||||
)
|
||||
except ImportError:
|
||||
raise RuntimeError("Cannot import rejection_top_p_sampling op.")
|
||||
return ids
|
@@ -13,43 +13,193 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from fastdeploy.distributed.parallel_state import \
|
||||
get_tensor_model_parallel_world_size
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import \
|
||||
LogitsProcessorBase
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.ops import \
|
||||
apply_penalty_multi_scores
|
||||
from fastdeploy.model_executor.layers.sample.ops import (
|
||||
apply_penalty_multi_scores, apply_speculative_penalty_multi_scores,
|
||||
top_p_sampling)
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
class SamplerProcessor:
|
||||
"""
|
||||
SamplingProcessor for guided decoding.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.async_step = None
|
||||
self.token_bitmask = None
|
||||
self.logits_processor: Dict[int, Optional[Any]] = dict()
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.logits_lock = threading.Lock()
|
||||
|
||||
def add_logits_processor(self,
|
||||
ids: int,
|
||||
future: Optional[Any] = None,
|
||||
prefill_tokens: List[int] = []):
|
||||
""" add logits processor to SamplerProcessor """
|
||||
with self.logits_lock:
|
||||
if future is None:
|
||||
if ids in self.logits_processor:
|
||||
del self.logits_processor[ids]
|
||||
return
|
||||
|
||||
if isinstance(future, LogitsProcessorBase):
|
||||
self.logits_processor[ids] = future
|
||||
for token in prefill_tokens:
|
||||
self.logits_processor[ids].accept_token(token)
|
||||
elif future.done():
|
||||
self.logits_processor[ids] = future.result()
|
||||
for token in prefill_tokens:
|
||||
self.logits_processor[ids].accept_token(token)
|
||||
else:
|
||||
self.logits_processor[ids] = [future, prefill_tokens]
|
||||
|
||||
def update_vocab_mask(self, skip_idx_list: List[int] = []):
|
||||
""" update vocab mask. (cpu-heavy operation) """
|
||||
if len(self.logits_processor) == 0:
|
||||
return
|
||||
|
||||
with self.logits_lock:
|
||||
for idx, processor in self.logits_processor.items():
|
||||
if processor is None:
|
||||
del self.logits_processor[idx]
|
||||
continue
|
||||
|
||||
if not isinstance(processor, LogitsProcessorBase):
|
||||
future, prefill_tokens = self.logits_processor[idx]
|
||||
self.logits_processor[idx] = future.result()
|
||||
for token in prefill_tokens:
|
||||
self.logits_processor[idx].accept_token(token)
|
||||
|
||||
available_processors = None
|
||||
for processor in self.logits_processor.values():
|
||||
if processor.is_terminated():
|
||||
continue
|
||||
available_processors = processor
|
||||
if available_processors is None:
|
||||
return
|
||||
|
||||
# allocate token bitmask
|
||||
self.token_bitmask = available_processors.allocate_token_bitmask()
|
||||
|
||||
with self.logits_lock:
|
||||
# fill token bitmask
|
||||
for idx, processor in self.logits_processor.items():
|
||||
if processor.is_terminated() or idx in skip_idx_list:
|
||||
continue
|
||||
|
||||
processor.fill_token_bitmask(self.token_bitmask, idx)
|
||||
|
||||
def apply_token_mask(self,
|
||||
logits: paddle.Tensor,
|
||||
skip_idx_list: List[int] = []):
|
||||
""" apply token mask to logits """
|
||||
if len(self.logits_processor) == 0 or self.token_bitmask is None:
|
||||
return logits
|
||||
|
||||
# self.async_step.result()
|
||||
available_processors = None
|
||||
with self.logits_lock:
|
||||
for processor in self.logits_processor.values():
|
||||
if processor.is_terminated():
|
||||
continue
|
||||
available_processors = processor
|
||||
if available_processors is None:
|
||||
return logits
|
||||
|
||||
indices = list(self.logits_processor.keys())
|
||||
mask_idx = [i for i in indices if i not in skip_idx_list]
|
||||
return available_processors.apply_token_mask(logits,
|
||||
self.token_bitmask,
|
||||
indices=mask_idx)
|
||||
|
||||
def _accept_token(self, idx: int, token: int):
|
||||
""" accept token """
|
||||
if idx not in self.logits_processor:
|
||||
raise ValueError(
|
||||
f"Invalid index, idx: {idx}, logit_processors.keys: {self.logits_processor.keys()}"
|
||||
)
|
||||
|
||||
if self.logits_processor[idx].is_terminated():
|
||||
return
|
||||
|
||||
self.logits_processor[idx].accept_token(token)
|
||||
|
||||
def update_output_tokens(self,
|
||||
next_tokens: paddle.Tensor,
|
||||
skip_idx_list: List[int] = []):
|
||||
""" update output tokens """
|
||||
if len(self.logits_processor) == 0:
|
||||
return
|
||||
|
||||
token_ids = next_tokens.numpy().tolist()
|
||||
with self.logits_lock:
|
||||
for idx in self.logits_processor.keys():
|
||||
token = token_ids[idx][0]
|
||||
if token < 0 or self.logits_processor[
|
||||
idx] is None or idx in skip_idx_list:
|
||||
continue
|
||||
|
||||
self._accept_token(idx, token)
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
""" pre process before running """
|
||||
# create async operation for guided decoding
|
||||
# TODO: support async
|
||||
self.update_vocab_mask(skip_idx_list)
|
||||
# self.async_step = self.executor.submit(self.update_vocab_mask)
|
||||
|
||||
|
||||
class Sampler(nn.Layer):
|
||||
"""
|
||||
Sampler for normal generation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
"""
|
||||
super().__init__()
|
||||
if current_platform.is_cuda():
|
||||
self.nranks = get_tensor_model_parallel_world_size()
|
||||
if current_platform.is_cuda() or current_platform.is_xpu():
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.processor = SamplerProcessor()
|
||||
|
||||
def apply_logits_processor(self,
|
||||
ids: int,
|
||||
future: Optional[Any] = None,
|
||||
prefill_tokens: List[int] = []):
|
||||
""" apply logits processor to sampler """
|
||||
self.processor.add_logits_processor(ids, future, prefill_tokens)
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
""" pre process before running """
|
||||
self.processor.pre_process(skip_idx_list)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
skip_idx_list: List[int] = [],
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
"""
|
||||
logits = self.processor.apply_token_mask(logits, skip_idx_list)
|
||||
|
||||
logits = apply_penalty_multi_scores(
|
||||
sampling_metadata.prompt_token_ids,
|
||||
sampling_metadata.pre_token_ids,
|
||||
logits,
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
@@ -63,10 +213,156 @@ class Sampler(nn.Layer):
|
||||
|
||||
probs = F.softmax(logits)
|
||||
|
||||
_, next_tokens = paddle.tensor.top_p_sampling(probs,
|
||||
sampling_metadata.top_p)
|
||||
|
||||
if self.nranks > 1:
|
||||
paddle.distributed.broadcast(next_tokens, 0)
|
||||
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p)
|
||||
|
||||
self.processor.update_output_tokens(next_tokens, skip_idx_list)
|
||||
return next_tokens
|
||||
|
||||
|
||||
class SpeculativeSampler(nn.Layer):
|
||||
"""
|
||||
Sampler for speculative generation.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
"""
|
||||
"""
|
||||
super().__init__()
|
||||
if current_platform.is_cuda():
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
self.speculative_verify_window = fd_config.speculative_config.verify_window
|
||||
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
""" pre process before running """
|
||||
pass
|
||||
|
||||
def apply_logits_processor(self,
|
||||
ids: int,
|
||||
future: Optional[Any] = None,
|
||||
prefill_tokens: List[int] = []):
|
||||
""" apply logits processor to sampler """
|
||||
pass
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
max_model_len: int,
|
||||
share_inputs: List[paddle.Tensor],
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
"""
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import (speculate_verify,
|
||||
top_p_candidates)
|
||||
|
||||
logits = apply_speculative_penalty_multi_scores(
|
||||
sampling_metadata.pre_token_ids,
|
||||
logits,
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
sampling_metadata.step_idx,
|
||||
sampling_metadata.min_dec_lens,
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["output_padding_offset"],
|
||||
share_inputs["output_cum_offsets"],
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
probs = F.softmax(logits)
|
||||
|
||||
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
|
||||
probs,
|
||||
sampling_metadata.top_p,
|
||||
share_inputs["output_padding_offset"],
|
||||
self.speculative_max_candidate_len,
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
speculate_verify(
|
||||
share_inputs["accept_tokens"],
|
||||
share_inputs["accept_num"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs[
|
||||
"draft_tokens"], # Both input and output, need to write the last 1 token accepted to position 0.
|
||||
share_inputs["seq_lens_this_time"],
|
||||
verify_tokens,
|
||||
verify_scores,
|
||||
share_inputs["max_dec_len"],
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["output_cum_offsets"],
|
||||
actual_candidate_len,
|
||||
share_inputs["actual_draft_token_num"],
|
||||
sampling_metadata.top_p,
|
||||
max_model_len,
|
||||
self.speculative_verify_window,
|
||||
True, # enable_topp
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MTPSampler(nn.Layer):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
"""
|
||||
"""
|
||||
super().__init__()
|
||||
if current_platform.is_cuda():
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
""" pre process before running """
|
||||
pass
|
||||
|
||||
def apply_logits_processor(self,
|
||||
ids: int,
|
||||
future: Optional[Any] = None,
|
||||
prefill_tokens: List[int] = []):
|
||||
""" apply logits processor to sampler """
|
||||
pass
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
max_model_len: int,
|
||||
share_inputs: List[paddle.Tensor],
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
"""
|
||||
logits = apply_speculative_penalty_multi_scores(
|
||||
sampling_metadata.pre_token_ids,
|
||||
logits,
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
sampling_metadata.step_idx,
|
||||
sampling_metadata.min_dec_lens,
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
max_model_len,
|
||||
)
|
||||
probs = F.softmax(logits)
|
||||
|
||||
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p)
|
||||
return next_tokens
|
||||
|
@@ -14,32 +14,37 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import Tensor
|
||||
from paddle import Tensor, nn
|
||||
from paddle.framework import in_dynamic_mode
|
||||
from scipy.linalg import block_diag
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda() and current_platform.available():
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
get_padding_offset,
|
||||
speculate_get_padding_offset,
|
||||
)
|
||||
get_padding_offset, speculate_get_padding_offset)
|
||||
except Exception:
|
||||
raise ImportError(
|
||||
f"Verify environment consistency between compilation and FastDeploy installation. "
|
||||
f"And ensure the Paddle version supports FastDeploy's custom operators"
|
||||
"Verify environment consistency between compilation and FastDeploy installation. "
|
||||
"And ensure the Paddle version supports FastDeploy's custom operators"
|
||||
)
|
||||
import re
|
||||
|
||||
import os
|
||||
cache_params = os.getenv("CACHE_PARAMS", "none")
|
||||
from fastdeploy import envs
|
||||
|
||||
cache_params = envs.FD_CACHE_PARAMS
|
||||
if cache_params != "none":
|
||||
c8_state_dict = paddle.load(cache_params, return_numpy=True)
|
||||
|
||||
def per_block_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
def per_block_cast_to_fp8(x: Tensor,
|
||||
block_size: list = [128,
|
||||
128]) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Only used in deep_gemm block wise quant weight.
|
||||
copy from FastDeploy/custom_ops/gpu_ops/fp8_deep_gemm/tests/test_core.py.
|
||||
@@ -48,10 +53,13 @@ def per_block_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = paddle.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128),
|
||||
x_padded = paddle.zeros((ceil_div(m, block_size[0]) * block_size[0],
|
||||
ceil_div(n, block_size[1]) * block_size[1]),
|
||||
dtype=x.dtype)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = paddle.view(x_padded, (-1, 128, x_padded.shape[1] // 128, 128))
|
||||
x_view = paddle.view(
|
||||
x_padded,
|
||||
(-1, block_size[0], x_padded.shape[1] // block_size[1], block_size[1]))
|
||||
|
||||
x_abs = paddle.abs(x_view).astype(paddle.float32)
|
||||
x_amax = paddle.amax(x_abs, axis=(1, 3), keepdim=True)
|
||||
@@ -63,15 +71,15 @@ def per_block_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
|
||||
# for distributed tensor model parallel
|
||||
def _set_var_distributed(var, split_axis):
|
||||
def _set_var_distributed(var: Tensor, split_axis: int):
|
||||
"""
|
||||
Set whether the variable is distributed. If the variable is None, no operation will be performed.
|
||||
|
||||
Args:
|
||||
var (Variable, Optional): A Variable object, which can be None. The default value is None.
|
||||
The Variable object should have an attribute 'is_distributed' to indicate whether
|
||||
the variable has been processed in a distributed manner.
|
||||
split_axis (Integer): the sharding dimension of dist tensors
|
||||
var (Tensor): A Variable object, which can be None. The default value is None.
|
||||
The Variable object should have an attribute 'is_distributed' to indicate whether
|
||||
the variable has been processed in a distributed manner.
|
||||
split_axis (int): the sharding dimension of dist tensors.
|
||||
|
||||
Returns:
|
||||
None. No return value.
|
||||
@@ -91,10 +99,16 @@ def _set_var_distributed(var, split_axis):
|
||||
main_block._find_var_recursive(var.name).is_distributed = True
|
||||
|
||||
|
||||
def get_tensor(input):
|
||||
def get_tensor(input: Union[paddle.Tensor, np.ndarray, str]) -> paddle.Tensor:
|
||||
"""
|
||||
EP并行中,权重按层分布式存储,为了节省峰值显存,在state_dict处理部分仅保存
|
||||
层名与对应权重的路径,因此需要将权重的类型转换为paddle.Tensor
|
||||
Return a corresponding PaddlePaddle tensor based on the type and content of the input.
|
||||
|
||||
Args:
|
||||
input (Union[paddle.Tensor, np.ndarray, str]): The input data.
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: Returns a PaddlePaddle tensor.
|
||||
|
||||
"""
|
||||
if isinstance(input, paddle.Tensor):
|
||||
if input.place.is_cpu_place():
|
||||
@@ -104,7 +118,6 @@ def get_tensor(input):
|
||||
return paddle.to_tensor(input)
|
||||
elif isinstance(input, str):
|
||||
if ".safetensors" in input:
|
||||
|
||||
match = re.match(r"\[(.*?)\](.*)", input)
|
||||
if match:
|
||||
key_name = match.group(1)
|
||||
@@ -116,12 +129,11 @@ def get_tensor(input):
|
||||
weight = f.get_tensor(key_name)
|
||||
weight = paddle.Tensor(weight, zero_copy=True)
|
||||
weight = weight._copy_to(
|
||||
paddle.framework._current_expected_place(), False
|
||||
)
|
||||
paddle.framework._current_expected_place(), False)
|
||||
return weight
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
else:
|
||||
if cache_params != "none":
|
||||
tmp_key = input.split("/")[-1]
|
||||
if tmp_key in c8_state_dict:
|
||||
@@ -129,25 +141,134 @@ def get_tensor(input):
|
||||
return paddle.to_tensor(c8_state_dict.pop(tmp_key))
|
||||
return paddle.load(input)
|
||||
else:
|
||||
# 理论上不会命中这个分支
|
||||
return input
|
||||
|
||||
|
||||
def matmul_hadU(X: Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
Perform matrix multiplication using the Hadamard matrix.
|
||||
|
||||
Args:
|
||||
X (Tensor): The tensor to be multiplied.
|
||||
|
||||
Returns:
|
||||
Tensor: The tensor after Hadamard matrix multiplication, with the same shape as the input tensor X.
|
||||
|
||||
"""
|
||||
input = X.clone().reshape((-1, X.shape[-1], 1))
|
||||
output = input.clone()
|
||||
while input.shape[1] > 1:
|
||||
input = input.reshape(
|
||||
(input.shape[0], input.shape[1] // 2, 2, input.shape[2]))
|
||||
output = output.reshape(input.shape)
|
||||
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
|
||||
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
|
||||
output = output.reshape((input.shape[0], input.shape[1], -1))
|
||||
(input, output) = (output, input)
|
||||
del output
|
||||
return input.reshape(X.shape)
|
||||
|
||||
|
||||
def random_hadamard_matrix(block_size: int,
|
||||
dtype: Union[paddle.dtype, str]) -> paddle.Tensor:
|
||||
"""
|
||||
Generate a random Hadamard matrix.
|
||||
|
||||
Args:
|
||||
block_size (int): The size of the block, i.e., the number of rows and columns of the matrix.
|
||||
dtype (str): The data type, for example 'float32'.
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: The generated random Hadamard matrix.
|
||||
|
||||
"""
|
||||
Q = paddle.diag(paddle.ones((block_size), dtype=dtype))
|
||||
block = matmul_hadU(Q)
|
||||
return block
|
||||
|
||||
|
||||
def create_hadamard_matrix(hidden_size: int) -> paddle.Tensor:
|
||||
"""
|
||||
Generate a Hadamard matrix.
|
||||
|
||||
Args:
|
||||
hidden_size (int): The size of the hidden layer.
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: The generated Hadamard matrix.
|
||||
|
||||
"""
|
||||
hadamard_block_size = 32
|
||||
h = random_hadamard_matrix(hadamard_block_size, "float32")
|
||||
block_num = hidden_size // hadamard_block_size
|
||||
hadamard_matrix = paddle.to_tensor(
|
||||
block_diag(*[h for i in range(block_num)]))
|
||||
return hadamard_matrix
|
||||
|
||||
|
||||
create_hadamard_matrix_map = {}
|
||||
# Zkk: below key are used in 4.5T fp8.
|
||||
create_hadamard_matrix_map[8192] = create_hadamard_matrix(8192)
|
||||
create_hadamard_matrix_map[448] = create_hadamard_matrix(448)
|
||||
create_hadamard_matrix_map[1024] = create_hadamard_matrix(1024)
|
||||
create_hadamard_matrix_map[3584] = create_hadamard_matrix(3584)
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
"""
|
||||
Ensure the numerator is divisible by the denominator.
|
||||
|
||||
Args:
|
||||
numerator (int): The numerator.
|
||||
denominator (int): The denominator.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
AssertionError: If the numerator cannot be evenly divided by the denominator, an assertion error is raised.
|
||||
|
||||
"""
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
||||
numerator, denominator)
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator and return
|
||||
the division value."""
|
||||
def divide(numerator: int, denominator: int):
|
||||
"""
|
||||
Calculate the division result of two numbers.
|
||||
|
||||
Args:
|
||||
numerator (int): The dividend.
|
||||
denominator (int): The divisor.
|
||||
|
||||
Returns:
|
||||
int: The result of the division, which is the quotient of the dividend divided by the divisor.
|
||||
|
||||
"""
|
||||
ensure_divisibility(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
||||
def remove_padding(max_len, input_ids, seq_lens_this_time):
|
||||
|
||||
def remove_padding(
|
||||
max_len: paddle.Tensor, input_ids: paddle.Tensor,
|
||||
seq_lens_this_time: paddle.Tensor
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
|
||||
paddle.Tensor]:
|
||||
"""
|
||||
remove_padding
|
||||
Remove padded sequences from the input.
|
||||
|
||||
Args:
|
||||
max_len (paddle.Tensor): The maximum length of the input sequences.
|
||||
input_ids (paddle.Tensor): The IDs of the input sequences.
|
||||
seq_lens_this_time (paddle.Tensor): The actual length of each sequence.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing:
|
||||
- The sequence IDs with padding removed (paddle.Tensor).
|
||||
- The padding offsets (paddle.Tensor).
|
||||
- The cumulative offsets (paddle.Tensor).
|
||||
- The query sequence lengths (paddle.Tensor).
|
||||
- The key sequence lengths (paddle.Tensor).
|
||||
"""
|
||||
if current_platform.is_cuda():
|
||||
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time)
|
||||
@@ -159,7 +280,7 @@ def remove_padding(max_len, input_ids, seq_lens_this_time):
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
) = get_padding_offset(input_ids, cum_offsets_now, token_num,
|
||||
seq_lens_this_time)
|
||||
seq_lens_this_time)
|
||||
return (
|
||||
ids_remove_padding,
|
||||
padding_offset,
|
||||
@@ -168,10 +289,30 @@ def remove_padding(max_len, input_ids, seq_lens_this_time):
|
||||
cu_seqlens_k,
|
||||
)
|
||||
|
||||
def speculate_remove_padding(max_len, input_ids, seq_lens_this_time,
|
||||
draft_tokens, seq_lens_encoder):
|
||||
|
||||
def speculate_remove_padding(
|
||||
max_len: paddle.Tensor, input_ids: paddle.Tensor,
|
||||
seq_lens_this_time: paddle.Tensor, draft_tokens: paddle.Tensor,
|
||||
seq_lens_encoder: paddle.Tensor
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
|
||||
paddle.Tensor]:
|
||||
"""
|
||||
remove_padding
|
||||
Remove padding from sequences.
|
||||
|
||||
Args:
|
||||
max_len (paddle.Tensor): The maximum length of the sequences.
|
||||
input_ids (paddle.Tensor): The IDs of the input sequences.
|
||||
seq_lens_this_time (paddle.Tensor): The lengths of the sequences in the current batch.
|
||||
draft_tokens (paddle.Tensor): The draft tokens.
|
||||
seq_lens_encoder (paddle.Tensor): The lengths of the encoder sequences.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing:
|
||||
- The input sequence IDs with padding removed (paddle.Tensor).
|
||||
- Padding offsets (paddle.Tensor).
|
||||
- Cumulative offsets (paddle.Tensor).
|
||||
- Query sequence lengths (paddle.Tensor).
|
||||
- Key sequence lengths (paddle.Tensor).
|
||||
"""
|
||||
if current_platform.is_cuda():
|
||||
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time)
|
||||
@@ -197,3 +338,43 @@ def speculate_remove_padding(max_len, input_ids, seq_lens_this_time,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
)
|
||||
|
||||
|
||||
class CpuGuard:
|
||||
"""CpuGuard"""
|
||||
|
||||
def __init__(self):
|
||||
"""init"""
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
"""enter"""
|
||||
self.ori_device = paddle.device.get_device()
|
||||
paddle.device.set_device("cpu")
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""exit"""
|
||||
paddle.device.set_device(self.ori_device)
|
||||
|
||||
|
||||
def create_and_set_parameter(layer: nn.Layer, name: str,
|
||||
tensor: paddle.Tensor):
|
||||
"""
|
||||
Create a parameter for a specified layer and set its value to the given tensor.
|
||||
|
||||
Args:
|
||||
layer (nn.Layer): The layer object to which the parameter will be added.
|
||||
name (str): The name of the parameter to be created.
|
||||
tensor (paddle.Tensor): The tensor to set as the value of the parameter.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
setattr(
|
||||
layer, name,
|
||||
layer.create_parameter(
|
||||
shape=tensor.shape,
|
||||
dtype=tensor.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
))
|
||||
getattr(layer, name).set_value(tensor)
|
||||
|
Reference in New Issue
Block a user