mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-26 01:50:33 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -18,14 +18,13 @@ gcu backend methods
|
||||
|
||||
from .attention.flash_attn_backend import GCUFlashAttnBackend
|
||||
from .attention.mem_efficient_attn_backend import GCUMemEfficientAttnBackend
|
||||
from .moe.fused_moe_method_gcu_backend import (GCUFusedMoeMethod,
|
||||
GCUWeightOnlyMoEMethod)
|
||||
from .moe.fused_moe_method_gcu_backend import GCUFusedMoeMethod, GCUWeightOnlyMoEMethod
|
||||
from .quantization.weight_only import GCUWeightOnlyLinearMethod
|
||||
|
||||
__all__ = [
|
||||
'GCUFlashAttnBackend',
|
||||
'GCUMemEfficientAttnBackend',
|
||||
'GCUFusedMoeMethod',
|
||||
'GCUWeightOnlyMoEMethod',
|
||||
'GCUWeightOnlyLinearMethod',
|
||||
"GCUFlashAttnBackend",
|
||||
"GCUMemEfficientAttnBackend",
|
||||
"GCUFusedMoeMethod",
|
||||
"GCUWeightOnlyMoEMethod",
|
||||
"GCUWeightOnlyLinearMethod",
|
||||
]
|
||||
|
||||
@@ -17,31 +17,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta, ForwardMode
|
||||
|
||||
from fastdeploy.model_executor.ops.gcu import (fused_rotary_embedding,
|
||||
mem_efficient_attention,
|
||||
flash_attn_var_len)
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.model_executor.ops.gcu import flash_attn_var_len, fused_rotary_embedding
|
||||
|
||||
|
||||
@dataclass
|
||||
class GCUFlashAttnMetadata(AttentionMetadata):
|
||||
"""
|
||||
GCUFlashAttnMetadata
|
||||
"""
|
||||
|
||||
forward_mode: ForwardMode = ForwardMode.MIXED
|
||||
|
||||
_dtype: paddle.dtype = paddle.bfloat16
|
||||
@@ -63,15 +65,18 @@ class GCUFlashAttnMetadata(AttentionMetadata):
|
||||
pre_caches_length: int = 0
|
||||
|
||||
|
||||
|
||||
|
||||
class GCUFlashAttnBackend(AttentionBackend):
|
||||
"""
|
||||
GCUFlashAttnBackend backend implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
|
||||
head_dim: int):
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
kv_num_heads: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
):
|
||||
"""
|
||||
GCUFlashAttnBackend __init__
|
||||
"""
|
||||
@@ -99,8 +104,6 @@ class GCUFlashAttnBackend(AttentionBackend):
|
||||
self.rotary_embs = None
|
||||
self.enable_monitor: bool = bool(os.getenv("FD_GCU_ATTN_MONITOR", False))
|
||||
|
||||
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = GCUFlashAttnMetadata()
|
||||
@@ -131,15 +134,14 @@ class GCUFlashAttnBackend(AttentionBackend):
|
||||
self.rotary_embs = metadata.rotary_embs.reshape((-1, self.head_dim))
|
||||
|
||||
# some info for attention
|
||||
self.seq_lens_this_time_list = forward_meta.seq_lens_this_time.tolist() # List[int]
|
||||
self.seq_lens_encoder_list = forward_meta.seq_lens_encoder.tolist() # List[List[int]]
|
||||
self.seq_lens_decoder_list = forward_meta.seq_lens_decoder.tolist() # List[List[int]]
|
||||
self.seq_lens_this_time_list = forward_meta.seq_lens_this_time.tolist() # List[int]
|
||||
self.seq_lens_encoder_list = forward_meta.seq_lens_encoder.tolist() # List[List[int]]
|
||||
self.seq_lens_decoder_list = forward_meta.seq_lens_decoder.tolist() # List[List[int]]
|
||||
self.seq_lens_sum = np.sum(self.seq_lens_this_time_list)
|
||||
self.max_seq_len_this_time = np.max(self.seq_lens_this_time_list)
|
||||
|
||||
num_seqs = forward_meta.seq_lens_this_time.shape[0]
|
||||
|
||||
|
||||
self.is_decoder = all(x[0] == 0 for x in self.seq_lens_encoder_list)
|
||||
self.is_all_prefill = all(x[0] == 0 for x in self.seq_lens_decoder_list)
|
||||
|
||||
@@ -147,8 +149,14 @@ class GCUFlashAttnBackend(AttentionBackend):
|
||||
if self.all_slot_mapping is None:
|
||||
max_num_blocks_per_seq = (self.max_seq_len + self.block_size - 1) // self.block_size
|
||||
total_blocks = max_num_blocks_per_seq * self.max_num_seqs
|
||||
self.all_block_tables = np.arange(0, total_blocks, dtype=np.int32).reshape((self.max_num_seqs, max_num_blocks_per_seq)).tolist()
|
||||
self.all_slot_mapping = np.arange(0, total_blocks * self.block_size, dtype=np.int32).reshape((self.max_num_seqs, -1)).tolist()
|
||||
self.all_block_tables = (
|
||||
np.arange(0, total_blocks, dtype=np.int32)
|
||||
.reshape((self.max_num_seqs, max_num_blocks_per_seq))
|
||||
.tolist()
|
||||
)
|
||||
self.all_slot_mapping = (
|
||||
np.arange(0, total_blocks * self.block_size, dtype=np.int32).reshape((self.max_num_seqs, -1)).tolist()
|
||||
)
|
||||
|
||||
block_tables = []
|
||||
slot_mapping = []
|
||||
@@ -157,9 +165,9 @@ class GCUFlashAttnBackend(AttentionBackend):
|
||||
position_ids = []
|
||||
for seq_idx in range(num_seqs):
|
||||
cache_len = None
|
||||
if self.seq_lens_encoder_list[seq_idx][0] != 0: # prefill
|
||||
if self.seq_lens_encoder_list[seq_idx][0] != 0: # prefill
|
||||
cache_len = 0
|
||||
elif self.seq_lens_decoder_list[seq_idx][0] != 0: # decode
|
||||
elif self.seq_lens_decoder_list[seq_idx][0] != 0: # decode
|
||||
cache_len = self.seq_lens_decoder_list[seq_idx][0]
|
||||
# else: doesnot have req in this seq_idx
|
||||
|
||||
@@ -193,7 +201,6 @@ class GCUFlashAttnBackend(AttentionBackend):
|
||||
self.max_seqlen_q = self.max_seq_len_this_time
|
||||
self.max_seqlen_k = np.max(cache_lens)
|
||||
|
||||
|
||||
def get_attntion_meta(self):
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
@@ -206,9 +213,11 @@ class GCUFlashAttnBackend(AttentionBackend):
|
||||
Caculate kv cache shape
|
||||
"""
|
||||
# [total_tokens, kv_num_heads, head_dim]
|
||||
return (max_num_blocks * self.block_size,
|
||||
self.kv_num_heads,
|
||||
self.head_dim)
|
||||
return (
|
||||
max_num_blocks * self.block_size,
|
||||
self.kv_num_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
@paddle.no_grad()
|
||||
def forward_mixed(
|
||||
@@ -232,7 +241,6 @@ class GCUFlashAttnBackend(AttentionBackend):
|
||||
query = query.reshape_((1, -1, self.num_heads, self.head_dim))
|
||||
key = key.reshape_((1, -1, self.kv_num_heads, self.head_dim))
|
||||
|
||||
|
||||
# 1. Rope
|
||||
if self.rotary_embs.dtype != query.dtype:
|
||||
self.rotary_embs = paddle.cast(self.rotary_embs, query.dtype)
|
||||
@@ -242,7 +250,7 @@ class GCUFlashAttnBackend(AttentionBackend):
|
||||
key,
|
||||
self.rotary_embs,
|
||||
self.position_ids,
|
||||
layer.use_neox_rotary_style
|
||||
layer.use_neox_rotary_style,
|
||||
)
|
||||
|
||||
# 2. Save kv cache
|
||||
@@ -281,4 +289,3 @@ class GCUFlashAttnBackend(AttentionBackend):
|
||||
)
|
||||
res = res.reshape_((token_num, -1))
|
||||
return res
|
||||
|
||||
|
||||
@@ -16,33 +16,35 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
import paddle
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
|
||||
from fastdeploy.model_executor.ops.gcu import (fused_rotary_embedding,
|
||||
mem_efficient_attention,
|
||||
flash_attn_var_len)
|
||||
from paddleformers.utils.log import logger
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
)
|
||||
from fastdeploy.model_executor.ops.gcu import (
|
||||
fused_rotary_embedding,
|
||||
mem_efficient_attention,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta, ForwardMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class GCUMemEfficientAttnMetadata(AttentionMetadata):
|
||||
"""
|
||||
GCUMemEfficientAttnMetadata
|
||||
"""
|
||||
|
||||
forward_mode: ForwardMode = ForwardMode.MIXED
|
||||
_dtype: paddle.dtype = paddle.bfloat16
|
||||
|
||||
@@ -63,15 +65,18 @@ class GCUMemEfficientAttnMetadata(AttentionMetadata):
|
||||
pre_caches_length: int = 0
|
||||
|
||||
|
||||
|
||||
|
||||
class GCUMemEfficientAttnBackend(AttentionBackend):
|
||||
"""
|
||||
GCUMemEfficientAttnBackend backend implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
|
||||
head_dim: int):
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
kv_num_heads: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
):
|
||||
"""
|
||||
GCUMemEfficientAttnBackend __init__
|
||||
"""
|
||||
@@ -99,8 +104,6 @@ class GCUMemEfficientAttnBackend(AttentionBackend):
|
||||
self.rotary_embs = None
|
||||
self.use_paddle_native_sdpa = False
|
||||
|
||||
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = GCUMemEfficientAttnMetadata()
|
||||
@@ -125,32 +128,35 @@ class GCUMemEfficientAttnBackend(AttentionBackend):
|
||||
|
||||
metadata.pre_caches_length = forward_meta.pre_caches_length # not inited
|
||||
|
||||
|
||||
self.attention_metadata = metadata
|
||||
|
||||
if self.rotary_embs is None:
|
||||
self.rotary_embs = metadata.rotary_embs.reshape((-1, self.head_dim))
|
||||
|
||||
# some info for attention
|
||||
self.seq_lens_this_time_list = forward_meta.seq_lens_this_time.tolist() # List[int]
|
||||
self.seq_lens_encoder_list = forward_meta.seq_lens_encoder.tolist() # List[List[int]]
|
||||
self.seq_lens_decoder_list = forward_meta.seq_lens_decoder.tolist() # List[List[int]]
|
||||
self.seq_lens_this_time_list = forward_meta.seq_lens_this_time.tolist() # List[int]
|
||||
self.seq_lens_encoder_list = forward_meta.seq_lens_encoder.tolist() # List[List[int]]
|
||||
self.seq_lens_decoder_list = forward_meta.seq_lens_decoder.tolist() # List[List[int]]
|
||||
self.seq_lens_sum = np.sum(self.seq_lens_this_time_list)
|
||||
self.max_seq_len_this_time = np.max(self.seq_lens_this_time_list)
|
||||
|
||||
num_seqs = forward_meta.seq_lens_this_time.shape[0]
|
||||
|
||||
|
||||
self.is_decoder = all(x[0] == 0 for x in self.seq_lens_encoder_list)
|
||||
self.is_all_prefill = all(x[0] == 0 for x in self.seq_lens_decoder_list)
|
||||
|
||||
|
||||
# block_tables and slot_mapping
|
||||
if self.all_slot_mapping is None:
|
||||
max_num_blocks_per_seq = (self.max_seq_len + self.block_size - 1) // self.block_size
|
||||
total_blocks = max_num_blocks_per_seq * self.max_num_seqs
|
||||
self.all_block_tables = np.arange(0, total_blocks, dtype=np.int32).reshape((self.max_num_seqs, max_num_blocks_per_seq)).tolist()
|
||||
self.all_slot_mapping = np.arange(0, total_blocks * self.block_size, dtype=np.int32).reshape((self.max_num_seqs, -1)).tolist()
|
||||
self.all_block_tables = (
|
||||
np.arange(0, total_blocks, dtype=np.int32)
|
||||
.reshape((self.max_num_seqs, max_num_blocks_per_seq))
|
||||
.tolist()
|
||||
)
|
||||
self.all_slot_mapping = (
|
||||
np.arange(0, total_blocks * self.block_size, dtype=np.int32).reshape((self.max_num_seqs, -1)).tolist()
|
||||
)
|
||||
|
||||
block_tables = []
|
||||
slot_mapping = []
|
||||
@@ -162,9 +168,9 @@ class GCUMemEfficientAttnBackend(AttentionBackend):
|
||||
position_ids = []
|
||||
for seq_idx in range(num_seqs):
|
||||
cache_len = None
|
||||
if self.seq_lens_encoder_list[seq_idx][0] != 0: # prefill
|
||||
if self.seq_lens_encoder_list[seq_idx][0] != 0: # prefill
|
||||
cache_len = 0
|
||||
elif self.seq_lens_decoder_list[seq_idx][0] != 0: # decode
|
||||
elif self.seq_lens_decoder_list[seq_idx][0] != 0: # decode
|
||||
cache_len = self.seq_lens_decoder_list[seq_idx][0]
|
||||
# else: doesnot have req in this seq_idx
|
||||
|
||||
@@ -179,9 +185,12 @@ class GCUMemEfficientAttnBackend(AttentionBackend):
|
||||
position_ids.extend(self.position_ids_base[start:end])
|
||||
query_lens.append(lens_this_time)
|
||||
cached_kv_lens.append(end)
|
||||
cached_kv_slot_range.append([self.all_slot_mapping[seq_idx][0], self.all_slot_mapping[seq_idx][end]])
|
||||
|
||||
|
||||
cached_kv_slot_range.append(
|
||||
[
|
||||
self.all_slot_mapping[seq_idx][0],
|
||||
self.all_slot_mapping[seq_idx][end],
|
||||
]
|
||||
)
|
||||
|
||||
self.block_tables = paddle.to_tensor(block_tables, dtype="int32")
|
||||
self.slot_mapping = paddle.to_tensor(slot_mapping, dtype="int32")
|
||||
@@ -206,7 +215,6 @@ class GCUMemEfficientAttnBackend(AttentionBackend):
|
||||
self.cached_kv_lens = cached_kv_lens
|
||||
self.cached_kv_slot_range = cached_kv_slot_range
|
||||
|
||||
|
||||
def get_attntion_meta(self):
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
@@ -219,9 +227,11 @@ class GCUMemEfficientAttnBackend(AttentionBackend):
|
||||
Caculate kv cache shape
|
||||
"""
|
||||
# [total_tokens, kv_num_heads, head_dim]
|
||||
return (max_num_blocks * self.block_size,
|
||||
self.kv_num_heads,
|
||||
self.head_dim)
|
||||
return (
|
||||
max_num_blocks * self.block_size,
|
||||
self.kv_num_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
@paddle.no_grad()
|
||||
def forward_mixed(
|
||||
@@ -245,7 +255,6 @@ class GCUMemEfficientAttnBackend(AttentionBackend):
|
||||
query = query.reshape_((1, -1, self.num_heads, self.head_dim))
|
||||
key = key.reshape_((1, -1, self.kv_num_heads, self.head_dim))
|
||||
|
||||
|
||||
# 1. Rope
|
||||
if self.rotary_embs.dtype != query.dtype:
|
||||
self.rotary_embs = paddle.cast(self.rotary_embs, query.dtype)
|
||||
@@ -255,7 +264,7 @@ class GCUMemEfficientAttnBackend(AttentionBackend):
|
||||
key,
|
||||
self.rotary_embs,
|
||||
self.position_ids,
|
||||
layer.use_neox_rotary_style
|
||||
layer.use_neox_rotary_style,
|
||||
)
|
||||
|
||||
# 2. Save kv cache
|
||||
@@ -282,9 +291,7 @@ class GCUMemEfficientAttnBackend(AttentionBackend):
|
||||
v_ = value_caches[kv_start:kv_end, :, :]
|
||||
|
||||
if self.use_paddle_native_sdpa:
|
||||
res = self.native_sdpa_impl(
|
||||
q_, k_, v_
|
||||
)
|
||||
res = self.native_sdpa_impl(q_, k_, v_)
|
||||
else:
|
||||
res = mem_efficient_attention(
|
||||
query=q_.unsqueeze(0),
|
||||
@@ -302,7 +309,6 @@ class GCUMemEfficientAttnBackend(AttentionBackend):
|
||||
result = result.reshape_((token_num, -1))
|
||||
return result
|
||||
|
||||
|
||||
def get_triangle_upper_mask(self, shape, dtype):
|
||||
# [batch_size, 1, q_seq_len, kv_seq_len]
|
||||
shape[1] = 1
|
||||
@@ -313,7 +319,6 @@ class GCUMemEfficientAttnBackend(AttentionBackend):
|
||||
mask = paddle.triu(mask, diagonal=kv_seq_len - q_seq_len + 1)
|
||||
return mask
|
||||
|
||||
|
||||
def native_sdpa_impl(self, query, key, value):
|
||||
# input shape: [num_tokens, num_heads, head_dim] -> [1, num_tokens, num_heads, head_dim]
|
||||
q = query.unsqueeze(0)
|
||||
@@ -342,13 +347,9 @@ class GCUMemEfficientAttnBackend(AttentionBackend):
|
||||
# matmul and devide by sqrt(head_dim)
|
||||
attn_weights = paddle.matmul(q / math.sqrt(head_dim), k.transpose([0, 1, 3, 2]))
|
||||
|
||||
attention_mask = self.get_triangle_upper_mask(
|
||||
[batch, 1, q_seq_len, kv_seq_len], q.dtype
|
||||
)
|
||||
attention_mask = self.get_triangle_upper_mask([batch, 1, q_seq_len, kv_seq_len], q.dtype)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = paddle.nn.functional.softmax(
|
||||
attn_weights, axis=-1, dtype="float32"
|
||||
).astype(q.dtype)
|
||||
attn_weights = paddle.nn.functional.softmax(attn_weights, axis=-1, dtype="float32").astype(q.dtype)
|
||||
|
||||
attn_output = paddle.matmul(attn_weights, v)
|
||||
attn_output = attn_output.transpose([0, 2, 1, 3])
|
||||
|
||||
@@ -11,6 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""""
|
||||
""" "
|
||||
gcu moe
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
@@ -15,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
@@ -24,27 +22,30 @@ import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import \
|
||||
MoEMethodBase
|
||||
from fastdeploy.model_executor.layers.utils import (CpuGuard,
|
||||
create_and_set_parameter,
|
||||
get_tensor)
|
||||
from fastdeploy.model_executor.ops.gcu import (invoke_fused_moe_kernel,
|
||||
moe_align_block_size,
|
||||
topk_softmax,
|
||||
weight_quantize_custom_rtn,
|
||||
weight_quantize_rtn)
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
|
||||
from fastdeploy.model_executor.layers.utils import (
|
||||
CpuGuard,
|
||||
create_and_set_parameter,
|
||||
get_tensor,
|
||||
)
|
||||
from fastdeploy.model_executor.ops.gcu import (
|
||||
invoke_fused_moe_kernel,
|
||||
moe_align_block_size,
|
||||
topk_softmax,
|
||||
weight_quantize_custom_rtn,
|
||||
weight_quantize_rtn,
|
||||
)
|
||||
|
||||
|
||||
class GCUFusedMoeMethod(MoEMethodBase):
|
||||
"""
|
||||
Use GCU to compute Fused MoE.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config):
|
||||
super().__init__(quant_config)
|
||||
self.group_size = -1
|
||||
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle gcu create weight process.
|
||||
@@ -53,28 +54,28 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
|
||||
for idx, weight_tensor in enumerate(
|
||||
[stacked_up_gate_proj_weights, stacked_down_proj_weights]):
|
||||
for idx, weight_tensor in enumerate([stacked_up_gate_proj_weights, stacked_down_proj_weights]):
|
||||
# shape [E, K, N] -> [E, N, K]
|
||||
weight_tensor = paddle.transpose(weight_tensor, [0, 2, 1])
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
setattr(
|
||||
layer, weight_name,
|
||||
layer,
|
||||
weight_name,
|
||||
layer.create_parameter(
|
||||
shape=weight_tensor.shape,
|
||||
dtype=weight_tensor.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
))
|
||||
),
|
||||
)
|
||||
getattr(layer, weight_name).set_value(weight_tensor)
|
||||
|
||||
|
||||
@paddle.no_grad()
|
||||
def compute_ffn(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
enable_quant = False
|
||||
enable_quant=False,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle gcu compute Fused MoE.
|
||||
@@ -86,8 +87,17 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
||||
|
||||
topk_weights = paddle.empty([token_num, top_k], dtype=gate_out.dtype)
|
||||
topk_indices = paddle.empty([token_num, top_k], dtype="int32")
|
||||
token_expert_indices = paddle.empty([token_num, top_k], dtype="int32",)
|
||||
topk_softmax(topk_weights, topk_indices, token_expert_indices, gate_out, norm_topk_prob=True)
|
||||
token_expert_indices = paddle.empty(
|
||||
[token_num, top_k],
|
||||
dtype="int32",
|
||||
)
|
||||
topk_softmax(
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
token_expert_indices,
|
||||
gate_out,
|
||||
norm_topk_prob=True,
|
||||
)
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
@@ -136,7 +146,7 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
||||
top_k,
|
||||
config,
|
||||
enable_quant, # use_int4_w4a16
|
||||
[0, self.group_size], # block_shape
|
||||
[0, self.group_size], # block_shape
|
||||
)
|
||||
|
||||
intermediate_cache2 = paddle.empty(
|
||||
@@ -144,8 +154,7 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
|
||||
intermediate_cache1)
|
||||
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(intermediate_cache1)
|
||||
|
||||
intermediate_cache2 = intermediate_cache2.reshape([-1, moe_intermediate_size])
|
||||
|
||||
@@ -181,13 +190,14 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
||||
fused_moe_out = fused_moe_out.reshape_([token_num, hidden_size])
|
||||
|
||||
if layer.tp_size > 1:
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from fastdeploy.distributed.communication_op import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
|
||||
tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
@@ -199,7 +209,6 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
||||
"""
|
||||
return self.compute_ffn(layer, x, gate_out, enable_quant=False)
|
||||
|
||||
|
||||
def apply_ep_prefill(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
@@ -211,7 +220,6 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def apply_ep_decode(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
@@ -223,7 +231,6 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def apply_tp(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
@@ -247,48 +254,44 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
self.moe_quant_type = self.quant_config.algo
|
||||
self.pack_num = 1
|
||||
|
||||
assert self.quant_config.algo == "weight_only_int4", \
|
||||
"GCUWeightOnlyMoEMethod only support weight_only_int4, but got:{self.quant_config.algo}"
|
||||
assert (
|
||||
self.quant_config.algo == "weight_only_int4"
|
||||
), "GCUWeightOnlyMoEMethod only support weight_only_int4, but got:{self.quant_config.algo}"
|
||||
|
||||
self.added_qzeros_attrs = [
|
||||
"up_gate_proj_weight_zeros", "down_proj_weight_zeros"
|
||||
"up_gate_proj_weight_zeros",
|
||||
"down_proj_weight_zeros",
|
||||
]
|
||||
self.group_size = 64
|
||||
|
||||
self.quant_multi_process_group_size = int(
|
||||
os.getenv("FD_MOE_QUANT_MULTI_PROCESS_GROUP_SIZE", 8)
|
||||
)
|
||||
self.quant_multi_process_group_size = int(os.getenv("FD_MOE_QUANT_MULTI_PROCESS_GROUP_SIZE", 8))
|
||||
logger.info(f"GCUWeightOnlyMoEMethod quant_multi_process_group_size: {self.quant_multi_process_group_size}")
|
||||
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle gcu process prequanted weights.
|
||||
"""
|
||||
up_gate_proj_expert_weight_key = layer.weight_key_map.get(
|
||||
"up_gate_proj_expert_weight_key", None)
|
||||
down_proj_expert_weight_key = layer.weight_key_map.get(
|
||||
"down_proj_expert_weight_key", None)
|
||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"up_gate_proj_expert_weight_scale_key", None)
|
||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"down_proj_expert_weight_scale_key", None)
|
||||
up_gate_proj_expert_weight_key = layer.weight_key_map.get("up_gate_proj_expert_weight_key", None)
|
||||
down_proj_expert_weight_key = layer.weight_key_map.get("down_proj_expert_weight_key", None)
|
||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||
|
||||
up_gate_proj_weights, down_proj_weights = layer.load_experts_weight(
|
||||
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key)
|
||||
state_dict,
|
||||
up_gate_proj_expert_weight_key,
|
||||
down_proj_expert_weight_key,
|
||||
)
|
||||
# self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
up_gate_proj_weight_scale = []
|
||||
down_proj_weight_scale = []
|
||||
for i in range(layer.num_experts):
|
||||
expert_idx = layer.expert_id_offset + i
|
||||
up_gate_proj_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
up_gate_proj_expert_weight_scale_key.format(expert_idx))))
|
||||
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
|
||||
)
|
||||
down_proj_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
down_proj_expert_weight_scale_key.format(expert_idx))))
|
||||
get_tensor(state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx)))
|
||||
)
|
||||
|
||||
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
||||
@@ -299,12 +302,11 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
"up_gate_proj_weight": up_gate_proj_weight,
|
||||
"down_proj_weight": down_proj_weight,
|
||||
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
|
||||
"down_proj_weight_scale": down_proj_weight_scale
|
||||
"down_proj_weight_scale": down_proj_weight_scale,
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
create_and_set_parameter(layer, name, tensor)
|
||||
|
||||
|
||||
@paddle.no_grad()
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
@@ -313,7 +315,6 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
|
||||
|
||||
def quant_worker(p_group_idx, shared_dict, weights, moe_quant_type, group_size):
|
||||
with CpuGuard():
|
||||
p_group_size = len(weights)
|
||||
@@ -322,13 +323,13 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
quant_weight, scale = weight_quantize_custom_rtn(
|
||||
weights[group_j],
|
||||
moe_quant_type,
|
||||
group_size # group_size
|
||||
group_size, # group_size
|
||||
)
|
||||
shared_dict[p_group_size * p_group_idx + group_j] = (
|
||||
quant_weight, scale
|
||||
quant_weight,
|
||||
scale,
|
||||
)
|
||||
|
||||
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
@@ -354,7 +355,13 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
|
||||
p = multiprocessing.Process(
|
||||
target=quant_worker,
|
||||
args=(i, shared_dict, w, self.moe_quant_type, self.group_size)
|
||||
args=(
|
||||
i,
|
||||
shared_dict,
|
||||
w,
|
||||
self.moe_quant_type,
|
||||
self.group_size,
|
||||
),
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
@@ -376,7 +383,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
quant_weight, scale = weight_quantize_rtn(
|
||||
weight_tensor[i],
|
||||
self.moe_quant_type,
|
||||
self.group_size # group_size
|
||||
self.group_size, # group_size
|
||||
)
|
||||
weight_list.append(quant_weight)
|
||||
weight_scale_list.append(scale)
|
||||
@@ -389,7 +396,6 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
quanted_weight_zeros = quanted_weight_scale * 8
|
||||
create_and_set_parameter(layer, zeros_name, quanted_weight_zeros)
|
||||
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""""
|
||||
""" "
|
||||
gcu quantization
|
||||
"""
|
||||
from .weight_only import GCUWeightOnlyLinearMethod
|
||||
|
||||
@@ -17,7 +17,9 @@
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.weight_only import (
|
||||
WeightOnlyConfig, WeightOnlyLinearMethod)
|
||||
WeightOnlyConfig,
|
||||
WeightOnlyLinearMethod,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.ops.gcu import linear_quant, weight_quantize_rtn
|
||||
|
||||
@@ -35,7 +37,6 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
self.quant_config = quant_config
|
||||
self.group_size = -1
|
||||
|
||||
|
||||
def create_weights(self, layer):
|
||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||
weight_scale_shape = [layer.weight_shape[1]]
|
||||
@@ -50,7 +51,6 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
|
||||
def process_prequanted_weights(self, layer, state_dict) -> None:
|
||||
"""
|
||||
Process pre-quantized weights before applying them to the model
|
||||
@@ -62,9 +62,7 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
|
||||
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
|
||||
layer.weight.set_value(quant_weight)
|
||||
layer.weight_scale.set_value(
|
||||
weight_scale.astype(paddle.get_default_dtype()))
|
||||
|
||||
layer.weight_scale.set_value(weight_scale.astype(paddle.get_default_dtype()))
|
||||
|
||||
def process_loaded_weights(self, layer, weight) -> None:
|
||||
quanted_weight_tensor, weight_scale_tensor = weight_quantize_rtn(
|
||||
@@ -74,9 +72,7 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
)
|
||||
|
||||
layer.weight.set_value(quanted_weight_tensor)
|
||||
layer.weight_scale.set_value(
|
||||
weight_scale_tensor.astype(paddle.get_default_dtype()))
|
||||
|
||||
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))
|
||||
|
||||
@paddle.no_grad()
|
||||
def apply(self, layer, x):
|
||||
|
||||
Reference in New Issue
Block a user