mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[GCU] Support gcu platform (#2702)
baseline: e7fa57ebae
Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
@@ -19,7 +19,7 @@ from typing import Optional
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.incubate.nn.functional import fused_bias_act
|
||||
from paddle.incubate.nn.functional import fused_bias_act, swiglu
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -66,6 +66,8 @@ class SiluAndMul(nn.Layer):
|
||||
if current_platform.is_cuda() or current_platform.is_xpu(
|
||||
) or current_platform.is_iluvatar():
|
||||
self.forward = self.forward_cuda
|
||||
elif current_platform.is_gcu():
|
||||
self.forward = self.forward_gcu
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -123,3 +125,18 @@ class SiluAndMul(nn.Layer):
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
|
||||
def forward_gcu(self, x):
|
||||
"""
|
||||
Forward propagation of the custom activation layer.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor to the activation layer.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor.
|
||||
"""
|
||||
out = swiglu(x)
|
||||
if self.bias is not None:
|
||||
out = out + self.bias
|
||||
return out
|
||||
|
@@ -16,14 +16,24 @@
|
||||
all backends methods
|
||||
"""
|
||||
|
||||
from .xpu import *
|
||||
from .npu import *
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
__all__ = []
|
||||
from . import npu
|
||||
if hasattr(npu, '__all__'):
|
||||
__all__.extend(npu.__all__)
|
||||
|
||||
from . import xpu
|
||||
if hasattr(xpu, '__all__'):
|
||||
__all__.extend(xpu.__all__)
|
||||
|
||||
if current_platform.is_xpu():
|
||||
from . import xpu
|
||||
from .xpu import *
|
||||
if hasattr(xpu, '__all__'):
|
||||
__all__.extend(xpu.__all__)
|
||||
|
||||
if current_platform.is_npu():
|
||||
from . import npu
|
||||
from .npu import *
|
||||
if hasattr(npu, '__all__'):
|
||||
__all__.extend(npu.__all__)
|
||||
|
||||
if current_platform.is_gcu():
|
||||
from . import gcu
|
||||
from .gcu import *
|
||||
if hasattr(gcu, '__all__'):
|
||||
__all__.extend(gcu.__all__)
|
||||
|
31
fastdeploy/model_executor/layers/backends/gcu/__init__.py
Normal file
31
fastdeploy/model_executor/layers/backends/gcu/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
gcu backend methods
|
||||
"""
|
||||
|
||||
from .attention.flash_attn_backend import GCUFlashAttnBackend
|
||||
from .attention.mem_efficient_attn_backend import GCUMemEfficientAttnBackend
|
||||
from .moe.fused_moe_method_gcu_backend import (GCUFusedMoeMethod,
|
||||
GCUWeightOnlyMoEMethod)
|
||||
from .quantization.weight_only import GCUWeightOnlyLinearMethod
|
||||
|
||||
__all__ = [
|
||||
'GCUFlashAttnBackend',
|
||||
'GCUMemEfficientAttnBackend',
|
||||
'GCUFusedMoeMethod',
|
||||
'GCUWeightOnlyMoEMethod',
|
||||
'GCUWeightOnlyLinearMethod',
|
||||
]
|
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .flash_attn_backend import GCUFlashAttnBackend
|
||||
from .mem_efficient_attn_backend import GCUMemEfficientAttnBackend
|
||||
|
||||
__all__ = [
|
||||
"GCUFlashAttnBackend",
|
||||
"GCUMemEfficientAttnBackend",
|
||||
]
|
@@ -0,0 +1,287 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from paddle._typing.dtype_like import _DTypeLiteral
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta, ForwardMode
|
||||
|
||||
from fastdeploy.model_executor.ops.gcu import (fused_rotary_embedding,
|
||||
mem_efficient_attention,
|
||||
flash_attn_var_len)
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class GCUFlashAttnMetadata(AttentionMetadata):
|
||||
"""
|
||||
GCUFlashAttnMetadata
|
||||
"""
|
||||
forward_mode: ForwardMode = ForwardMode.MIXED
|
||||
|
||||
_dtype: _DTypeLiteral = paddle.bfloat16
|
||||
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||
seq_lens_decoder: Optional[paddle.Tensor] = None
|
||||
seq_lens_this_time: Optional[paddle.Tensor] = None
|
||||
cum_offsets: Optional[paddle.Tensor] = None
|
||||
padding_offset: Optional[paddle.Tensor] = None
|
||||
|
||||
cu_seqlens_q: Optional[paddle.Tensor] = None
|
||||
cu_seqlens_k: Optional[paddle.Tensor] = None
|
||||
caches: Optional[paddle.Tensor] = None
|
||||
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
|
||||
pre_caches_length: int = 0
|
||||
|
||||
|
||||
|
||||
|
||||
class GCUFlashAttnBackend(AttentionBackend):
|
||||
"""
|
||||
GCUFlashAttnBackend backend implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
|
||||
head_dim: int):
|
||||
"""
|
||||
GCUFlashAttnBackend __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: GCUFlashAttnMetadata = None
|
||||
self.block_size = fd_config.parallel_config.block_size
|
||||
self.max_seq_len = fd_config.parallel_config.max_model_len
|
||||
self.max_num_seqs = fd_config.parallel_config.max_num_seqs
|
||||
|
||||
self.causal = getattr(fd_config.model_config, "causal", True)
|
||||
|
||||
self.rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.kv_num_heads = kv_num_heads
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.scaling = 1.0 / (self.head_dim**0.5)
|
||||
self.num_layers = fd_config.model_config.num_layers
|
||||
self.position_ids_base = paddle.arange(self.max_seq_len)
|
||||
|
||||
# TODO(zhengjun): Need to adapt the allocation logic and
|
||||
# temporarily allocate according to fixed size
|
||||
self.all_block_tables: List[List[int]] = None
|
||||
self.all_slot_mapping: List[List[int]] = None
|
||||
|
||||
self.rotary_embs = None
|
||||
self.enable_monitor: bool = bool(os.getenv("FD_GCU_ATTN_MONITOR", False))
|
||||
|
||||
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = GCUFlashAttnMetadata()
|
||||
|
||||
metadata.forward_mode = forward_meta.forward_mode
|
||||
|
||||
metadata._dtype = paddle.get_default_dtype()
|
||||
|
||||
metadata.seq_lens_encoder = forward_meta.seq_lens_encoder
|
||||
metadata.seq_lens_decoder = forward_meta.seq_lens_decoder
|
||||
metadata.seq_lens_this_time = forward_meta.seq_lens_this_time
|
||||
metadata.cum_offsets = forward_meta.cum_offsets
|
||||
metadata.padding_offset = forward_meta.padding_offset
|
||||
|
||||
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||
metadata.cu_seqlens_k = forward_meta.cu_seqlens_k
|
||||
metadata.caches = forward_meta.caches
|
||||
|
||||
# metadata.block_tables = forward_meta.block_tables
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.attn_mask = forward_meta.attn_mask # not init
|
||||
|
||||
metadata.pre_caches_length = forward_meta.pre_caches_length # not inited
|
||||
|
||||
self.attention_metadata = metadata
|
||||
|
||||
if self.rotary_embs is None:
|
||||
self.rotary_embs = metadata.rotary_embs.reshape((-1, self.head_dim))
|
||||
|
||||
# some info for attention
|
||||
self.seq_lens_this_time_list = forward_meta.seq_lens_this_time.tolist() # List[int]
|
||||
self.seq_lens_encoder_list = forward_meta.seq_lens_encoder.tolist() # List[List[int]]
|
||||
self.seq_lens_decoder_list = forward_meta.seq_lens_decoder.tolist() # List[List[int]]
|
||||
self.seq_lens_sum = np.sum(self.seq_lens_this_time_list)
|
||||
self.max_seq_len_this_time = np.max(self.seq_lens_this_time_list)
|
||||
|
||||
num_seqs = forward_meta.seq_lens_this_time.shape[0]
|
||||
|
||||
|
||||
self.is_decoder = all(x[0] == 0 for x in self.seq_lens_encoder_list)
|
||||
self.is_all_prefill = all(x[0] == 0 for x in self.seq_lens_decoder_list)
|
||||
|
||||
# block_tables and slot_mapping
|
||||
if self.all_slot_mapping is None:
|
||||
max_num_blocks_per_seq = (self.max_seq_len + self.block_size - 1) // self.block_size
|
||||
total_blocks = max_num_blocks_per_seq * self.max_num_seqs
|
||||
self.all_block_tables = np.arange(0, total_blocks, dtype=np.int32).reshape((self.max_num_seqs, max_num_blocks_per_seq)).tolist()
|
||||
self.all_slot_mapping = np.arange(0, total_blocks * self.block_size, dtype=np.int32).reshape((self.max_num_seqs, -1)).tolist()
|
||||
|
||||
block_tables = []
|
||||
slot_mapping = []
|
||||
cache_slot_range = []
|
||||
cache_lens = []
|
||||
position_ids = []
|
||||
for seq_idx in range(num_seqs):
|
||||
cache_len = None
|
||||
if self.seq_lens_encoder_list[seq_idx][0] != 0: # prefill
|
||||
cache_len = 0
|
||||
elif self.seq_lens_decoder_list[seq_idx][0] != 0: # decode
|
||||
cache_len = self.seq_lens_decoder_list[seq_idx][0]
|
||||
# else: doesnot have req in this seq_idx
|
||||
|
||||
if cache_len is not None:
|
||||
lens_this_time = self.seq_lens_this_time_list[seq_idx]
|
||||
start = cache_len
|
||||
end = start + lens_this_time
|
||||
slot_mapping.extend(self.all_slot_mapping[seq_idx][start:end])
|
||||
cache_slot_range.extend(self.all_slot_mapping[seq_idx][0:end])
|
||||
cache_lens.append(end)
|
||||
block_tables.append(self.all_block_tables[seq_idx])
|
||||
position_ids.extend(self.position_ids_base[start:end])
|
||||
|
||||
self.block_tables = paddle.to_tensor(block_tables, dtype="int32")
|
||||
self.slot_mapping = paddle.to_tensor(slot_mapping, dtype="int32")
|
||||
self.cache_slot_range = paddle.to_tensor(cache_slot_range, dtype="int32")
|
||||
self.position_ids = paddle.to_tensor(position_ids, dtype="int32")
|
||||
self.position_ids = self.position_ids.reshape_((1, -1))
|
||||
|
||||
if self.enable_monitor:
|
||||
logger.info(f"[FD_DEBUG] init_attention_metadata, position_ids:\n{self.position_ids}")
|
||||
|
||||
cu_query_lens_data = [0]
|
||||
for seq_idx in range(num_seqs):
|
||||
if self.seq_lens_this_time_list[seq_idx] != 0:
|
||||
cu_query_lens_data.append(self.seq_lens_this_time_list[seq_idx])
|
||||
cu_query_lens = np.array(cu_query_lens_data, dtype=np.int32).cumsum(axis=0)
|
||||
|
||||
self.cu_query_lens = paddle.to_tensor(cu_query_lens, dtype="int32")
|
||||
self.seqused_k = paddle.to_tensor(cache_lens, dtype="int32")
|
||||
self.max_seqlen_q = self.max_seq_len_this_time
|
||||
self.max_seqlen_k = np.max(cache_lens)
|
||||
|
||||
|
||||
def get_attntion_meta(self):
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
):
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
"""
|
||||
# [total_tokens, kv_num_heads, head_dim]
|
||||
return (max_num_blocks * self.block_size,
|
||||
self.kv_num_heads,
|
||||
self.head_dim)
|
||||
|
||||
@paddle.no_grad()
|
||||
def forward_mixed(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
"""Run a forward for mixed."""
|
||||
token_num = qkv.shape[0]
|
||||
q_size = self.num_heads * self.head_dim
|
||||
kv_size = self.kv_num_heads * self.head_dim
|
||||
num_or_sections = [q_size, kv_size, kv_size]
|
||||
query, key, value = paddle.split(qkv, num_or_sections=num_or_sections, axis=-1)
|
||||
|
||||
query = query.reshape_((1, -1, self.num_heads, self.head_dim))
|
||||
key = key.reshape_((1, -1, self.kv_num_heads, self.head_dim))
|
||||
|
||||
|
||||
# 1. Rope
|
||||
if self.rotary_embs.dtype != query.dtype:
|
||||
self.rotary_embs = paddle.cast(self.rotary_embs, query.dtype)
|
||||
|
||||
query, key = fused_rotary_embedding(
|
||||
query,
|
||||
key,
|
||||
self.rotary_embs,
|
||||
self.position_ids,
|
||||
layer.use_neox_rotary_style
|
||||
)
|
||||
|
||||
# 2. Save kv cache
|
||||
# shape: [total_tokens, kv_num_heads, head_dim]
|
||||
key = key.reshape_((-1, self.kv_num_heads, self.head_dim))
|
||||
value = value.reshape_((-1, self.kv_num_heads, self.head_dim))
|
||||
key_caches = forward_meta.caches[2 * layer.layer_id]
|
||||
value_caches = forward_meta.caches[2 * layer.layer_id + 1]
|
||||
key_caches[self.slot_mapping, :, :] = key
|
||||
value_caches[self.slot_mapping, :, :] = value
|
||||
|
||||
# 3. calc attn
|
||||
query = query.reshape_((-1, self.num_heads, self.head_dim))
|
||||
key_caches = key_caches.reshape((-1, self.block_size, self.kv_num_heads, self.head_dim))
|
||||
value_caches = value_caches.reshape((-1, self.block_size, self.kv_num_heads, self.head_dim))
|
||||
res = flash_attn_var_len(
|
||||
query=query,
|
||||
key=key_caches,
|
||||
value=value_caches,
|
||||
cu_seqlens_q=self.cu_query_lens,
|
||||
cu_seqlens_k=None,
|
||||
seqused_k=self.seqused_k,
|
||||
leftpad_k=None,
|
||||
block_table=self.block_tables,
|
||||
alibi_slopes=None,
|
||||
max_seqlen_q=self.max_seqlen_q,
|
||||
max_seqlen_k=self.max_seqlen_k,
|
||||
p_dropout=0.0,
|
||||
softmax_scale=self.scaling,
|
||||
zero_tensors=False,
|
||||
is_causal=self.causal,
|
||||
window_size_left=-1,
|
||||
window_size_right=-1,
|
||||
softcap=0.0,
|
||||
return_softmax=False,
|
||||
)
|
||||
res = res.reshape_((token_num, -1))
|
||||
return res
|
||||
|
@@ -0,0 +1,357 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from paddle._typing.dtype_like import _DTypeLiteral
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta, ForwardMode
|
||||
|
||||
from fastdeploy.model_executor.ops.gcu import (fused_rotary_embedding,
|
||||
mem_efficient_attention,
|
||||
flash_attn_var_len)
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
@dataclass
|
||||
class GCUMemEfficientAttnMetadata(AttentionMetadata):
|
||||
"""
|
||||
GCUMemEfficientAttnMetadata
|
||||
"""
|
||||
forward_mode: ForwardMode = ForwardMode.MIXED
|
||||
_dtype: _DTypeLiteral = paddle.bfloat16
|
||||
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||
seq_lens_decoder: Optional[paddle.Tensor] = None
|
||||
seq_lens_this_time: Optional[paddle.Tensor] = None
|
||||
cum_offsets: Optional[paddle.Tensor] = None
|
||||
padding_offset: Optional[paddle.Tensor] = None
|
||||
|
||||
cu_seqlens_q: Optional[paddle.Tensor] = None
|
||||
cu_seqlens_k: Optional[paddle.Tensor] = None
|
||||
caches: Optional[paddle.Tensor] = None
|
||||
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
|
||||
pre_caches_length: int = 0
|
||||
|
||||
|
||||
|
||||
|
||||
class GCUMemEfficientAttnBackend(AttentionBackend):
|
||||
"""
|
||||
GCUMemEfficientAttnBackend backend implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
|
||||
head_dim: int):
|
||||
"""
|
||||
GCUMemEfficientAttnBackend __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: GCUMemEfficientAttnMetadata = None
|
||||
self.block_size = fd_config.parallel_config.block_size
|
||||
self.max_seq_len = fd_config.parallel_config.max_model_len
|
||||
self.max_num_seqs = fd_config.parallel_config.max_num_seqs
|
||||
|
||||
self.causal = getattr(fd_config.model_config, "causal", True)
|
||||
|
||||
self.rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.kv_num_heads = kv_num_heads
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.scaling = 1.0 / (self.head_dim**0.5)
|
||||
self.num_layers = fd_config.model_config.num_layers
|
||||
self.position_ids_base = paddle.arange(self.max_seq_len)
|
||||
|
||||
# TODO(zhengjun): Need to adapt the allocation logic and
|
||||
# temporarily allocate according to fixed size
|
||||
self.all_block_tables: List[List[int]] = None
|
||||
self.all_slot_mapping: List[List[int]] = None
|
||||
|
||||
self.rotary_embs = None
|
||||
self.use_paddle_native_sdpa = False
|
||||
|
||||
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = GCUMemEfficientAttnMetadata()
|
||||
|
||||
metadata.forward_mode = forward_meta.forward_mode
|
||||
|
||||
metadata._dtype = paddle.get_default_dtype()
|
||||
|
||||
metadata.seq_lens_encoder = forward_meta.seq_lens_encoder
|
||||
metadata.seq_lens_decoder = forward_meta.seq_lens_decoder
|
||||
metadata.seq_lens_this_time = forward_meta.seq_lens_this_time
|
||||
metadata.cum_offsets = forward_meta.cum_offsets
|
||||
metadata.padding_offset = forward_meta.padding_offset
|
||||
|
||||
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||
metadata.cu_seqlens_k = forward_meta.cu_seqlens_k
|
||||
metadata.caches = forward_meta.caches
|
||||
|
||||
# metadata.block_tables = forward_meta.block_tables
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.attn_mask = forward_meta.attn_mask # not init
|
||||
|
||||
metadata.pre_caches_length = forward_meta.pre_caches_length # not inited
|
||||
|
||||
|
||||
self.attention_metadata = metadata
|
||||
|
||||
if self.rotary_embs is None:
|
||||
self.rotary_embs = metadata.rotary_embs.reshape((-1, self.head_dim))
|
||||
|
||||
# some info for attention
|
||||
self.seq_lens_this_time_list = forward_meta.seq_lens_this_time.tolist() # List[int]
|
||||
self.seq_lens_encoder_list = forward_meta.seq_lens_encoder.tolist() # List[List[int]]
|
||||
self.seq_lens_decoder_list = forward_meta.seq_lens_decoder.tolist() # List[List[int]]
|
||||
self.seq_lens_sum = np.sum(self.seq_lens_this_time_list)
|
||||
self.max_seq_len_this_time = np.max(self.seq_lens_this_time_list)
|
||||
|
||||
num_seqs = forward_meta.seq_lens_this_time.shape[0]
|
||||
|
||||
|
||||
self.is_decoder = all(x[0] == 0 for x in self.seq_lens_encoder_list)
|
||||
self.is_all_prefill = all(x[0] == 0 for x in self.seq_lens_decoder_list)
|
||||
|
||||
|
||||
# block_tables and slot_mapping
|
||||
if self.all_slot_mapping is None:
|
||||
max_num_blocks_per_seq = (self.max_seq_len + self.block_size - 1) // self.block_size
|
||||
total_blocks = max_num_blocks_per_seq * self.max_num_seqs
|
||||
self.all_block_tables = np.arange(0, total_blocks, dtype=np.int32).reshape((self.max_num_seqs, max_num_blocks_per_seq)).tolist()
|
||||
self.all_slot_mapping = np.arange(0, total_blocks * self.block_size, dtype=np.int32).reshape((self.max_num_seqs, -1)).tolist()
|
||||
|
||||
block_tables = []
|
||||
slot_mapping = []
|
||||
cache_slot_range = []
|
||||
cache_lens = []
|
||||
query_lens = []
|
||||
cached_kv_lens = []
|
||||
cached_kv_slot_range = []
|
||||
position_ids = []
|
||||
for seq_idx in range(num_seqs):
|
||||
cache_len = None
|
||||
if self.seq_lens_encoder_list[seq_idx][0] != 0: # prefill
|
||||
cache_len = 0
|
||||
elif self.seq_lens_decoder_list[seq_idx][0] != 0: # decode
|
||||
cache_len = self.seq_lens_decoder_list[seq_idx][0]
|
||||
# else: doesnot have req in this seq_idx
|
||||
|
||||
if cache_len is not None:
|
||||
lens_this_time = self.seq_lens_this_time_list[seq_idx]
|
||||
start = cache_len
|
||||
end = start + lens_this_time
|
||||
slot_mapping.extend(self.all_slot_mapping[seq_idx][start:end])
|
||||
cache_slot_range.extend(self.all_slot_mapping[seq_idx][0:end])
|
||||
cache_lens.append(end)
|
||||
block_tables.append(self.all_block_tables[seq_idx])
|
||||
position_ids.extend(self.position_ids_base[start:end])
|
||||
query_lens.append(lens_this_time)
|
||||
cached_kv_lens.append(end)
|
||||
cached_kv_slot_range.append([self.all_slot_mapping[seq_idx][0], self.all_slot_mapping[seq_idx][end]])
|
||||
|
||||
|
||||
|
||||
self.block_tables = paddle.to_tensor(block_tables, dtype="int32")
|
||||
self.slot_mapping = paddle.to_tensor(slot_mapping, dtype="int32")
|
||||
self.cache_slot_range = paddle.to_tensor(cache_slot_range, dtype="int32")
|
||||
self.position_ids = paddle.to_tensor(position_ids, dtype="int32")
|
||||
self.position_ids = self.position_ids.reshape_((1, -1))
|
||||
|
||||
logger.info(f"[FD_DEBUG] init_attention_metadata, self.position_ids:\n{self.position_ids}")
|
||||
|
||||
cu_query_lens_data = [0]
|
||||
for seq_idx in range(num_seqs):
|
||||
if self.seq_lens_this_time_list[seq_idx] != 0:
|
||||
cu_query_lens_data.append(self.seq_lens_this_time_list[seq_idx])
|
||||
cu_query_lens = np.array(cu_query_lens_data, dtype=np.int32).cumsum(axis=0)
|
||||
|
||||
self.cu_query_lens = paddle.to_tensor(cu_query_lens, dtype="int32")
|
||||
self.seqused_k = paddle.to_tensor(cache_lens, dtype="int32")
|
||||
self.max_seqlen_q = self.max_seq_len_this_time
|
||||
self.max_seqlen_k = np.max(cache_lens)
|
||||
|
||||
self.query_lens = query_lens
|
||||
self.cached_kv_lens = cached_kv_lens
|
||||
self.cached_kv_slot_range = cached_kv_slot_range
|
||||
|
||||
|
||||
def get_attntion_meta(self):
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
):
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
"""
|
||||
# [total_tokens, kv_num_heads, head_dim]
|
||||
return (max_num_blocks * self.block_size,
|
||||
self.kv_num_heads,
|
||||
self.head_dim)
|
||||
|
||||
@paddle.no_grad()
|
||||
def forward_mixed(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
"""Run a forward for mixed."""
|
||||
token_num = qkv.shape[0]
|
||||
q_size = self.num_heads * self.head_dim
|
||||
kv_size = self.kv_num_heads * self.head_dim
|
||||
num_or_sections = [q_size, kv_size, kv_size]
|
||||
query, key, value = paddle.split(qkv, num_or_sections=num_or_sections, axis=-1)
|
||||
|
||||
query = query.reshape_((1, -1, self.num_heads, self.head_dim))
|
||||
key = key.reshape_((1, -1, self.kv_num_heads, self.head_dim))
|
||||
|
||||
|
||||
# 1. Rope
|
||||
if self.rotary_embs.dtype != query.dtype:
|
||||
self.rotary_embs = paddle.cast(self.rotary_embs, query.dtype)
|
||||
|
||||
query, key = fused_rotary_embedding(
|
||||
query,
|
||||
key,
|
||||
self.rotary_embs,
|
||||
self.position_ids,
|
||||
layer.use_neox_rotary_style
|
||||
)
|
||||
|
||||
# 2. Save kv cache
|
||||
# shape: [total_tokens, kv_num_heads, head_dim]
|
||||
key = key.reshape_((-1, self.kv_num_heads, self.head_dim))
|
||||
value = value.reshape_((-1, self.kv_num_heads, self.head_dim))
|
||||
key_caches = forward_meta.caches[2 * layer.layer_id]
|
||||
value_caches = forward_meta.caches[2 * layer.layer_id + 1]
|
||||
key_caches[self.slot_mapping, :, :] = key
|
||||
value_caches[self.slot_mapping, :, :] = value
|
||||
|
||||
# 3. calc attn
|
||||
query = query.reshape_((-1, self.num_heads, self.head_dim))
|
||||
|
||||
q_start = 0
|
||||
result = paddle.empty_like(query)
|
||||
for idx in range(len(self.query_lens)):
|
||||
q_end = q_start + self.query_lens[idx]
|
||||
kv_start = self.cached_kv_slot_range[idx][0]
|
||||
kv_end = self.cached_kv_slot_range[idx][1]
|
||||
|
||||
q_ = query[q_start:q_end, :, :]
|
||||
k_ = key_caches[kv_start:kv_end, :, :]
|
||||
v_ = value_caches[kv_start:kv_end, :, :]
|
||||
|
||||
if self.use_paddle_native_sdpa:
|
||||
res = self.native_sdpa_impl(
|
||||
q_, k_, v_
|
||||
)
|
||||
else:
|
||||
res = mem_efficient_attention(
|
||||
query=q_.unsqueeze(0),
|
||||
key=k_.unsqueeze(0),
|
||||
value=v_.unsqueeze(0),
|
||||
attn_mask=None,
|
||||
dropout=0.0,
|
||||
softmax_scale=self.scaling,
|
||||
mask_mode=1,
|
||||
seqlens=[0],
|
||||
causal=self.causal,
|
||||
)
|
||||
result[q_start:q_end, :, :] = res
|
||||
q_start = q_end
|
||||
result = result.reshape_((token_num, -1))
|
||||
return result
|
||||
|
||||
|
||||
def get_triangle_upper_mask(self, shape, dtype):
|
||||
# [batch_size, 1, q_seq_len, kv_seq_len]
|
||||
shape[1] = 1
|
||||
q_seq_len = shape[2]
|
||||
kv_seq_len = shape[3]
|
||||
paddle_dtype = dtype # paddle.base.data_feeder.convert_dtype(dtype)
|
||||
mask = paddle.full(shape, paddle.finfo(paddle_dtype).min, dtype=paddle_dtype)
|
||||
mask = paddle.triu(mask, diagonal=kv_seq_len - q_seq_len + 1)
|
||||
return mask
|
||||
|
||||
|
||||
def native_sdpa_impl(self, query, key, value):
|
||||
# input shape: [num_tokens, num_heads, head_dim] -> [1, num_tokens, num_heads, head_dim]
|
||||
q = query.unsqueeze(0)
|
||||
k = key.unsqueeze(0)
|
||||
v = value.unsqueeze(0)
|
||||
batch, q_seq_len, heads, head_dim = q.shape
|
||||
kv_seq_len = k.shape[1]
|
||||
|
||||
# [batch_size, seq_len, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim]
|
||||
q = paddle.transpose(q, [0, 2, 1, 3])
|
||||
k = paddle.transpose(k, [0, 2, 1, 3])
|
||||
v = paddle.transpose(v, [0, 2, 1, 3])
|
||||
|
||||
# GQA
|
||||
if q.shape[1] != k.shape[1]:
|
||||
kv_head = k.shape[1]
|
||||
|
||||
k = k.reshape([batch, kv_head, 1, kv_seq_len, head_dim])
|
||||
k = paddle.tile(k, [1, 1, heads // kv_head, 1, 1])
|
||||
k = k.reshape([batch, heads, kv_seq_len, head_dim])
|
||||
|
||||
v = v.reshape([batch, kv_head, 1, kv_seq_len, head_dim])
|
||||
v = paddle.tile(v, [1, 1, heads // kv_head, 1, 1])
|
||||
v = v.reshape([batch, heads, kv_seq_len, head_dim])
|
||||
|
||||
# matmul and devide by sqrt(head_dim)
|
||||
attn_weights = paddle.matmul(q / math.sqrt(head_dim), k.transpose([0, 1, 3, 2]))
|
||||
|
||||
attention_mask = self.get_triangle_upper_mask(
|
||||
[batch, 1, q_seq_len, kv_seq_len], q.dtype
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = paddle.nn.functional.softmax(
|
||||
attn_weights, axis=-1, dtype="float32"
|
||||
).astype(q.dtype)
|
||||
|
||||
attn_output = paddle.matmul(attn_weights, v)
|
||||
attn_output = attn_output.transpose([0, 2, 1, 3])
|
||||
return attn_output.squeeze(0)
|
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""""
|
||||
gcu moe
|
||||
"""
|
@@ -0,0 +1,402 @@
|
||||
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import \
|
||||
MoEMethodBase
|
||||
from fastdeploy.model_executor.layers.utils import (CpuGuard,
|
||||
create_and_set_parameter,
|
||||
get_tensor)
|
||||
from fastdeploy.model_executor.ops.gcu import (invoke_fused_moe_kernel,
|
||||
moe_align_block_size,
|
||||
topk_softmax,
|
||||
weight_quantize_custom_rtn,
|
||||
weight_quantize_rtn)
|
||||
|
||||
|
||||
class GCUFusedMoeMethod(MoEMethodBase):
|
||||
"""
|
||||
Use GCU to compute Fused MoE.
|
||||
"""
|
||||
def __init__(self, quant_config):
|
||||
super().__init__(quant_config)
|
||||
self.group_size = -1
|
||||
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle gcu create weight process.
|
||||
"""
|
||||
# bf16
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0)
|
||||
stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0)
|
||||
for idx, weight_tensor in enumerate(
|
||||
[stacked_ffn1_weights, stacked_ffn2_weights]):
|
||||
# shape [E, K, N] -> [E, N, K]
|
||||
weight_tensor = paddle.transpose(weight_tensor, [0, 2, 1])
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
setattr(
|
||||
layer, weight_name,
|
||||
layer.create_parameter(
|
||||
shape=weight_tensor.shape,
|
||||
dtype=weight_tensor.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
))
|
||||
getattr(layer, weight_name).set_value(weight_tensor)
|
||||
|
||||
|
||||
@paddle.no_grad()
|
||||
def compute_ffn(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
enable_quant = False
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle gcu compute Fused MoE.
|
||||
"""
|
||||
token_num, hidden_size = x.shape
|
||||
top_k = layer.top_k
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
num_experts = layer.num_local_experts
|
||||
|
||||
topk_weights = paddle.empty([token_num, top_k], dtype=gate_out.dtype)
|
||||
topk_indices = paddle.empty([token_num, top_k], dtype="int32")
|
||||
token_expert_indices = paddle.empty([token_num, top_k], dtype="int32",)
|
||||
topk_softmax(topk_weights, topk_indices, token_expert_indices, gate_out, norm_topk_prob=True)
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
|
||||
block_size = config["BLOCK_SIZE_M"]
|
||||
max_num_tokens_padded = np.prod(topk_indices.shape) + num_experts * (block_size - 1)
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
sorted_token_ids = paddle.empty([max_num_tokens_padded], dtype="int32")
|
||||
expert_ids = paddle.zeros(shape=[max_num_m_blocks], dtype="int32")
|
||||
num_tokens_post_pad = paddle.empty([1], dtype="int32")
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_pad = moe_align_block_size(
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
topk_indices,
|
||||
num_experts,
|
||||
block_size,
|
||||
)
|
||||
|
||||
intermediate_cache1 = paddle.empty(
|
||||
[token_num, top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ffn1_B_scale = layer.moe_ffn1_weight_scale if enable_quant else None
|
||||
ffn1_B_zeros = layer.moe_ffn1_weight_zeros if enable_quant else None
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
x, # input
|
||||
layer.moe_ffn1_weight, # weight
|
||||
intermediate_cache1, # output
|
||||
None, # A_scale
|
||||
ffn1_B_scale, # B_scale
|
||||
ffn1_B_zeros, # B_zp
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
False, # mul_routed_weight
|
||||
top_k,
|
||||
config,
|
||||
enable_quant, # use_int4_w4a16
|
||||
[0, self.group_size], # block_shape
|
||||
)
|
||||
|
||||
intermediate_cache2 = paddle.empty(
|
||||
(token_num, top_k, moe_intermediate_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
|
||||
intermediate_cache1)
|
||||
|
||||
intermediate_cache2 = intermediate_cache2.reshape([-1, moe_intermediate_size])
|
||||
|
||||
intermediate_cache3 = paddle.empty(
|
||||
(token_num, top_k, hidden_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ffn2_B_scale = layer.moe_ffn2_weight_scale if enable_quant else None
|
||||
ffn2_B_zeros = layer.moe_ffn2_weight_zeros if enable_quant else None
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
intermediate_cache2, # input
|
||||
layer.moe_ffn2_weight, # weight
|
||||
intermediate_cache3, # output
|
||||
None, # A_scale
|
||||
ffn2_B_scale, # B_scale
|
||||
ffn2_B_zeros, # B_zp
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
True, # mul_routed_weight
|
||||
1,
|
||||
config,
|
||||
enable_quant, # use_int4_w4a16
|
||||
[0, self.group_size], # block_shape
|
||||
)
|
||||
|
||||
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
|
||||
fused_moe_out = intermediate_cache3.sum(axis=1)
|
||||
fused_moe_out = fused_moe_out.reshape_([token_num, hidden_size])
|
||||
|
||||
if layer.tp_size > 1:
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle gcu compute Fused MoE.
|
||||
"""
|
||||
return self.compute_ffn(layer, x, gate_out, enable_quant=False)
|
||||
|
||||
|
||||
def apply_ep_prefill(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def apply_ep_decode(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def apply_tp(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
"""
|
||||
weight only for moe
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config):
|
||||
super().__init__(quant_config)
|
||||
self.quant_config = quant_config
|
||||
self.moe_quant_type = self.quant_config.algo
|
||||
self.pack_num = 1
|
||||
|
||||
assert self.quant_config.algo == "weight_only_int4", \
|
||||
"GCUWeightOnlyMoEMethod only support weight_only_int4, but got:{self.quant_config.algo}"
|
||||
|
||||
self.added_qzeros_attrs = [
|
||||
"moe_ffn1_weight_zeros", "moe_ffn2_weight_zeros"
|
||||
]
|
||||
self.group_size = 64
|
||||
|
||||
self.quant_multi_process_group_size = int(
|
||||
os.getenv("FD_MOE_QUANT_MULTI_PROCESS_GROUP_SIZE", 8)
|
||||
)
|
||||
logger.info(f"GCUWeightOnlyMoEMethod quant_multi_process_group_size: {self.quant_multi_process_group_size}")
|
||||
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle gcu process prequanted weights.
|
||||
"""
|
||||
ffn1_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_key", None)
|
||||
ffn2_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_key", None)
|
||||
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_scale_key", None)
|
||||
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_scale_key", None)
|
||||
|
||||
ffn1_weights, ffn2_weights = layer.load_experts_weight(
|
||||
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
|
||||
# self.check(layer, ffn1_weights, ffn2_weights)
|
||||
ffn1_weight_scale = []
|
||||
ffn2_weight_scale = []
|
||||
for i in range(layer.num_experts):
|
||||
expert_idx = layer.expert_id_offset + i
|
||||
ffn1_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn1_expert_weight_scale_key.format(expert_idx))))
|
||||
ffn2_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_weight_scale_key.format(expert_idx))))
|
||||
|
||||
ffn1_weight = paddle.stack(ffn1_weights, axis=0)
|
||||
ffn2_weight = paddle.stack(ffn2_weights, axis=0)
|
||||
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0)
|
||||
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0)
|
||||
|
||||
name_tensor_map = {
|
||||
"moe_ffn1_weight": ffn1_weight,
|
||||
"moe_ffn2_weight": ffn2_weight,
|
||||
"moe_ffn1_weight_scale": ffn1_weight_scale,
|
||||
"moe_ffn2_weight_scale": ffn2_weight_scale
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
create_and_set_parameter(layer, name, tensor)
|
||||
|
||||
|
||||
@paddle.no_grad()
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
self.check(layer, ffn1_weights, ffn2_weights)
|
||||
|
||||
|
||||
def quant_worker(p_group_idx, shared_dict, weights, moe_quant_type, group_size):
|
||||
with CpuGuard():
|
||||
p_group_size = len(weights)
|
||||
for group_j in range(p_group_size):
|
||||
# weight shape [K, N] -> [N/2, K] -> [N, K/2]
|
||||
quant_weight, scale = weight_quantize_custom_rtn(
|
||||
weights[group_j],
|
||||
moe_quant_type,
|
||||
group_size # group_size
|
||||
)
|
||||
shared_dict[p_group_size * p_group_idx + group_j] = (
|
||||
quant_weight, scale
|
||||
)
|
||||
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
zeros_name = self.added_qzeros_attrs[idx]
|
||||
|
||||
if self.quant_multi_process_group_size > 0:
|
||||
process_group_size = self.quant_multi_process_group_size
|
||||
process_group_num = layer.num_local_experts // process_group_size
|
||||
grouped_weights_num = process_group_num * process_group_size
|
||||
remain_weights_start_idx = grouped_weights_num
|
||||
|
||||
weight_list = [None] * grouped_weights_num
|
||||
weight_scale_list = [None] * grouped_weights_num
|
||||
|
||||
with multiprocessing.Manager() as manager:
|
||||
shared_dict = manager.dict({})
|
||||
processes = []
|
||||
|
||||
for i in range(process_group_num):
|
||||
w = []
|
||||
for j in range(process_group_size):
|
||||
w.append(weight_tensor[process_group_size * i + j].to("cpu"))
|
||||
|
||||
p = multiprocessing.Process(
|
||||
target=quant_worker,
|
||||
args=(i, shared_dict, w, self.moe_quant_type, self.group_size)
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
dict_ = dict(shared_dict)
|
||||
|
||||
for k, v in dict_.items():
|
||||
weight_list[k] = v[0].to(ffn1_weights[0].place)
|
||||
weight_scale_list[k] = v[1].to(ffn1_weights[0].place)
|
||||
else:
|
||||
remain_weights_start_idx = 0
|
||||
|
||||
if remain_weights_start_idx < layer.num_local_experts:
|
||||
for i in range(remain_weights_start_idx, layer.num_local_experts):
|
||||
# weight shape [K, N] -> [N/2, K] -> [N, K/2]
|
||||
quant_weight, scale = weight_quantize_rtn(
|
||||
weight_tensor[i],
|
||||
self.moe_quant_type,
|
||||
self.group_size # group_size
|
||||
)
|
||||
weight_list.append(quant_weight)
|
||||
weight_scale_list.append(scale)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
create_and_set_parameter(layer, weight_name, quanted_weight)
|
||||
|
||||
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
|
||||
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
|
||||
|
||||
quanted_weight_zeros = quanted_weight_scale * 8
|
||||
create_and_set_parameter(layer, zeros_name, quanted_weight_zeros)
|
||||
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle gcu compute Fused MoE.
|
||||
"""
|
||||
return self.compute_ffn(layer, x, gate_out, enable_quant=True)
|
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""""
|
||||
gcu quantization
|
||||
"""
|
||||
from .weight_only import GCUWeightOnlyLinearMethod
|
||||
|
||||
__all__ = [
|
||||
"GCUWeightOnlyLinearMethod",
|
||||
]
|
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.weight_only import (
|
||||
WeightOnlyConfig, WeightOnlyLinearMethod)
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.ops.gcu import linear_quant, weight_quantize_rtn
|
||||
|
||||
|
||||
class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
"""
|
||||
Weight only quantization method for linear layer on GCU
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: WeightOnlyConfig,
|
||||
) -> None:
|
||||
super().__init__(quant_config)
|
||||
self.quant_config = quant_config
|
||||
self.group_size = -1
|
||||
|
||||
|
||||
def create_weights(self, layer):
|
||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
|
||||
|
||||
layer.linear_weight_shape.reverse()
|
||||
if self.quant_config.name() == "wint4":
|
||||
layer.linear_weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
shape=linear_weight_scale_shape,
|
||||
dtype=layer._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
|
||||
def process_prequanted_weights(self, layer, state_dict) -> None:
|
||||
"""
|
||||
Process pre-quantized weights before applying them to the model
|
||||
Args:
|
||||
layer: The layer that owns the weights
|
||||
quant_weight: The quantized weights
|
||||
weight_scale: The scale of the quantized weights
|
||||
"""
|
||||
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
|
||||
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
|
||||
layer.linear_weight.set_value(quant_weight)
|
||||
layer.linear_weight_scale.set_value(
|
||||
weight_scale.astype(paddle.get_default_dtype()))
|
||||
|
||||
|
||||
def process_loaded_weights(self, layer, weight) -> None:
|
||||
quanted_weight_tensor, weight_scale_tensor = weight_quantize_rtn(
|
||||
weight,
|
||||
self.quant_config.algo,
|
||||
self.group_size, # group_size
|
||||
)
|
||||
|
||||
layer.linear_weight.set_value(quanted_weight_tensor)
|
||||
layer.linear_weight_scale.set_value(
|
||||
weight_scale_tensor.astype(paddle.get_default_dtype()))
|
||||
|
||||
|
||||
@paddle.no_grad()
|
||||
def apply(self, layer, x):
|
||||
linear_out = linear_quant(
|
||||
lhs=x,
|
||||
rhs=layer.linear_weight,
|
||||
scale=layer.linear_weight_scale,
|
||||
bias=None,
|
||||
group_size=self.group_size,
|
||||
)
|
||||
return linear_out
|
@@ -58,7 +58,7 @@ class LinearBase(nn.Layer):
|
||||
"""
|
||||
super().__init__()
|
||||
if current_platform.is_cuda() or current_platform.is_xpu(
|
||||
) or current_platform.is_iluvatar():
|
||||
) or current_platform.is_iluvatar() or current_platform.is_gcu():
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@@ -20,6 +20,7 @@ from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
class FusedMoE(nn.Layer):
|
||||
@@ -95,8 +96,13 @@ class FusedMoE(nn.Layer):
|
||||
self.moe_quant_type = moe_quant_config.name()
|
||||
else:
|
||||
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
|
||||
from .fused_moe_cutlass_backend import CutlassMoEMethod
|
||||
self.quant_method = CutlassMoEMethod(None)
|
||||
if current_platform.is_cuda():
|
||||
from .fused_moe_cutlass_backend import CutlassMoEMethod
|
||||
self.quant_method = CutlassMoEMethod(None)
|
||||
elif current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.layers.backends import \
|
||||
GCUFusedMoeMethod
|
||||
self.quant_method = GCUFusedMoeMethod(None)
|
||||
|
||||
if self.ep_size > 1:
|
||||
self.quant_method.init_ep(self)
|
||||
|
@@ -19,9 +19,14 @@ from typing import Callable, Dict, Optional
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.ops.gcu import fused_add_rms_norm, rms_norm
|
||||
else:
|
||||
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
from .utils import get_tensor
|
||||
@@ -69,7 +74,10 @@ class RMSNorm(nn.Layer):
|
||||
self.weight_key: Optional[str] = f"{prefix}.weight"
|
||||
self.with_weight: bool = self.weight_key is not None
|
||||
self.eps: float = eps
|
||||
self.norm_func: Callable = fused_rms_norm
|
||||
if current_platform.is_gcu():
|
||||
self.norm_func: Callable = fused_add_rms_norm
|
||||
else:
|
||||
self.norm_func: Callable = fused_rms_norm
|
||||
self.linear_bias: Optional[paddle.Tensor] = linear_bias
|
||||
self.quant_scale: Optional[float] = quant_scale
|
||||
self._dtype: str = self._helper.get_default_dtype()
|
||||
@@ -129,19 +137,26 @@ class RMSNorm(nn.Layer):
|
||||
The `residual_output` is the result of applying the normalization and possibly other
|
||||
operations (like linear transformation) on the `residual_input`.
|
||||
"""
|
||||
norm_out = self.norm_func(
|
||||
x,
|
||||
norm_weight=self.ln_weight,
|
||||
norm_bias=None,
|
||||
epsilon=self.eps,
|
||||
begin_norm_axis=self.begin_norm_axis,
|
||||
bias=self.linear_bias,
|
||||
residual=residual_input,
|
||||
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
|
||||
quant_round_type=self.quant_round_type,
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
if current_platform.is_gcu():
|
||||
if residual_input is None:
|
||||
return rms_norm(x, self.ln_weight, self.eps)
|
||||
norm_out = self.norm_func(
|
||||
x, residual_input, self.ln_weight, self.eps
|
||||
)
|
||||
else:
|
||||
norm_out = self.norm_func(
|
||||
x,
|
||||
norm_weight=self.ln_weight,
|
||||
norm_bias=None,
|
||||
epsilon=self.eps,
|
||||
begin_norm_axis=self.begin_norm_axis,
|
||||
bias=self.linear_bias,
|
||||
residual=residual_input,
|
||||
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
|
||||
quant_round_type=self.quant_round_type,
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
if residual_input is not None:
|
||||
return norm_out[0], norm_out[1]
|
||||
else:
|
||||
@@ -193,7 +208,10 @@ class LayerNorm(nn.Layer):
|
||||
self.with_bias: bool = with_bias
|
||||
self.eps: float = eps
|
||||
self.quant_scale: float = quant_scale
|
||||
self.norm_func: Callable = fused_layer_norm
|
||||
if current_platform.is_gcu():
|
||||
self.norm_func: Callable = paddle.nn.functional.layer_norm
|
||||
else:
|
||||
self.norm_func: Callable = fused_layer_norm
|
||||
self.linear_bias: Optional[paddle.Tensor] = linear_bias
|
||||
self._dtype: str = self._helper.get_default_dtype()
|
||||
self._norm_weight_dtype: str = "float32"
|
||||
@@ -279,19 +297,40 @@ class LayerNorm(nn.Layer):
|
||||
else:
|
||||
raise NotImplementedError("Iluvatar does not support yet!")
|
||||
|
||||
norm_out = self.norm_func(
|
||||
x,
|
||||
norm_weight=self.ln_weight,
|
||||
norm_bias=self.ln_bias,
|
||||
epsilon=self.eps,
|
||||
begin_norm_axis=1,
|
||||
bias=self.linear_bias,
|
||||
residual=residual_input,
|
||||
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
|
||||
quant_round_type=self.quant_round_type,
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
if current_platform.is_gcu():
|
||||
if residual_input is not None:
|
||||
y = x + residual_input
|
||||
out = self.norm_func(
|
||||
x=y,
|
||||
normalized_shape=y.shape[1:],
|
||||
weight=self.ln_weight,
|
||||
bias=self.linear_bias,
|
||||
epsilon=self.eps,
|
||||
)
|
||||
return out, y
|
||||
else:
|
||||
out = self.norm_func(
|
||||
x=x,
|
||||
normalized_shape=x.shape[1:],
|
||||
weight=self.ln_weight,
|
||||
bias=self.linear_bias,
|
||||
epsilon=self.eps,
|
||||
)
|
||||
return out
|
||||
else:
|
||||
norm_out = self.norm_func(
|
||||
x,
|
||||
norm_weight=self.ln_weight,
|
||||
norm_bias=self.ln_bias,
|
||||
epsilon=self.eps,
|
||||
begin_norm_axis=1,
|
||||
bias=self.linear_bias,
|
||||
residual=residual_input,
|
||||
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
|
||||
quant_round_type=self.quant_round_type,
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
if residual_input is not None:
|
||||
return norm_out[0], norm_out[1]
|
||||
else:
|
||||
|
@@ -66,6 +66,13 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
return XPUWeightOnlyMoEMethod(self)
|
||||
else:
|
||||
return XPUWeightOnlyLinearMethod(self)
|
||||
elif current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
GCUWeightOnlyLinearMethod, GCUWeightOnlyMoEMethod)
|
||||
if isinstance(layer, FusedMoE):
|
||||
return GCUWeightOnlyMoEMethod(self)
|
||||
else:
|
||||
return GCUWeightOnlyLinearMethod(self)
|
||||
else:
|
||||
if isinstance(layer, FusedMoE):
|
||||
if layer.use_method == "cutlass":
|
||||
|
@@ -55,6 +55,10 @@ class ErnieRotaryEmbedding:
|
||||
dtype="float32")
|
||||
emb = paddle.stack([freqs, freqs], axis=-1).reshape(
|
||||
(bsz, max_seq_len, self.rotary_dim))
|
||||
elif current_platform.is_gcu():
|
||||
# shape: [B, S, D]
|
||||
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
|
||||
return rot_emb
|
||||
else:
|
||||
# shape: [B, S, D/2]
|
||||
rot_emb = paddle.zeros(
|
||||
@@ -95,6 +99,10 @@ class QwenRotaryEmbedding:
|
||||
# shape: [B, S, D/2]
|
||||
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"),
|
||||
inv_freq)
|
||||
if current_platform.is_gcu():
|
||||
# shape: [B, S, D]
|
||||
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
|
||||
return rot_emb
|
||||
# shape: [B, S, 1, D]
|
||||
emb = paddle.concat([freqs, freqs], axis=-1).reshape(
|
||||
(bsz, max_seq_len, 1, self.rotary_dim))
|
||||
|
@@ -79,6 +79,21 @@ def apply_penalty_multi_scores(
|
||||
min_dec_lens,
|
||||
eos_token_ids,
|
||||
)
|
||||
elif current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.ops.gcu import \
|
||||
get_token_penalty_multi_scores
|
||||
logits = get_token_penalty_multi_scores(
|
||||
pre_token_ids,
|
||||
logits,
|
||||
repetition_penalties,
|
||||
frequency_penalties,
|
||||
presence_penalties,
|
||||
temperature,
|
||||
bad_words_token_ids,
|
||||
step_idx,
|
||||
min_dec_lens,
|
||||
eos_token_ids,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@@ -19,7 +19,11 @@ from typing import Literal, Optional
|
||||
import paddle
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.ops.gcu import \
|
||||
top_p_sampling as gcu_top_p_sampling
|
||||
|
||||
def top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
@@ -46,13 +50,16 @@ def top_p_sampling(
|
||||
ids = rejection_top_p_sampling(x, ps, seed)
|
||||
_ = None
|
||||
else:
|
||||
_, ids = paddle.tensor.top_p_sampling(x,
|
||||
ps,
|
||||
threshold=threshold,
|
||||
topp_seed=topp_seed,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode=mode)
|
||||
if current_platform.is_gcu():
|
||||
_, ids = gcu_top_p_sampling(x, ps)
|
||||
else:
|
||||
_, ids = paddle.tensor.top_p_sampling(x,
|
||||
ps,
|
||||
threshold=threshold,
|
||||
topp_seed=topp_seed,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode=mode)
|
||||
return _, ids
|
||||
|
||||
|
||||
|
@@ -171,7 +171,7 @@ class Sampler(nn.Layer):
|
||||
"""
|
||||
super().__init__()
|
||||
if current_platform.is_cuda() or current_platform.is_xpu(
|
||||
) or current_platform.is_iluvatar():
|
||||
) or current_platform.is_iluvatar() or current_platform.is_gcu():
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
@@ -17,5 +17,6 @@ from . import cpu
|
||||
from . import xpu
|
||||
from . import npu
|
||||
from . import iluvatar
|
||||
from . import gcu
|
||||
|
||||
__all__ = ["gpu", "cpu", "xpu", "npu", "iluvatar"]
|
||||
__all__ = ["gpu", "cpu", "xpu", "npu", "iluvatar", "gcu"]
|
||||
|
116
fastdeploy/model_executor/ops/gcu/__init__.py
Normal file
116
fastdeploy/model_executor/ops/gcu/__init__.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
""" fastdeploy gcu ops """
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from fastdeploy.import_ops import import_custom_ops, rename_imported_op
|
||||
|
||||
PACKAGE = "fastdeploy.model_executor.ops.gcu"
|
||||
|
||||
import_custom_ops(PACKAGE, ".fastdeploy_ops", globals())
|
||||
|
||||
if current_platform.is_gcu():
|
||||
from paddle_custom_device.gcu.ops import (invoke_fused_moe_kernel, # noqa: F401,E402
|
||||
moe_align_block_size, top_p_sampling, # noqa: F401
|
||||
topk_softmax, # noqa: F401
|
||||
weight_quantize_custom_rtn, # noqa: F401
|
||||
weight_quantize_rtn) # noqa: F401
|
||||
|
||||
# ###################### Ops from PaddleCustomDevice ####################
|
||||
rename_imported_op(
|
||||
old_name="fused_rotary_embedding_gcu",
|
||||
new_name="fused_rotary_embedding",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="reshape_and_cache_gcu",
|
||||
new_name="reshape_and_cache",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="paged_attention_gcu",
|
||||
new_name="paged_attention",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="mem_efficient_attention_gcu",
|
||||
new_name="mem_efficient_attention",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="flash_attn_var_len_gcu",
|
||||
new_name="flash_attn_var_len",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="rms_norm_gcu",
|
||||
new_name="rms_norm",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="fused_add_rms_norm_op",
|
||||
new_name="fused_add_rms_norm",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="linear_quant_gcu",
|
||||
new_name="linear_quant",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
|
||||
# ###################### CPU OPS ####################
|
||||
rename_imported_op(
|
||||
old_name="get_padding_offset_gcu",
|
||||
new_name="get_padding_offset",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="update_inputs_gcu",
|
||||
new_name="update_inputs",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="rebuild_padding_gcu",
|
||||
new_name="rebuild_padding",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="get_token_penalty_multi_scores_gcu",
|
||||
new_name="get_token_penalty_multi_scores",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="set_stop_value_multi_ends_gcu",
|
||||
new_name="set_stop_value_multi_ends",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="set_value_by_flags_and_idx_gcu",
|
||||
new_name="set_value_by_flags_and_idx",
|
||||
global_ns=globals(),
|
||||
)
|
@@ -24,6 +24,11 @@ if current_platform.is_iluvatar():
|
||||
from fastdeploy.model_executor.ops.iluvatar import (
|
||||
get_padding_offset, save_output, set_stop_value_multi_ends,
|
||||
step_paddle, update_inputs)
|
||||
elif current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.ops.gcu import (get_padding_offset,
|
||||
save_output,
|
||||
set_stop_value_multi_ends,
|
||||
update_inputs)
|
||||
else:
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
get_padding_offset, save_output, set_stop_value_multi_ends,
|
||||
@@ -391,6 +396,17 @@ def rebuild_padding(tmp_out: paddle.Tensor,
|
||||
output_padding_offset,
|
||||
max_input_length,
|
||||
)
|
||||
elif current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.ops.gcu import rebuild_padding
|
||||
hidden_states = rebuild_padding(
|
||||
tmp_out,
|
||||
cum_offsets,
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
max_input_length,
|
||||
)
|
||||
elif current_platform.is_cpu():
|
||||
from fastdeploy.model_executor.ops.cpu import rebuild_padding_cpu
|
||||
hidden_states = rebuild_padding_cpu(
|
||||
|
Reference in New Issue
Block a user