[GCU] Support gcu platform (#2702)

baseline: e7fa57ebae

Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
EnflameGCU
2025-07-08 13:00:52 +08:00
committed by GitHub
parent 26d5d737dd
commit d0f4d6ba3a
33 changed files with 2988 additions and 85 deletions

View File

@@ -19,7 +19,7 @@ from typing import Optional
import paddle
from paddle import nn
from paddle.incubate.nn.functional import fused_bias_act
from paddle.incubate.nn.functional import fused_bias_act, swiglu
from fastdeploy.config import FDConfig
from fastdeploy.platforms import current_platform
@@ -66,6 +66,8 @@ class SiluAndMul(nn.Layer):
if current_platform.is_cuda() or current_platform.is_xpu(
) or current_platform.is_iluvatar():
self.forward = self.forward_cuda
elif current_platform.is_gcu():
self.forward = self.forward_gcu
else:
raise NotImplementedError
@@ -123,3 +125,18 @@ class SiluAndMul(nn.Layer):
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
def forward_gcu(self, x):
"""
Forward propagation of the custom activation layer.
Args:
x (Tensor): Input tensor to the activation layer.
Returns:
Tensor: Output tensor.
"""
out = swiglu(x)
if self.bias is not None:
out = out + self.bias
return out

View File

@@ -16,14 +16,24 @@
all backends methods
"""
from .xpu import *
from .npu import *
from fastdeploy.platforms import current_platform
__all__ = []
from . import npu
if hasattr(npu, '__all__'):
__all__.extend(npu.__all__)
from . import xpu
if hasattr(xpu, '__all__'):
__all__.extend(xpu.__all__)
if current_platform.is_xpu():
from . import xpu
from .xpu import *
if hasattr(xpu, '__all__'):
__all__.extend(xpu.__all__)
if current_platform.is_npu():
from . import npu
from .npu import *
if hasattr(npu, '__all__'):
__all__.extend(npu.__all__)
if current_platform.is_gcu():
from . import gcu
from .gcu import *
if hasattr(gcu, '__all__'):
__all__.extend(gcu.__all__)

View File

@@ -0,0 +1,31 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
gcu backend methods
"""
from .attention.flash_attn_backend import GCUFlashAttnBackend
from .attention.mem_efficient_attn_backend import GCUMemEfficientAttnBackend
from .moe.fused_moe_method_gcu_backend import (GCUFusedMoeMethod,
GCUWeightOnlyMoEMethod)
from .quantization.weight_only import GCUWeightOnlyLinearMethod
__all__ = [
'GCUFlashAttnBackend',
'GCUMemEfficientAttnBackend',
'GCUFusedMoeMethod',
'GCUWeightOnlyMoEMethod',
'GCUWeightOnlyLinearMethod',
]

View File

@@ -0,0 +1,21 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .flash_attn_backend import GCUFlashAttnBackend
from .mem_efficient_attn_backend import GCUMemEfficientAttnBackend
__all__ = [
"GCUFlashAttnBackend",
"GCUMemEfficientAttnBackend",
]

View File

@@ -0,0 +1,287 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional
import paddle
import numpy as np
if TYPE_CHECKING:
from paddle._typing.dtype_like import _DTypeLiteral
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
from fastdeploy.worker.forward_meta import ForwardMeta, ForwardMode
from fastdeploy.model_executor.ops.gcu import (fused_rotary_embedding,
mem_efficient_attention,
flash_attn_var_len)
from paddleformers.utils.log import logger
@dataclass
class GCUFlashAttnMetadata(AttentionMetadata):
"""
GCUFlashAttnMetadata
"""
forward_mode: ForwardMode = ForwardMode.MIXED
_dtype: _DTypeLiteral = paddle.bfloat16
seq_lens_encoder: Optional[paddle.Tensor] = None
seq_lens_decoder: Optional[paddle.Tensor] = None
seq_lens_this_time: Optional[paddle.Tensor] = None
cum_offsets: Optional[paddle.Tensor] = None
padding_offset: Optional[paddle.Tensor] = None
cu_seqlens_q: Optional[paddle.Tensor] = None
cu_seqlens_k: Optional[paddle.Tensor] = None
caches: Optional[paddle.Tensor] = None
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
pre_caches_length: int = 0
class GCUFlashAttnBackend(AttentionBackend):
"""
GCUFlashAttnBackend backend implementation.
"""
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
head_dim: int):
"""
GCUFlashAttnBackend __init__
"""
super().__init__()
self.attention_metadata: GCUFlashAttnMetadata = None
self.block_size = fd_config.parallel_config.block_size
self.max_seq_len = fd_config.parallel_config.max_model_len
self.max_num_seqs = fd_config.parallel_config.max_num_seqs
self.causal = getattr(fd_config.model_config, "causal", True)
self.rank = fd_config.parallel_config.tensor_parallel_rank
self.kv_num_heads = kv_num_heads
self.num_heads = num_heads
self.head_dim = head_dim
self.scaling = 1.0 / (self.head_dim**0.5)
self.num_layers = fd_config.model_config.num_layers
self.position_ids_base = paddle.arange(self.max_seq_len)
# TODO(zhengjun): Need to adapt the allocation logic and
# temporarily allocate according to fixed size
self.all_block_tables: List[List[int]] = None
self.all_slot_mapping: List[List[int]] = None
self.rotary_embs = None
self.enable_monitor: bool = bool(os.getenv("FD_GCU_ATTN_MONITOR", False))
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
metadata = GCUFlashAttnMetadata()
metadata.forward_mode = forward_meta.forward_mode
metadata._dtype = paddle.get_default_dtype()
metadata.seq_lens_encoder = forward_meta.seq_lens_encoder
metadata.seq_lens_decoder = forward_meta.seq_lens_decoder
metadata.seq_lens_this_time = forward_meta.seq_lens_this_time
metadata.cum_offsets = forward_meta.cum_offsets
metadata.padding_offset = forward_meta.padding_offset
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
metadata.cu_seqlens_k = forward_meta.cu_seqlens_k
metadata.caches = forward_meta.caches
# metadata.block_tables = forward_meta.block_tables
metadata.rotary_embs = forward_meta.rotary_embs
metadata.attn_mask = forward_meta.attn_mask # not init
metadata.pre_caches_length = forward_meta.pre_caches_length # not inited
self.attention_metadata = metadata
if self.rotary_embs is None:
self.rotary_embs = metadata.rotary_embs.reshape((-1, self.head_dim))
# some info for attention
self.seq_lens_this_time_list = forward_meta.seq_lens_this_time.tolist() # List[int]
self.seq_lens_encoder_list = forward_meta.seq_lens_encoder.tolist() # List[List[int]]
self.seq_lens_decoder_list = forward_meta.seq_lens_decoder.tolist() # List[List[int]]
self.seq_lens_sum = np.sum(self.seq_lens_this_time_list)
self.max_seq_len_this_time = np.max(self.seq_lens_this_time_list)
num_seqs = forward_meta.seq_lens_this_time.shape[0]
self.is_decoder = all(x[0] == 0 for x in self.seq_lens_encoder_list)
self.is_all_prefill = all(x[0] == 0 for x in self.seq_lens_decoder_list)
# block_tables and slot_mapping
if self.all_slot_mapping is None:
max_num_blocks_per_seq = (self.max_seq_len + self.block_size - 1) // self.block_size
total_blocks = max_num_blocks_per_seq * self.max_num_seqs
self.all_block_tables = np.arange(0, total_blocks, dtype=np.int32).reshape((self.max_num_seqs, max_num_blocks_per_seq)).tolist()
self.all_slot_mapping = np.arange(0, total_blocks * self.block_size, dtype=np.int32).reshape((self.max_num_seqs, -1)).tolist()
block_tables = []
slot_mapping = []
cache_slot_range = []
cache_lens = []
position_ids = []
for seq_idx in range(num_seqs):
cache_len = None
if self.seq_lens_encoder_list[seq_idx][0] != 0: # prefill
cache_len = 0
elif self.seq_lens_decoder_list[seq_idx][0] != 0: # decode
cache_len = self.seq_lens_decoder_list[seq_idx][0]
# else: doesnot have req in this seq_idx
if cache_len is not None:
lens_this_time = self.seq_lens_this_time_list[seq_idx]
start = cache_len
end = start + lens_this_time
slot_mapping.extend(self.all_slot_mapping[seq_idx][start:end])
cache_slot_range.extend(self.all_slot_mapping[seq_idx][0:end])
cache_lens.append(end)
block_tables.append(self.all_block_tables[seq_idx])
position_ids.extend(self.position_ids_base[start:end])
self.block_tables = paddle.to_tensor(block_tables, dtype="int32")
self.slot_mapping = paddle.to_tensor(slot_mapping, dtype="int32")
self.cache_slot_range = paddle.to_tensor(cache_slot_range, dtype="int32")
self.position_ids = paddle.to_tensor(position_ids, dtype="int32")
self.position_ids = self.position_ids.reshape_((1, -1))
if self.enable_monitor:
logger.info(f"[FD_DEBUG] init_attention_metadata, position_ids:\n{self.position_ids}")
cu_query_lens_data = [0]
for seq_idx in range(num_seqs):
if self.seq_lens_this_time_list[seq_idx] != 0:
cu_query_lens_data.append(self.seq_lens_this_time_list[seq_idx])
cu_query_lens = np.array(cu_query_lens_data, dtype=np.int32).cumsum(axis=0)
self.cu_query_lens = paddle.to_tensor(cu_query_lens, dtype="int32")
self.seqused_k = paddle.to_tensor(cache_lens, dtype="int32")
self.max_seqlen_q = self.max_seq_len_this_time
self.max_seqlen_k = np.max(cache_lens)
def get_attntion_meta(self):
"""get_attntion_meta"""
return self.attention_metadata
def get_kv_cache_shape(
self,
max_num_blocks: int,
):
"""
Caculate kv cache shape
"""
# [total_tokens, kv_num_heads, head_dim]
return (max_num_blocks * self.block_size,
self.kv_num_heads,
self.head_dim)
@paddle.no_grad()
def forward_mixed(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""Run a forward for mixed."""
token_num = qkv.shape[0]
q_size = self.num_heads * self.head_dim
kv_size = self.kv_num_heads * self.head_dim
num_or_sections = [q_size, kv_size, kv_size]
query, key, value = paddle.split(qkv, num_or_sections=num_or_sections, axis=-1)
query = query.reshape_((1, -1, self.num_heads, self.head_dim))
key = key.reshape_((1, -1, self.kv_num_heads, self.head_dim))
# 1. Rope
if self.rotary_embs.dtype != query.dtype:
self.rotary_embs = paddle.cast(self.rotary_embs, query.dtype)
query, key = fused_rotary_embedding(
query,
key,
self.rotary_embs,
self.position_ids,
layer.use_neox_rotary_style
)
# 2. Save kv cache
# shape: [total_tokens, kv_num_heads, head_dim]
key = key.reshape_((-1, self.kv_num_heads, self.head_dim))
value = value.reshape_((-1, self.kv_num_heads, self.head_dim))
key_caches = forward_meta.caches[2 * layer.layer_id]
value_caches = forward_meta.caches[2 * layer.layer_id + 1]
key_caches[self.slot_mapping, :, :] = key
value_caches[self.slot_mapping, :, :] = value
# 3. calc attn
query = query.reshape_((-1, self.num_heads, self.head_dim))
key_caches = key_caches.reshape((-1, self.block_size, self.kv_num_heads, self.head_dim))
value_caches = value_caches.reshape((-1, self.block_size, self.kv_num_heads, self.head_dim))
res = flash_attn_var_len(
query=query,
key=key_caches,
value=value_caches,
cu_seqlens_q=self.cu_query_lens,
cu_seqlens_k=None,
seqused_k=self.seqused_k,
leftpad_k=None,
block_table=self.block_tables,
alibi_slopes=None,
max_seqlen_q=self.max_seqlen_q,
max_seqlen_k=self.max_seqlen_k,
p_dropout=0.0,
softmax_scale=self.scaling,
zero_tensors=False,
is_causal=self.causal,
window_size_left=-1,
window_size_right=-1,
softcap=0.0,
return_softmax=False,
)
res = res.reshape_((token_num, -1))
return res

View File

@@ -0,0 +1,357 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional
import paddle
import numpy as np
import math
if TYPE_CHECKING:
from paddle._typing.dtype_like import _DTypeLiteral
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
from fastdeploy.worker.forward_meta import ForwardMeta, ForwardMode
from fastdeploy.model_executor.ops.gcu import (fused_rotary_embedding,
mem_efficient_attention,
flash_attn_var_len)
from paddleformers.utils.log import logger
@dataclass
class GCUMemEfficientAttnMetadata(AttentionMetadata):
"""
GCUMemEfficientAttnMetadata
"""
forward_mode: ForwardMode = ForwardMode.MIXED
_dtype: _DTypeLiteral = paddle.bfloat16
seq_lens_encoder: Optional[paddle.Tensor] = None
seq_lens_decoder: Optional[paddle.Tensor] = None
seq_lens_this_time: Optional[paddle.Tensor] = None
cum_offsets: Optional[paddle.Tensor] = None
padding_offset: Optional[paddle.Tensor] = None
cu_seqlens_q: Optional[paddle.Tensor] = None
cu_seqlens_k: Optional[paddle.Tensor] = None
caches: Optional[paddle.Tensor] = None
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
pre_caches_length: int = 0
class GCUMemEfficientAttnBackend(AttentionBackend):
"""
GCUMemEfficientAttnBackend backend implementation.
"""
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
head_dim: int):
"""
GCUMemEfficientAttnBackend __init__
"""
super().__init__()
self.attention_metadata: GCUMemEfficientAttnMetadata = None
self.block_size = fd_config.parallel_config.block_size
self.max_seq_len = fd_config.parallel_config.max_model_len
self.max_num_seqs = fd_config.parallel_config.max_num_seqs
self.causal = getattr(fd_config.model_config, "causal", True)
self.rank = fd_config.parallel_config.tensor_parallel_rank
self.kv_num_heads = kv_num_heads
self.num_heads = num_heads
self.head_dim = head_dim
self.scaling = 1.0 / (self.head_dim**0.5)
self.num_layers = fd_config.model_config.num_layers
self.position_ids_base = paddle.arange(self.max_seq_len)
# TODO(zhengjun): Need to adapt the allocation logic and
# temporarily allocate according to fixed size
self.all_block_tables: List[List[int]] = None
self.all_slot_mapping: List[List[int]] = None
self.rotary_embs = None
self.use_paddle_native_sdpa = False
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
metadata = GCUMemEfficientAttnMetadata()
metadata.forward_mode = forward_meta.forward_mode
metadata._dtype = paddle.get_default_dtype()
metadata.seq_lens_encoder = forward_meta.seq_lens_encoder
metadata.seq_lens_decoder = forward_meta.seq_lens_decoder
metadata.seq_lens_this_time = forward_meta.seq_lens_this_time
metadata.cum_offsets = forward_meta.cum_offsets
metadata.padding_offset = forward_meta.padding_offset
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
metadata.cu_seqlens_k = forward_meta.cu_seqlens_k
metadata.caches = forward_meta.caches
# metadata.block_tables = forward_meta.block_tables
metadata.rotary_embs = forward_meta.rotary_embs
metadata.attn_mask = forward_meta.attn_mask # not init
metadata.pre_caches_length = forward_meta.pre_caches_length # not inited
self.attention_metadata = metadata
if self.rotary_embs is None:
self.rotary_embs = metadata.rotary_embs.reshape((-1, self.head_dim))
# some info for attention
self.seq_lens_this_time_list = forward_meta.seq_lens_this_time.tolist() # List[int]
self.seq_lens_encoder_list = forward_meta.seq_lens_encoder.tolist() # List[List[int]]
self.seq_lens_decoder_list = forward_meta.seq_lens_decoder.tolist() # List[List[int]]
self.seq_lens_sum = np.sum(self.seq_lens_this_time_list)
self.max_seq_len_this_time = np.max(self.seq_lens_this_time_list)
num_seqs = forward_meta.seq_lens_this_time.shape[0]
self.is_decoder = all(x[0] == 0 for x in self.seq_lens_encoder_list)
self.is_all_prefill = all(x[0] == 0 for x in self.seq_lens_decoder_list)
# block_tables and slot_mapping
if self.all_slot_mapping is None:
max_num_blocks_per_seq = (self.max_seq_len + self.block_size - 1) // self.block_size
total_blocks = max_num_blocks_per_seq * self.max_num_seqs
self.all_block_tables = np.arange(0, total_blocks, dtype=np.int32).reshape((self.max_num_seqs, max_num_blocks_per_seq)).tolist()
self.all_slot_mapping = np.arange(0, total_blocks * self.block_size, dtype=np.int32).reshape((self.max_num_seqs, -1)).tolist()
block_tables = []
slot_mapping = []
cache_slot_range = []
cache_lens = []
query_lens = []
cached_kv_lens = []
cached_kv_slot_range = []
position_ids = []
for seq_idx in range(num_seqs):
cache_len = None
if self.seq_lens_encoder_list[seq_idx][0] != 0: # prefill
cache_len = 0
elif self.seq_lens_decoder_list[seq_idx][0] != 0: # decode
cache_len = self.seq_lens_decoder_list[seq_idx][0]
# else: doesnot have req in this seq_idx
if cache_len is not None:
lens_this_time = self.seq_lens_this_time_list[seq_idx]
start = cache_len
end = start + lens_this_time
slot_mapping.extend(self.all_slot_mapping[seq_idx][start:end])
cache_slot_range.extend(self.all_slot_mapping[seq_idx][0:end])
cache_lens.append(end)
block_tables.append(self.all_block_tables[seq_idx])
position_ids.extend(self.position_ids_base[start:end])
query_lens.append(lens_this_time)
cached_kv_lens.append(end)
cached_kv_slot_range.append([self.all_slot_mapping[seq_idx][0], self.all_slot_mapping[seq_idx][end]])
self.block_tables = paddle.to_tensor(block_tables, dtype="int32")
self.slot_mapping = paddle.to_tensor(slot_mapping, dtype="int32")
self.cache_slot_range = paddle.to_tensor(cache_slot_range, dtype="int32")
self.position_ids = paddle.to_tensor(position_ids, dtype="int32")
self.position_ids = self.position_ids.reshape_((1, -1))
logger.info(f"[FD_DEBUG] init_attention_metadata, self.position_ids:\n{self.position_ids}")
cu_query_lens_data = [0]
for seq_idx in range(num_seqs):
if self.seq_lens_this_time_list[seq_idx] != 0:
cu_query_lens_data.append(self.seq_lens_this_time_list[seq_idx])
cu_query_lens = np.array(cu_query_lens_data, dtype=np.int32).cumsum(axis=0)
self.cu_query_lens = paddle.to_tensor(cu_query_lens, dtype="int32")
self.seqused_k = paddle.to_tensor(cache_lens, dtype="int32")
self.max_seqlen_q = self.max_seq_len_this_time
self.max_seqlen_k = np.max(cache_lens)
self.query_lens = query_lens
self.cached_kv_lens = cached_kv_lens
self.cached_kv_slot_range = cached_kv_slot_range
def get_attntion_meta(self):
"""get_attntion_meta"""
return self.attention_metadata
def get_kv_cache_shape(
self,
max_num_blocks: int,
):
"""
Caculate kv cache shape
"""
# [total_tokens, kv_num_heads, head_dim]
return (max_num_blocks * self.block_size,
self.kv_num_heads,
self.head_dim)
@paddle.no_grad()
def forward_mixed(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""Run a forward for mixed."""
token_num = qkv.shape[0]
q_size = self.num_heads * self.head_dim
kv_size = self.kv_num_heads * self.head_dim
num_or_sections = [q_size, kv_size, kv_size]
query, key, value = paddle.split(qkv, num_or_sections=num_or_sections, axis=-1)
query = query.reshape_((1, -1, self.num_heads, self.head_dim))
key = key.reshape_((1, -1, self.kv_num_heads, self.head_dim))
# 1. Rope
if self.rotary_embs.dtype != query.dtype:
self.rotary_embs = paddle.cast(self.rotary_embs, query.dtype)
query, key = fused_rotary_embedding(
query,
key,
self.rotary_embs,
self.position_ids,
layer.use_neox_rotary_style
)
# 2. Save kv cache
# shape: [total_tokens, kv_num_heads, head_dim]
key = key.reshape_((-1, self.kv_num_heads, self.head_dim))
value = value.reshape_((-1, self.kv_num_heads, self.head_dim))
key_caches = forward_meta.caches[2 * layer.layer_id]
value_caches = forward_meta.caches[2 * layer.layer_id + 1]
key_caches[self.slot_mapping, :, :] = key
value_caches[self.slot_mapping, :, :] = value
# 3. calc attn
query = query.reshape_((-1, self.num_heads, self.head_dim))
q_start = 0
result = paddle.empty_like(query)
for idx in range(len(self.query_lens)):
q_end = q_start + self.query_lens[idx]
kv_start = self.cached_kv_slot_range[idx][0]
kv_end = self.cached_kv_slot_range[idx][1]
q_ = query[q_start:q_end, :, :]
k_ = key_caches[kv_start:kv_end, :, :]
v_ = value_caches[kv_start:kv_end, :, :]
if self.use_paddle_native_sdpa:
res = self.native_sdpa_impl(
q_, k_, v_
)
else:
res = mem_efficient_attention(
query=q_.unsqueeze(0),
key=k_.unsqueeze(0),
value=v_.unsqueeze(0),
attn_mask=None,
dropout=0.0,
softmax_scale=self.scaling,
mask_mode=1,
seqlens=[0],
causal=self.causal,
)
result[q_start:q_end, :, :] = res
q_start = q_end
result = result.reshape_((token_num, -1))
return result
def get_triangle_upper_mask(self, shape, dtype):
# [batch_size, 1, q_seq_len, kv_seq_len]
shape[1] = 1
q_seq_len = shape[2]
kv_seq_len = shape[3]
paddle_dtype = dtype # paddle.base.data_feeder.convert_dtype(dtype)
mask = paddle.full(shape, paddle.finfo(paddle_dtype).min, dtype=paddle_dtype)
mask = paddle.triu(mask, diagonal=kv_seq_len - q_seq_len + 1)
return mask
def native_sdpa_impl(self, query, key, value):
# input shape: [num_tokens, num_heads, head_dim] -> [1, num_tokens, num_heads, head_dim]
q = query.unsqueeze(0)
k = key.unsqueeze(0)
v = value.unsqueeze(0)
batch, q_seq_len, heads, head_dim = q.shape
kv_seq_len = k.shape[1]
# [batch_size, seq_len, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim]
q = paddle.transpose(q, [0, 2, 1, 3])
k = paddle.transpose(k, [0, 2, 1, 3])
v = paddle.transpose(v, [0, 2, 1, 3])
# GQA
if q.shape[1] != k.shape[1]:
kv_head = k.shape[1]
k = k.reshape([batch, kv_head, 1, kv_seq_len, head_dim])
k = paddle.tile(k, [1, 1, heads // kv_head, 1, 1])
k = k.reshape([batch, heads, kv_seq_len, head_dim])
v = v.reshape([batch, kv_head, 1, kv_seq_len, head_dim])
v = paddle.tile(v, [1, 1, heads // kv_head, 1, 1])
v = v.reshape([batch, heads, kv_seq_len, head_dim])
# matmul and devide by sqrt(head_dim)
attn_weights = paddle.matmul(q / math.sqrt(head_dim), k.transpose([0, 1, 3, 2]))
attention_mask = self.get_triangle_upper_mask(
[batch, 1, q_seq_len, kv_seq_len], q.dtype
)
attn_weights = attn_weights + attention_mask
attn_weights = paddle.nn.functional.softmax(
attn_weights, axis=-1, dtype="float32"
).astype(q.dtype)
attn_output = paddle.matmul(attn_weights, v)
attn_output = attn_output.transpose([0, 2, 1, 3])
return attn_output.squeeze(0)

View File

@@ -0,0 +1,16 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""
gcu moe
"""

View File

@@ -0,0 +1,402 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import multiprocessing
import os
import numpy as np
import paddle
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import \
MoEMethodBase
from fastdeploy.model_executor.layers.utils import (CpuGuard,
create_and_set_parameter,
get_tensor)
from fastdeploy.model_executor.ops.gcu import (invoke_fused_moe_kernel,
moe_align_block_size,
topk_softmax,
weight_quantize_custom_rtn,
weight_quantize_rtn)
class GCUFusedMoeMethod(MoEMethodBase):
"""
Use GCU to compute Fused MoE.
"""
def __init__(self, quant_config):
super().__init__(quant_config)
self.group_size = -1
def create_weights(self, layer: nn.Layer, state_dict):
"""
Paddle gcu create weight process.
"""
# bf16
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0)
stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0)
for idx, weight_tensor in enumerate(
[stacked_ffn1_weights, stacked_ffn2_weights]):
# shape [E, K, N] -> [E, N, K]
weight_tensor = paddle.transpose(weight_tensor, [0, 2, 1])
weight_name = self.added_weight_attrs[idx]
setattr(
layer, weight_name,
layer.create_parameter(
shape=weight_tensor.shape,
dtype=weight_tensor.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
))
getattr(layer, weight_name).set_value(weight_tensor)
@paddle.no_grad()
def compute_ffn(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
enable_quant = False
) -> paddle.Tensor:
"""
Paddle gcu compute Fused MoE.
"""
token_num, hidden_size = x.shape
top_k = layer.top_k
moe_intermediate_size = layer.moe_intermediate_size
num_experts = layer.num_local_experts
topk_weights = paddle.empty([token_num, top_k], dtype=gate_out.dtype)
topk_indices = paddle.empty([token_num, top_k], dtype="int32")
token_expert_indices = paddle.empty([token_num, top_k], dtype="int32",)
topk_softmax(topk_weights, topk_indices, token_expert_indices, gate_out, norm_topk_prob=True)
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
}
block_size = config["BLOCK_SIZE_M"]
max_num_tokens_padded = np.prod(topk_indices.shape) + num_experts * (block_size - 1)
max_num_m_blocks = max_num_tokens_padded // block_size
sorted_token_ids = paddle.empty([max_num_tokens_padded], dtype="int32")
expert_ids = paddle.zeros(shape=[max_num_m_blocks], dtype="int32")
num_tokens_post_pad = paddle.empty([1], dtype="int32")
sorted_token_ids, expert_ids, num_tokens_post_pad = moe_align_block_size(
sorted_token_ids,
expert_ids,
num_tokens_post_pad,
topk_indices,
num_experts,
block_size,
)
intermediate_cache1 = paddle.empty(
[token_num, top_k, moe_intermediate_size * 2],
dtype=x.dtype,
)
ffn1_B_scale = layer.moe_ffn1_weight_scale if enable_quant else None
ffn1_B_zeros = layer.moe_ffn1_weight_zeros if enable_quant else None
invoke_fused_moe_kernel(
x, # input
layer.moe_ffn1_weight, # weight
intermediate_cache1, # output
None, # A_scale
ffn1_B_scale, # B_scale
ffn1_B_zeros, # B_zp
topk_weights,
topk_indices,
sorted_token_ids,
expert_ids,
num_tokens_post_pad,
False, # mul_routed_weight
top_k,
config,
enable_quant, # use_int4_w4a16
[0, self.group_size], # block_shape
)
intermediate_cache2 = paddle.empty(
(token_num, top_k, moe_intermediate_size),
dtype=x.dtype,
)
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
intermediate_cache1)
intermediate_cache2 = intermediate_cache2.reshape([-1, moe_intermediate_size])
intermediate_cache3 = paddle.empty(
(token_num, top_k, hidden_size),
dtype=x.dtype,
)
ffn2_B_scale = layer.moe_ffn2_weight_scale if enable_quant else None
ffn2_B_zeros = layer.moe_ffn2_weight_zeros if enable_quant else None
invoke_fused_moe_kernel(
intermediate_cache2, # input
layer.moe_ffn2_weight, # weight
intermediate_cache3, # output
None, # A_scale
ffn2_B_scale, # B_scale
ffn2_B_zeros, # B_zp
topk_weights,
topk_indices,
sorted_token_ids,
expert_ids,
num_tokens_post_pad,
True, # mul_routed_weight
1,
config,
enable_quant, # use_int4_w4a16
[0, self.group_size], # block_shape
)
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
fused_moe_out = intermediate_cache3.sum(axis=1)
fused_moe_out = fused_moe_out.reshape_([token_num, hidden_size])
if layer.tp_size > 1:
from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce
tensor_model_parallel_all_reduce(fused_moe_out)
return fused_moe_out
def apply(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Paddle gcu compute Fused MoE.
"""
return self.compute_ffn(layer, x, gate_out, enable_quant=False)
def apply_ep_prefill(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
"""
raise NotImplementedError
def apply_ep_decode(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
"""
raise NotImplementedError
def apply_tp(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
"""
raise NotImplementedError
class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
"""
weight only for moe
"""
def __init__(self, quant_config):
super().__init__(quant_config)
self.quant_config = quant_config
self.moe_quant_type = self.quant_config.algo
self.pack_num = 1
assert self.quant_config.algo == "weight_only_int4", \
"GCUWeightOnlyMoEMethod only support weight_only_int4, but got:{self.quant_config.algo}"
self.added_qzeros_attrs = [
"moe_ffn1_weight_zeros", "moe_ffn2_weight_zeros"
]
self.group_size = 64
self.quant_multi_process_group_size = int(
os.getenv("FD_MOE_QUANT_MULTI_PROCESS_GROUP_SIZE", 8)
)
logger.info(f"GCUWeightOnlyMoEMethod quant_multi_process_group_size: {self.quant_multi_process_group_size}")
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
"""
Paddle gcu process prequanted weights.
"""
ffn1_expert_weight_key = layer.weight_key_map.get(
"ffn1_expert_weight_key", None)
ffn2_expert_weight_key = layer.weight_key_map.get(
"ffn2_expert_weight_key", None)
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
"ffn1_expert_weight_scale_key", None)
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
"ffn2_expert_weight_scale_key", None)
ffn1_weights, ffn2_weights = layer.load_experts_weight(
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
# self.check(layer, ffn1_weights, ffn2_weights)
ffn1_weight_scale = []
ffn2_weight_scale = []
for i in range(layer.num_experts):
expert_idx = layer.expert_id_offset + i
ffn1_weight_scale.append(
get_tensor(
state_dict.pop(
ffn1_expert_weight_scale_key.format(expert_idx))))
ffn2_weight_scale.append(
get_tensor(
state_dict.pop(
ffn2_expert_weight_scale_key.format(expert_idx))))
ffn1_weight = paddle.stack(ffn1_weights, axis=0)
ffn2_weight = paddle.stack(ffn2_weights, axis=0)
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0)
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0)
name_tensor_map = {
"moe_ffn1_weight": ffn1_weight,
"moe_ffn2_weight": ffn2_weight,
"moe_ffn1_weight_scale": ffn1_weight_scale,
"moe_ffn2_weight_scale": ffn2_weight_scale
}
for name, tensor in name_tensor_map.items():
create_and_set_parameter(layer, name, tensor)
@paddle.no_grad()
def create_weights(self, layer: nn.Layer, state_dict):
"""
Paddle cutlass create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, ffn1_weights, ffn2_weights)
def quant_worker(p_group_idx, shared_dict, weights, moe_quant_type, group_size):
with CpuGuard():
p_group_size = len(weights)
for group_j in range(p_group_size):
# weight shape [K, N] -> [N/2, K] -> [N, K/2]
quant_weight, scale = weight_quantize_custom_rtn(
weights[group_j],
moe_quant_type,
group_size # group_size
)
shared_dict[p_group_size * p_group_idx + group_j] = (
quant_weight, scale
)
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
zeros_name = self.added_qzeros_attrs[idx]
if self.quant_multi_process_group_size > 0:
process_group_size = self.quant_multi_process_group_size
process_group_num = layer.num_local_experts // process_group_size
grouped_weights_num = process_group_num * process_group_size
remain_weights_start_idx = grouped_weights_num
weight_list = [None] * grouped_weights_num
weight_scale_list = [None] * grouped_weights_num
with multiprocessing.Manager() as manager:
shared_dict = manager.dict({})
processes = []
for i in range(process_group_num):
w = []
for j in range(process_group_size):
w.append(weight_tensor[process_group_size * i + j].to("cpu"))
p = multiprocessing.Process(
target=quant_worker,
args=(i, shared_dict, w, self.moe_quant_type, self.group_size)
)
p.start()
processes.append(p)
for p in processes:
p.join()
dict_ = dict(shared_dict)
for k, v in dict_.items():
weight_list[k] = v[0].to(ffn1_weights[0].place)
weight_scale_list[k] = v[1].to(ffn1_weights[0].place)
else:
remain_weights_start_idx = 0
if remain_weights_start_idx < layer.num_local_experts:
for i in range(remain_weights_start_idx, layer.num_local_experts):
# weight shape [K, N] -> [N/2, K] -> [N, K/2]
quant_weight, scale = weight_quantize_rtn(
weight_tensor[i],
self.moe_quant_type,
self.group_size # group_size
)
weight_list.append(quant_weight)
weight_scale_list.append(scale)
quanted_weight = paddle.stack(weight_list, axis=0)
create_and_set_parameter(layer, weight_name, quanted_weight)
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
quanted_weight_zeros = quanted_weight_scale * 8
create_and_set_parameter(layer, zeros_name, quanted_weight_zeros)
def apply(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Paddle gcu compute Fused MoE.
"""
return self.compute_ffn(layer, x, gate_out, enable_quant=True)

View File

@@ -0,0 +1,21 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""
gcu quantization
"""
from .weight_only import GCUWeightOnlyLinearMethod
__all__ = [
"GCUWeightOnlyLinearMethod",
]

View File

@@ -0,0 +1,90 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from fastdeploy.model_executor.layers.quantization.weight_only import (
WeightOnlyConfig, WeightOnlyLinearMethod)
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.ops.gcu import linear_quant, weight_quantize_rtn
class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
"""
Weight only quantization method for linear layer on GCU
"""
def __init__(
self,
quant_config: WeightOnlyConfig,
) -> None:
super().__init__(quant_config)
self.quant_config = quant_config
self.group_size = -1
def create_weights(self, layer):
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
layer.linear_weight_shape.reverse()
if self.quant_config.name() == "wint4":
layer.linear_weight_shape[0] //= 2
layer.weight_dtype = "int8"
layer.linear_weight_scale = layer.create_parameter(
shape=linear_weight_scale_shape,
dtype=layer._dtype,
is_bias=False,
)
def process_prequanted_weights(self, layer, state_dict) -> None:
"""
Process pre-quantized weights before applying them to the model
Args:
layer: The layer that owns the weights
quant_weight: The quantized weights
weight_scale: The scale of the quantized weights
"""
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
layer.linear_weight.set_value(quant_weight)
layer.linear_weight_scale.set_value(
weight_scale.astype(paddle.get_default_dtype()))
def process_loaded_weights(self, layer, weight) -> None:
quanted_weight_tensor, weight_scale_tensor = weight_quantize_rtn(
weight,
self.quant_config.algo,
self.group_size, # group_size
)
layer.linear_weight.set_value(quanted_weight_tensor)
layer.linear_weight_scale.set_value(
weight_scale_tensor.astype(paddle.get_default_dtype()))
@paddle.no_grad()
def apply(self, layer, x):
linear_out = linear_quant(
lhs=x,
rhs=layer.linear_weight,
scale=layer.linear_weight_scale,
bias=None,
group_size=self.group_size,
)
return linear_out

View File

@@ -58,7 +58,7 @@ class LinearBase(nn.Layer):
"""
super().__init__()
if current_platform.is_cuda() or current_platform.is_xpu(
) or current_platform.is_iluvatar():
) or current_platform.is_iluvatar() or current_platform.is_gcu():
self.forward = self.forward_cuda
else:
raise NotImplementedError

View File

@@ -20,6 +20,7 @@ from paddleformers.utils.log import logger
from fastdeploy import envs
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.platforms import current_platform
class FusedMoE(nn.Layer):
@@ -95,8 +96,13 @@ class FusedMoE(nn.Layer):
self.moe_quant_type = moe_quant_config.name()
else:
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
from .fused_moe_cutlass_backend import CutlassMoEMethod
self.quant_method = CutlassMoEMethod(None)
if current_platform.is_cuda():
from .fused_moe_cutlass_backend import CutlassMoEMethod
self.quant_method = CutlassMoEMethod(None)
elif current_platform.is_gcu():
from fastdeploy.model_executor.layers.backends import \
GCUFusedMoeMethod
self.quant_method = GCUFusedMoeMethod(None)
if self.ep_size > 1:
self.quant_method.init_ep(self)

View File

@@ -19,9 +19,14 @@ from typing import Callable, Dict, Optional
import numpy as np
import paddle
from paddle import nn
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
from fastdeploy.platforms import current_platform
if current_platform.is_gcu():
from fastdeploy.model_executor.ops.gcu import fused_add_rms_norm, rms_norm
else:
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
from fastdeploy.config import FDConfig
from .utils import get_tensor
@@ -69,7 +74,10 @@ class RMSNorm(nn.Layer):
self.weight_key: Optional[str] = f"{prefix}.weight"
self.with_weight: bool = self.weight_key is not None
self.eps: float = eps
self.norm_func: Callable = fused_rms_norm
if current_platform.is_gcu():
self.norm_func: Callable = fused_add_rms_norm
else:
self.norm_func: Callable = fused_rms_norm
self.linear_bias: Optional[paddle.Tensor] = linear_bias
self.quant_scale: Optional[float] = quant_scale
self._dtype: str = self._helper.get_default_dtype()
@@ -129,19 +137,26 @@ class RMSNorm(nn.Layer):
The `residual_output` is the result of applying the normalization and possibly other
operations (like linear transformation) on the `residual_input`.
"""
norm_out = self.norm_func(
x,
norm_weight=self.ln_weight,
norm_bias=None,
epsilon=self.eps,
begin_norm_axis=self.begin_norm_axis,
bias=self.linear_bias,
residual=residual_input,
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
if current_platform.is_gcu():
if residual_input is None:
return rms_norm(x, self.ln_weight, self.eps)
norm_out = self.norm_func(
x, residual_input, self.ln_weight, self.eps
)
else:
norm_out = self.norm_func(
x,
norm_weight=self.ln_weight,
norm_bias=None,
epsilon=self.eps,
begin_norm_axis=self.begin_norm_axis,
bias=self.linear_bias,
residual=residual_input,
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
if residual_input is not None:
return norm_out[0], norm_out[1]
else:
@@ -193,7 +208,10 @@ class LayerNorm(nn.Layer):
self.with_bias: bool = with_bias
self.eps: float = eps
self.quant_scale: float = quant_scale
self.norm_func: Callable = fused_layer_norm
if current_platform.is_gcu():
self.norm_func: Callable = paddle.nn.functional.layer_norm
else:
self.norm_func: Callable = fused_layer_norm
self.linear_bias: Optional[paddle.Tensor] = linear_bias
self._dtype: str = self._helper.get_default_dtype()
self._norm_weight_dtype: str = "float32"
@@ -279,19 +297,40 @@ class LayerNorm(nn.Layer):
else:
raise NotImplementedError("Iluvatar does not support yet!")
norm_out = self.norm_func(
x,
norm_weight=self.ln_weight,
norm_bias=self.ln_bias,
epsilon=self.eps,
begin_norm_axis=1,
bias=self.linear_bias,
residual=residual_input,
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
if current_platform.is_gcu():
if residual_input is not None:
y = x + residual_input
out = self.norm_func(
x=y,
normalized_shape=y.shape[1:],
weight=self.ln_weight,
bias=self.linear_bias,
epsilon=self.eps,
)
return out, y
else:
out = self.norm_func(
x=x,
normalized_shape=x.shape[1:],
weight=self.ln_weight,
bias=self.linear_bias,
epsilon=self.eps,
)
return out
else:
norm_out = self.norm_func(
x,
norm_weight=self.ln_weight,
norm_bias=self.ln_bias,
epsilon=self.eps,
begin_norm_axis=1,
bias=self.linear_bias,
residual=residual_input,
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
if residual_input is not None:
return norm_out[0], norm_out[1]
else:

View File

@@ -66,6 +66,13 @@ class WeightOnlyConfig(QuantConfigBase):
return XPUWeightOnlyMoEMethod(self)
else:
return XPUWeightOnlyLinearMethod(self)
elif current_platform.is_gcu():
from fastdeploy.model_executor.layers.backends import (
GCUWeightOnlyLinearMethod, GCUWeightOnlyMoEMethod)
if isinstance(layer, FusedMoE):
return GCUWeightOnlyMoEMethod(self)
else:
return GCUWeightOnlyLinearMethod(self)
else:
if isinstance(layer, FusedMoE):
if layer.use_method == "cutlass":

View File

@@ -55,6 +55,10 @@ class ErnieRotaryEmbedding:
dtype="float32")
emb = paddle.stack([freqs, freqs], axis=-1).reshape(
(bsz, max_seq_len, self.rotary_dim))
elif current_platform.is_gcu():
# shape: [B, S, D]
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
return rot_emb
else:
# shape: [B, S, D/2]
rot_emb = paddle.zeros(
@@ -95,6 +99,10 @@ class QwenRotaryEmbedding:
# shape: [B, S, D/2]
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"),
inv_freq)
if current_platform.is_gcu():
# shape: [B, S, D]
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
return rot_emb
# shape: [B, S, 1, D]
emb = paddle.concat([freqs, freqs], axis=-1).reshape(
(bsz, max_seq_len, 1, self.rotary_dim))

View File

@@ -79,6 +79,21 @@ def apply_penalty_multi_scores(
min_dec_lens,
eos_token_ids,
)
elif current_platform.is_gcu():
from fastdeploy.model_executor.ops.gcu import \
get_token_penalty_multi_scores
logits = get_token_penalty_multi_scores(
pre_token_ids,
logits,
repetition_penalties,
frequency_penalties,
presence_penalties,
temperature,
bad_words_token_ids,
step_idx,
min_dec_lens,
eos_token_ids,
)
else:
raise NotImplementedError()

View File

@@ -19,7 +19,11 @@ from typing import Literal, Optional
import paddle
from fastdeploy import envs
from fastdeploy.platforms import current_platform
if current_platform.is_gcu():
from fastdeploy.model_executor.ops.gcu import \
top_p_sampling as gcu_top_p_sampling
def top_p_sampling(
x: paddle.Tensor,
@@ -46,13 +50,16 @@ def top_p_sampling(
ids = rejection_top_p_sampling(x, ps, seed)
_ = None
else:
_, ids = paddle.tensor.top_p_sampling(x,
ps,
threshold=threshold,
topp_seed=topp_seed,
seed=seed,
k=k,
mode=mode)
if current_platform.is_gcu():
_, ids = gcu_top_p_sampling(x, ps)
else:
_, ids = paddle.tensor.top_p_sampling(x,
ps,
threshold=threshold,
topp_seed=topp_seed,
seed=seed,
k=k,
mode=mode)
return _, ids

View File

@@ -171,7 +171,7 @@ class Sampler(nn.Layer):
"""
super().__init__()
if current_platform.is_cuda() or current_platform.is_xpu(
) or current_platform.is_iluvatar():
) or current_platform.is_iluvatar() or current_platform.is_gcu():
self.forward = self.forward_cuda
else:
raise NotImplementedError()

View File

@@ -17,5 +17,6 @@ from . import cpu
from . import xpu
from . import npu
from . import iluvatar
from . import gcu
__all__ = ["gpu", "cpu", "xpu", "npu", "iluvatar"]
__all__ = ["gpu", "cpu", "xpu", "npu", "iluvatar", "gcu"]

View File

@@ -0,0 +1,116 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" fastdeploy gcu ops """
from fastdeploy.platforms import current_platform
from fastdeploy.import_ops import import_custom_ops, rename_imported_op
PACKAGE = "fastdeploy.model_executor.ops.gcu"
import_custom_ops(PACKAGE, ".fastdeploy_ops", globals())
if current_platform.is_gcu():
from paddle_custom_device.gcu.ops import (invoke_fused_moe_kernel, # noqa: F401,E402
moe_align_block_size, top_p_sampling, # noqa: F401
topk_softmax, # noqa: F401
weight_quantize_custom_rtn, # noqa: F401
weight_quantize_rtn) # noqa: F401
# ###################### Ops from PaddleCustomDevice ####################
rename_imported_op(
old_name="fused_rotary_embedding_gcu",
new_name="fused_rotary_embedding",
global_ns=globals(),
)
rename_imported_op(
old_name="reshape_and_cache_gcu",
new_name="reshape_and_cache",
global_ns=globals(),
)
rename_imported_op(
old_name="paged_attention_gcu",
new_name="paged_attention",
global_ns=globals(),
)
rename_imported_op(
old_name="mem_efficient_attention_gcu",
new_name="mem_efficient_attention",
global_ns=globals(),
)
rename_imported_op(
old_name="flash_attn_var_len_gcu",
new_name="flash_attn_var_len",
global_ns=globals(),
)
rename_imported_op(
old_name="rms_norm_gcu",
new_name="rms_norm",
global_ns=globals(),
)
rename_imported_op(
old_name="fused_add_rms_norm_op",
new_name="fused_add_rms_norm",
global_ns=globals(),
)
rename_imported_op(
old_name="linear_quant_gcu",
new_name="linear_quant",
global_ns=globals(),
)
# ###################### CPU OPS ####################
rename_imported_op(
old_name="get_padding_offset_gcu",
new_name="get_padding_offset",
global_ns=globals(),
)
rename_imported_op(
old_name="update_inputs_gcu",
new_name="update_inputs",
global_ns=globals(),
)
rename_imported_op(
old_name="rebuild_padding_gcu",
new_name="rebuild_padding",
global_ns=globals(),
)
rename_imported_op(
old_name="get_token_penalty_multi_scores_gcu",
new_name="get_token_penalty_multi_scores",
global_ns=globals(),
)
rename_imported_op(
old_name="set_stop_value_multi_ends_gcu",
new_name="set_stop_value_multi_ends",
global_ns=globals(),
)
rename_imported_op(
old_name="set_value_by_flags_and_idx_gcu",
new_name="set_value_by_flags_and_idx",
global_ns=globals(),
)

View File

@@ -24,6 +24,11 @@ if current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import (
get_padding_offset, save_output, set_stop_value_multi_ends,
step_paddle, update_inputs)
elif current_platform.is_gcu():
from fastdeploy.model_executor.ops.gcu import (get_padding_offset,
save_output,
set_stop_value_multi_ends,
update_inputs)
else:
from fastdeploy.model_executor.ops.gpu import (
get_padding_offset, save_output, set_stop_value_multi_ends,
@@ -391,6 +396,17 @@ def rebuild_padding(tmp_out: paddle.Tensor,
output_padding_offset,
max_input_length,
)
elif current_platform.is_gcu():
from fastdeploy.model_executor.ops.gcu import rebuild_padding
hidden_states = rebuild_padding(
tmp_out,
cum_offsets,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length,
)
elif current_platform.is_cpu():
from fastdeploy.model_executor.ops.cpu import rebuild_padding_cpu
hidden_states = rebuild_padding_cpu(