dcu adapter ernie45t (#2756)

Co-authored-by: lifu <lifu@sugon.com>
Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
lifulll
2025-07-09 18:56:27 +08:00
committed by GitHub
parent 03a74995b8
commit 1f28bdf994
30 changed files with 1133 additions and 41 deletions

View File

@@ -77,8 +77,10 @@ function copy_ops(){
is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"` is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"`
if [ "$is_rocm" = "True" ]; then if [ "$is_rocm" = "True" ]; then
DEVICE_TYPE="rocm" DEVICE_TYPE="rocm"
mkdir -p ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
echo -e "ROCM ops have been copy to fastdeploy" echo -e "BASE and ROCM ops have been copy to fastdeploy"
return return
fi fi
mkdir -p ../fastdeploy/model_executor/ops/base mkdir -p ../fastdeploy/model_executor/ops/base

View File

@@ -214,11 +214,19 @@ HOSTDEVICE inline void Store(const AlignedVector<T, Size> &vec, T *addr) {
*addr_vec = vec; *addr_vec = vec;
} }
#ifdef PADDLE_WITH_HIP
template <int Size>
HOSTDEVICE inline void Store(const AlignedVector<hip_bfloat16, Size> &vec,
int8_t *addr) {
printf("Error: Store hip_bfloat16 to int8_t is not supported!");
}
#else
template <int Size> template <int Size>
HOSTDEVICE inline void Store(const AlignedVector<__nv_bfloat16, Size> &vec, HOSTDEVICE inline void Store(const AlignedVector<__nv_bfloat16, Size> &vec,
int8_t *addr) { int8_t *addr) {
printf("Error: Store __nv_bfloat16 to int8_t is not supported!"); printf("Error: Store __nv_bfloat16 to int8_t is not supported!");
} }
#endif
template <int Size> template <int Size>
HOSTDEVICE inline void Store(const AlignedVector<half, Size> &vec, HOSTDEVICE inline void Store(const AlignedVector<half, Size> &vec,
@@ -478,7 +486,12 @@ template <typename T>
static void PrintMatrix3(const T *mat_d, int num, std::string name) { static void PrintMatrix3(const T *mat_d, int num, std::string name) {
std::vector<T> tmp(num); std::vector<T> tmp(num);
#ifdef PADDLE_WITH_HIP
hipMemcpy(tmp.data(), mat_d, sizeof(T) * num, hipMemcpyDeviceToHost);
#else
cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost); cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost);
#endif
std::ofstream outfile; std::ofstream outfile;
outfile.open(name + ".txt", std::ios::out); outfile.open(name + ".txt", std::ios::out);
@@ -495,6 +508,7 @@ static void PrintMatrix3(const T *mat_d, int num, std::string name) {
outfile.close(); outfile.close();
} }
#ifndef PADDLE_WITH_HIP
__forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr, __forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr,
int mode = 0) { int mode = 0) {
uint32_t flag; uint32_t flag;
@@ -534,6 +548,7 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
cudaDevAttrMaxSharedMemoryPerBlockOptin, device); cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
return max_shared_mem_per_block_opt_in; return max_shared_mem_per_block_opt_in;
} }
#endif
inline int GetSMVersion() { inline int GetSMVersion() {
static int sm_version = phi::backends::gpu::GetGPUComputeCapability( static int sm_version = phi::backends::gpu::GetGPUComputeCapability(

View File

@@ -91,7 +91,12 @@ void set_data_ipc(const paddle::Tensor& tmp_input,
memset((void *)shm, 0, sizeof(*shm)); memset((void *)shm, 0, sizeof(*shm));
void *data_ptr_now = reinterpret_cast<void*>(const_cast<data_t*>(tmp_input.data<data_t>())); void *data_ptr_now = reinterpret_cast<void*>(const_cast<data_t*>(tmp_input.data<data_t>()));
#ifdef PADDLE_WITH_HIP
checkCudaErrors(hipIpcGetMemHandle((hipIpcMemHandle_t *)&shm->memHandle, data_ptr_now));
#else
checkCudaErrors(cudaIpcGetMemHandle((cudaIpcMemHandle_t *)&shm->memHandle, data_ptr_now)); checkCudaErrors(cudaIpcGetMemHandle((cudaIpcMemHandle_t *)&shm->memHandle, data_ptr_now));
#endif
} }

View File

@@ -37,10 +37,18 @@ std::vector<paddle::Tensor> ShareExternalData(paddle::Tensor& input,
} }
shm = (volatile shmStruct *)info.addr; shm = (volatile shmStruct *)info.addr;
void *ptr = nullptr; void *ptr = nullptr;
#ifdef PADDLE_WITH_HIP
checkCudaErrors(
hipIpcOpenMemHandle(&ptr,
*(hipIpcMemHandle_t *)&shm->memHandle, // NOLINT
hipIpcMemLazyEnablePeerAccess));
#else
checkCudaErrors( checkCudaErrors(
cudaIpcOpenMemHandle(&ptr, cudaIpcOpenMemHandle(&ptr,
*(cudaIpcMemHandle_t *)&shm->memHandle, // NOLINT *(cudaIpcMemHandle_t *)&shm->memHandle, // NOLINT
cudaIpcMemLazyEnablePeerAccess)); cudaIpcMemLazyEnablePeerAccess));
#endif
paddle::Tensor tmp_tensor = paddle::from_blob( paddle::Tensor tmp_tensor = paddle::from_blob(
ptr, ptr,
shape, shape,

View File

@@ -187,39 +187,45 @@ def find_end_files(directory, end_str):
if paddle.is_compiled_with_rocm(): if paddle.is_compiled_with_rocm():
# NOTE(@duanyanhui): paddle.is_compiled_with_cuda() returns True when paddle compiled with rocm. # NOTE(@duanyanhui): paddle.is_compiled_with_cuda() returns True when paddle compiled with rocm.
# so we need to check if paddle compiled with rocm at first. # so we need to check if paddle compiled with rocm at first.
json_dir = "third_party/nlohmann_json"
if not os.path.exists(json_dir) or not os.listdir(json_dir):
if not os.path.exists(json_dir):
os.makedirs(json_dir)
clone_git_repo("v3.11.3", "https://bgithub.xyz/nlohmann/json.git", json_dir)
if not os.listdir(json_dir):
raise ValueError("Git clone nlohmann_json failed!")
sources=[
"gpu_ops/set_value_by_flags.cu",
"gpu_ops/token_penalty_multi_scores.cu",
"gpu_ops/stop_generation.cu",
"gpu_ops/stop_generation_multi_ends.cu",
"gpu_ops/get_padding_offset.cu",
"gpu_ops/update_inputs.cu",
"gpu_ops/rebuild_padding.cu",
"gpu_ops/step.cu",
"gpu_ops/set_data_ipc.cu",
"gpu_ops/moe/tritonmoe_preprocess.cu",
"gpu_ops/step_system_cache.cu",
"gpu_ops/get_output_ep.cc",
"gpu_ops/speculate_decoding/speculate_get_padding_offset.cu",
"gpu_ops/speculate_decoding/speculate_get_output.cc",
"gpu_ops/share_external_data.cu",
"gpu_ops/speculate_decoding/speculate_clear_accept_nums.cu",
"gpu_ops/speculate_decoding/speculate_get_output_padding_offset.cu",
"gpu_ops/speculate_decoding/speculate_get_seq_lens_output.cu",
"gpu_ops/speculate_decoding/speculate_save_output.cc",
"gpu_ops/speculate_decoding/speculate_set_value_by_flags.cu",
"gpu_ops/speculate_decoding/speculate_step.cu",
"gpu_ops/speculate_decoding/speculate_step_system_cache.cu",
"gpu_ops/speculate_decoding/speculate_update_v3.cu",
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/step_reschedule.cu",
]
setup( setup(
name="fastdeploy_ops", name="fastdeploy_ops",
ext_modules=CUDAExtension( ext_modules=CUDAExtension(
sources=[ sources=sources,
"gpu_ops/save_with_output.cc",
"gpu_ops/set_mask_value.cu",
"gpu_ops/set_value_by_flags.cu",
"gpu_ops/ngram_mask.cu",
"gpu_ops/gather_idx.cu",
"gpu_ops/token_penalty_multi_scores.cu",
"gpu_ops/token_penalty_only_once.cu",
"gpu_ops/stop_generation.cu",
"gpu_ops/stop_generation_multi_ends.cu",
"gpu_ops/stop_generation_multi_stop_seqs.cu",
"gpu_ops/set_flags.cu",
"gpu_ops/fused_get_rope.cu",
"gpu_ops/transfer_output.cc",
"gpu_ops/get_padding_offset.cu",
"gpu_ops/update_inputs.cu",
"gpu_ops/update_inputs_beam.cu",
"gpu_ops/beam_search_softmax.cu",
"gpu_ops/rebuild_padding.cu",
"gpu_ops/save_with_output_msg.cc",
"gpu_ops/get_output.cc",
"gpu_ops/get_output_msg_with_topk.cc",
"gpu_ops/step.cu",
"gpu_ops/step_reschedule.cu",
"gpu_ops/set_data_ipc.cu",
"gpu_ops/read_data_ipc.cu",
"gpu_ops/dequant_int8.cu",
"gpu_ops/enforce_generation.cu",
"gpu_ops/tune_cublaslt_gemm.cu",
],
extra_compile_args={ extra_compile_args={
"cxx": ["-O3"], "cxx": ["-O3"],
"hipcc": [ "hipcc": [
@@ -231,6 +237,9 @@ if paddle.is_compiled_with_rocm():
"-U__HIP_NO_BFLOAT16_CONVERSIONS__", "-U__HIP_NO_BFLOAT16_CONVERSIONS__",
"-U__HIP_NO_BFLOAT162_OPERATORS__", "-U__HIP_NO_BFLOAT162_OPERATORS__",
"-U__HIP_NO_BFLOAT162_CONVERSIONS__", "-U__HIP_NO_BFLOAT162_CONVERSIONS__",
"-DPADDLE_DEV",
"-Ithird_party/nlohmann_json/include",
"-Igpu_ops",
], ],
}, },
), ),

View File

@@ -6,3 +6,4 @@ FastDeploy currently supports installation on the following hardware platforms:
- [Kunlun XPU Installation](kunlunxin_xpu.md) - [Kunlun XPU Installation](kunlunxin_xpu.md)
- [Enflame S60 GCU Installation](Enflame_gcu.md) - [Enflame S60 GCU Installation](Enflame_gcu.md)
- [Iluvatar GPU Installation](iluvatar_gpu.md) - [Iluvatar GPU Installation](iluvatar_gpu.md)
- [Hygon DCU Installation](hygon_dcu.md)

View File

@@ -0,0 +1,81 @@
# Run ERNIE-4.5-300B-A47B & ERNIE-4.5-21B-A3B model on hygon machine
The current version of the software merely serves as a demonstration demo for the hygon k100AI combined with the Fastdeploy inference framework for large models. There may be issues when running the latest ERNIE4.5 model, and we will conduct repairs and performance optimization in the future. Subsequent versions will provide customers with a more stable version.
## Requirements
Firstly, you need to prepare a machine with the following configuration
- OSLinux
- Python3.10
- Memory: 2T
- Disk: 4T
- DCU ModelK100AI
- DCU Driver Version≥ 6.3.8-V1.9.2
## 1. Set up using Docker (Recommended)
```bash
mkdir Work
cd Work
docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:fastdeploy2.0.0-kylinv10-dtk25.04-py3.10
docker run -it \
--network=host \
--name=ernie45t \
--privileged \
--device=/dev/kfd \
--device=/dev/dri \
--ipc=host \
--shm-size=16G \
--group-add video \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
-u root \
--ulimit stack=-1:-1 \
--ulimit memlock=-1:-1 \
-v `pwd`:/home \
-v /opt/hyhal:/opt/hyhal:ro \
image.sourcefind.cn:5000/dcu/admin/base/custom:fastdeploy2.0.0-kylinv10-dtk25.04-py3.10 /bin/bash
```
## 2. Start service
```bash
export FD_ATTENTION_BACKEND="BLOCK_ATTN"
python -m fastdeploy.entrypoints.openai.api_server \
--model "/models/ERNIE-45-Turbo/ERNIE-4.5-300B-A47B-Paddle/" \
--port 8188 \
--tensor-parallel-size 8 \
--quantization=wint8 \
--gpu-memory-utilization=0.8
```
#### Send requests
Send requests using either curl or Python
```bash
curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "Where is the capital of China?"}
]
}'
```
```python
import openai
ip = "0.0.0.0"
service_http_port = "8188"
client = openai.Client(base_url=f"http://{ip}:{service_http_port}/v1", api_key="EMPTY_API_KEY")
response = client.chat.completions.create(
model="default",
messages=[
{"role": "user", "content": "Eliza's rate per hour for the first 40 hours she works each week is $10. She also receives an overtime pay of 1.2 times her regular hourly rate. If Eliza worked for 45 hours this week, how much are her earnings for this week?"},
],
temperature=1,
max_tokens=1024,
stream=False,
)
print(response)
```

View File

@@ -6,3 +6,4 @@ FastDeploy currently supports installation on the following hardware platforms:
- [Kunlunxin XPU Installation](kunlunxin_xpu.md) - [Kunlunxin XPU Installation](kunlunxin_xpu.md)
- [Enflame S60 GCU Installation](Enflame_gcu.md) - [Enflame S60 GCU Installation](Enflame_gcu.md)
- [Iluvatar GPU Installation](iluvatar_gpu.md) - [Iluvatar GPU Installation](iluvatar_gpu.md)
- [Hygon DCU Installation](hygon_dcu.md)

View File

@@ -0,0 +1,81 @@
# 使用 FastDeploy 在海光 K100AI 上运行 ERNIE-4.5-300B-A47B & ERNIE-4.5-21B-A3B
当前版本软件只是作为K100AI + Fastdeploy 推理大模型的一个演示 demo跑最新ERNIE4.5模型可能存在问题,后续进行修复和性能优化,给客户提供一个更稳定的版本。
## 准备机器
首先您需要准备以下配置的机器
- OSLinux
- Python3.10
- 内存2T
- 磁盘4T
- DCU 型号K100AI
- DCU 驱动版本:≥ 6.3.8-V1.9.2
## 1. 使用 Docker 安装(推荐)
```bash
mkdir Work
cd Work
docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:fastdeploy2.0.0-kylinv10-dtk25.04-py3.10
docker run -it \
--network=host \
--name=ernie45t \
--privileged \
--device=/dev/kfd \
--device=/dev/dri \
--ipc=host \
--shm-size=16G \
--group-add video \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
-u root \
--ulimit stack=-1:-1 \
--ulimit memlock=-1:-1 \
-v `pwd`:/home \
-v /opt/hyhal:/opt/hyhal:ro \
image.sourcefind.cn:5000/dcu/admin/base/custom:fastdeploy2.0.0-kylinv10-dtk25.04-py3.10 /bin/bash
```
## 2. 启动服务
```bash
export FD_ATTENTION_BACKEND="BLOCK_ATTN"
python -m fastdeploy.entrypoints.openai.api_server \
--model "/models/ERNIE-45-Turbo/ERNIE-4.5-300B-A47B-Paddle/" \
--port 8188 \
--tensor-parallel-size 8 \
--quantization=wint8 \
--gpu-memory-utilization=0.8
```
#### 请求服务
您可以基于 OpenAI 协议,通过 curl 和 python 两种方式请求服务。
```bash
curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "Where is the capital of China?"}
]
}'
```
```python
import openai
ip = "0.0.0.0"
service_http_port = "8188"
client = openai.Client(base_url=f"http://{ip}:{service_http_port}/v1", api_key="EMPTY_API_KEY")
response = client.chat.completions.create(
model="default",
messages=[
{"role": "user", "content": "Eliza's rate per hour for the first 40 hours she works each week is $10. She also receives an overtime pay of 1.2 times her regular hourly rate. If Eliza worked for 45 hours this week, how much are her earnings for this week?"},
],
temperature=1,
max_tokens=1024,
stream=False,
)
print(response)
```

View File

@@ -20,9 +20,11 @@ from .mla_attention_backend import MLAAttentionBackend
from .native_paddle_backend import PaddleNativeAttnBackend from .native_paddle_backend import PaddleNativeAttnBackend
from .xpu_attn_backend import XPUAttentionBackend from .xpu_attn_backend import XPUAttentionBackend
from .iluvatar_attn_backend import IluvatarAttnBackend from .iluvatar_attn_backend import IluvatarAttnBackend
from .block_multihead_attn_backend import BlockAttentionBackend
__all__ = [ __all__ = [
"AttentionBackend", "PaddleNativeAttnBackend", "AttentionBackend", "PaddleNativeAttnBackend",
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend", "get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend",
"MLAAttentionBackend", "FlashAttentionBackend", "IluvatarAttnBackend" "MLAAttentionBackend", "FlashAttentionBackend", "IluvatarAttnBackend",
"BlockAttentionBackend"
] ]

View File

@@ -0,0 +1,172 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional
import paddle
if TYPE_CHECKING:
from paddle._typing.dtype_like import _DTypeLiteral
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
from fastdeploy.worker.forward_meta import ForwardMeta
@dataclass
class BlockAttentionMetadata(AttentionMetadata):
"""
BlockAttentionMetadata
"""
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
_dtype: _DTypeLiteral = paddle.bfloat16
encoder_max_partition_size: int = 32768
max_partition_size: int = 32768
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: Optional[paddle.Tensor] = None
decoder_block_shape_q: Optional[paddle.Tensor] = None
_fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
class BlockAttentionBackend(AttentionBackend):
"""
BlockAttentionBackend backend implementation.
"""
def __init__(self, fd_config: FDConfig, kv_num_heads: int,
num_heads: int, head_dim: int):
"""
BlockAttentionBackend __init__
"""
super().__init__()
self.attention_metadata: BlockAttentionMetadata = None
self.block_size = fd_config.parallel_config.block_size
self.max_seq_len = fd_config.parallel_config.max_model_len
self.rope_theta = (10000.0 if fd_config.model_config.rope_theta
is None else fd_config.model_config.rope_theta)
self.rank = fd_config.parallel_config.tensor_parallel_rank
self.kv_num_heads = kv_num_heads
self.num_heads = num_heads
self.head_dim = fd_config.model_config.head_dim
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
metadata = BlockAttentionMetadata()
metadata._dtype = paddle.get_default_dtype()
if metadata._dtype == "bfloat16":
metadata._fuse_kernel_compute_dtype = "bf16"
elif metadata._dtype == "float16":
metadata._fuse_kernel_compute_dtype = "fp16"
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"
metadata.block_tables = forward_meta.block_tables
metadata.rotary_embs = forward_meta.rotary_embs
metadata.attn_mask = forward_meta.attn_mask
self.attention_metadata = metadata
def get_attntion_meta(self):
"""get_attntion_meta"""
return self.attention_metadata
def get_kv_cache_shape(
self,
max_num_blocks: int,
):
"""
Caculate kv cache shape
"""
return (max_num_blocks, self.kv_num_heads, self.block_size,
self.head_dim)
def forward_mixed(
self,
q,
k,
v,
qkv,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
):
"""
forward_mixed
"""
metadata = self.attention_metadata
res = paddle.incubate.nn.functional.block_multihead_attention(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.padding_offset,
forward_meta.cum_offsets,
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k,
metadata.block_tables,
getattr(layer, "pre_key_cache", None),
getattr(layer, "pre_value_cache", None),
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
layer.qkv_scale,
layer.qkv_bias,
layer.linear_shift,
layer.linear_smooth,
getattr(layer, "max_enc_len_this_time", None),
getattr(layer, "max_dec_len_this_time", None),
metadata.rotary_embs,
metadata.attn_mask,
None, # tgt_mask
self.max_seq_len,
self.block_size,
layer.use_neox_rotary_style,
getattr(layer, "use_dynamic_cachekv_quant", False),
quant_round_type=getattr(layer, "quant_round_type", 0),
quant_max_bound=getattr(layer, "quant_max_bound", 0.0),
quant_min_bound=getattr(layer, "quant_min_bound", 0.0),
out_scale=getattr(layer, "out_scale", -1.0),
compute_dtype=metadata._fuse_kernel_compute_dtype,
rope_theta=self.rope_theta,
)[0]
return res

View File

@@ -29,7 +29,7 @@ from fastdeploy.model_executor.layers.attention.ops import (
open_shm_and_get_meta_signal) open_shm_and_get_meta_signal)
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
if current_platform.is_cuda(): if current_platform.is_cuda() and not current_platform.is_dcu():
from fastdeploy.model_executor.ops.gpu import (decode_mla_write_cache, from fastdeploy.model_executor.ops.gpu import (decode_mla_write_cache,
multi_head_latent_attention, multi_head_latent_attention,
prefill_mla_write_cache) prefill_mla_write_cache)

View File

@@ -20,7 +20,7 @@ import paddle
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
if current_platform.is_cuda(): if current_platform.is_cuda() and not current_platform.is_dcu():
from fastdeploy.model_executor.ops.gpu import \ from fastdeploy.model_executor.ops.gpu import \
append_attention as append_attention_gpu append_attention as append_attention_gpu

View File

@@ -37,3 +37,9 @@ if current_platform.is_gcu():
from .gcu import * from .gcu import *
if hasattr(gcu, '__all__'): if hasattr(gcu, '__all__'):
__all__.extend(gcu.__all__) __all__.extend(gcu.__all__)
if current_platform.is_dcu():
from .dcu import *
from . import dcu
if hasattr(dcu, '__all__'):
__all__.extend(dcu.__all__)

View File

@@ -0,0 +1,22 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
dcu backend methods
"""
from .fused_moe_triton_backends import DCUTritonWeightOnlyMoEMethod
from .weight_only import DCUWeightOnlyLinearMethod
__all__ = ['DCUTritonWeightOnlyMoEMethod', 'DCUWeightOnlyLinearMethod']

View File

@@ -0,0 +1,244 @@
"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from paddle import nn
from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.utils import (create_hadamard_matrix_map,
get_tensor)
from fastdeploy.utils import ceil_div
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
"""
Use Triton Group Gemm to compute Fused MoE.
"""
def __init__(self, quant_method=None):
"""
Triton Group Gemm to compute Fused MoE.
"""
self.quant_method = quant_method
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
self.added_scale_attrs = [
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
]
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
"""process_prequanted_weights"""
pass
def create_weights(self, layer: nn.Layer, state_dict):
"""
Triton MoE create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(ffn1_weights) == layer.num_local_experts
assert len(ffn2_weights) == layer.num_local_experts
assert self.quant_method.name() == "wint8"
assert ffn1_weights[0].shape == [
layer.hidden_size, layer.moe_intermediate_size * 2
]
assert ffn2_weights[0].shape == [
layer.moe_intermediate_size, layer.hidden_size
]
ffn1_tensor = paddle.stack(ffn1_weights, axis=0)
ffn2_tensor = paddle.stack(ffn2_weights, axis=0)
if self.quant_method.name() == "wint8":
max_bound = 127
elif self.quant_method.name() == "wint4":
max_bound = 7
for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
quanted_weight_scale = weight_tensor.abs().max(axis=1)
quanted_weight = weight_tensor / quanted_weight_scale[:,
None, :] * max_bound
quanted_weight = paddle.round(quanted_weight).astype("int8")
quanted_weight_scale = quanted_weight_scale / max_bound
setattr(
layer, weight_name,
layer.create_parameter(
shape=quanted_weight.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
))
getattr(layer, weight_name).set_value(quanted_weight)
setattr(
layer, scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
))
getattr(layer, scale_name).set_value(quanted_weight_scale)
def apply(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
"""
token_num = x.shape[0]
top_k = layer.top_k
num_local_experts = layer.num_local_experts
top_k = layer.top_k
moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
scores = paddle.nn.functional.softmax(gate_out, axis=-1)
scores += layer.gate_correction_bias
topk_weights, topk_ids = paddle.topk(scores,
k=top_k,
axis=-1,
sorted=False)
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True)
intermediate_cache1 = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
)
intermediate_cache2 = paddle.empty(
(token_num * top_k, moe_intermediate_size),
dtype=x.dtype,
)
intermediate_cache3 = paddle.empty(
(token_num * top_k, hidden_size),
dtype=x.dtype,
)
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
}
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
from .triton_moe_kernels import fused_moe_kernel_paddle
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
topk_ids, num_local_experts, config["BLOCK_SIZE_M"])
max_num_tokens_padded = sorted_token_ids.shape[0]
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
fused_moe_kernel_paddle[grid](
x,
layer.moe_ffn1_weight,
intermediate_cache1,
None,
layer.moe_ffn1_weight_scale,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
moe_intermediate_size * 2,
hidden_size,
max_num_tokens_padded,
token_num * top_k,
stride_am=x.strides[0],
stride_ak=x.strides[1],
stride_be=layer.moe_ffn1_weight.strides[0],
stride_bk=layer.moe_ffn1_weight.strides[1],
stride_bn=layer.moe_ffn1_weight.strides[2],
stride_cm=intermediate_cache1.strides[0],
stride_cn=intermediate_cache1.strides[1],
#
stride_asm=-1,
stride_ask=-1,
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
stride_bsk=-1,
stride_bsn=layer.moe_ffn1_weight_scale.strides[1],
group_n=-1,
group_k=-1,
# Meta-parameters
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
GROUP_SIZE_M=config["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=False,
top_k=top_k,
compute_type_enum=1,
use_fp8_w8a8=False,
use_int8_w8a16=True,
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
)
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
intermediate_cache1)
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
fused_moe_kernel_paddle[grid](
intermediate_cache2,
layer.moe_ffn2_weight,
intermediate_cache3,
None,
layer.moe_ffn2_weight_scale,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
hidden_size,
moe_intermediate_size,
max_num_tokens_padded,
token_num * top_k,
stride_am=intermediate_cache2.strides[0],
stride_ak=intermediate_cache2.strides[1],
stride_be=layer.moe_ffn2_weight.strides[0],
stride_bk=layer.moe_ffn2_weight.strides[1],
stride_bn=layer.moe_ffn2_weight.strides[2],
stride_cm=intermediate_cache3.strides[0],
stride_cn=intermediate_cache3.strides[1],
stride_asm=-1,
stride_ask=-1,
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
stride_bsk=-1,
stride_bsn=layer.moe_ffn2_weight_scale.strides[1],
group_n=-1,
group_k=-1,
# Meta-parameters
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
GROUP_SIZE_M=config["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=True,
top_k=1,
compute_type_enum=1,
use_fp8_w8a8=False,
use_int8_w8a16=True,
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
)
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
out = intermediate_cache3.sum(axis=1)
if layer.tp_size > 1:
tensor_model_parallel_all_reduce(out)
return out

View File

@@ -0,0 +1,198 @@
"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import triton
import triton.language as tl
@triton.jit
def fused_moe_kernel_paddle(
a_ptr,
b_ptr,
c_ptr,
a_scale_ptr,
b_scale_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
num_tokens_post_padded,
num_valid_tokens,
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Block size for block-wise fp8 quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type_enum: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
even_Ks: tl.constexpr,
):
"""
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(num_tokens_post_padded, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
assert compute_type_enum == 1
compute_type = tl.bfloat16
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
if use_int8_w8a16:
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[
None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8:
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
else:
# (Zkk): every expert has one activation scale and weight scale.
a_scale = tl.load(a_scale_ptr + off_experts)
b_scale = tl.load(b_scale_ptr + off_experts)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
if even_Ks:
a = tl.load(
a_ptrs,
mask=token_mask[:, None],
other=0.0,
)
b = tl.load(b_ptrs,
cache_modifier=".cv",
eviction_policy='evict_first')
else:
a = tl.load(
a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
# We accumulate along the K dimension.
if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8:
if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask,
mask=token_mask,
other=0.0)
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
accumulator += tl.dot(a, b) * a_scale[:,
None] * b_scale[None, :]
else:
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8:
if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type)
else:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)

View File

@@ -0,0 +1,46 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from paddle.nn.quant import weight_dequantize
from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig, GPUWeightOnlyLinearMethod
class DCUWeightOnlyLinearMethod(GPUWeightOnlyLinearMethod):
"""
Weight only quantization method for linear layer on GPU
The weights are loaded in the BF16 numerical format. After loading, the quantization coefficients will be computed,
and the weights will be quantized to int8 or int4.
"""
def __init__(
self,
quant_config: WeightOnlyConfig,
) -> None:
super().__init__(quant_config)
def apply(self, layer, x):
dequant_out = weight_dequantize(
x=layer.linear_weight,
scale=layer.linear_weight_scale,
algo=self.quant_config.algo,
out_dtype=paddle.get_default_dtype()
)
linear_out = paddle.matmul(x, dequant_out)
if layer.linear_bias is not None:
linear_out = paddle.add(linear_out, layer.linear_bias)
return linear_out

View File

@@ -27,7 +27,7 @@ from fastdeploy.platforms import current_platform
from ..utils import create_and_set_parameter, get_tensor from ..utils import create_and_set_parameter, get_tensor
from .fused_moe_backend_base import MoEMethodBase from .fused_moe_backend_base import MoEMethodBase
if current_platform.is_cuda(): if current_platform.is_cuda() and not current_platform.is_dcu():
from fastdeploy.model_executor.ops.gpu import (moe_expert_dispatch, from fastdeploy.model_executor.ops.gpu import (moe_expert_dispatch,
moe_expert_reduce, noaux_tc) moe_expert_reduce, noaux_tc)
elif current_platform.is_iluvatar(): elif current_platform.is_iluvatar():

View File

@@ -75,6 +75,15 @@ class WeightOnlyConfig(QuantConfigBase):
return GCUWeightOnlyMoEMethod(self) return GCUWeightOnlyMoEMethod(self)
else: else:
return GCUWeightOnlyLinearMethod(self) return GCUWeightOnlyLinearMethod(self)
elif current_platform.is_dcu():
if isinstance(layer, FusedMoE):
from fastdeploy.model_executor.layers.backends import (
DCUTritonWeightOnlyMoEMethod)
return DCUTritonWeightOnlyMoEMethod(self)
else:
from fastdeploy.model_executor.layers.backends import (
DCUWeightOnlyLinearMethod)
return DCUWeightOnlyLinearMethod(self)
else: else:
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
if layer.use_method == "cutlass": if layer.use_method == "cutlass":

View File

@@ -39,7 +39,7 @@ from fastdeploy.model_executor.models.ernie4_5_moe import (Ernie4_5_Attention,
from fastdeploy.model_executor.models.model_base import ModelForCasualLM from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
if current_platform.is_cuda(): if current_platform.is_cuda() and not current_platform.is_dcu():
from fastdeploy.model_executor.ops.gpu import (extract_text_token_output, from fastdeploy.model_executor.ops.gpu import (extract_text_token_output,
text_image_gather_scatter, text_image_gather_scatter,
text_image_index_out) text_image_index_out)

View File

@@ -29,6 +29,10 @@ elif current_platform.is_gcu():
save_output, save_output,
set_stop_value_multi_ends, set_stop_value_multi_ends,
update_inputs) update_inputs)
elif current_platform.is_dcu():
from fastdeploy.model_executor.ops.gpu import (
get_padding_offset, save_output, set_stop_value_multi_ends,
step_paddle, update_inputs)
else: else:
from fastdeploy.model_executor.ops.gpu import ( from fastdeploy.model_executor.ops.gpu import (
get_padding_offset, save_output, set_stop_value_multi_ends, get_padding_offset, save_output, set_stop_value_multi_ends,

View File

@@ -33,14 +33,14 @@ def __getattr__(name: str):
# lazy init current_platform. # lazy init current_platform.
global _current_platform global _current_platform
if _current_platform is None: if _current_platform is None:
if paddle.is_compiled_with_cuda(): if paddle.is_compiled_with_rocm():
_current_platform = DCUPlatform()
elif paddle.is_compiled_with_cuda():
_current_platform = CUDAPlatform() _current_platform = CUDAPlatform()
elif paddle.is_compiled_with_xpu(): elif paddle.is_compiled_with_xpu():
_current_platform = XPUPlatform() _current_platform = XPUPlatform()
elif paddle.is_compiled_with_custom_device("npu"): elif paddle.is_compiled_with_custom_device("npu"):
_current_platform = NPUPlatform() _current_platform = NPUPlatform()
elif paddle.is_compiled_with_rocm():
_current_platform = DCUPlatform()
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"): elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
_current_platform = IluvatarPlatform() _current_platform = IluvatarPlatform()
elif paddle.is_compiled_with_custom_device("gcu"): elif paddle.is_compiled_with_custom_device("gcu"):

View File

@@ -25,6 +25,7 @@ class _Backend(enum.Enum):
APPEND_ATTN = enum.auto() APPEND_ATTN = enum.auto()
MLA_ATTN = enum.auto() MLA_ATTN = enum.auto()
FLASH_ATTN = enum.auto() FLASH_ATTN = enum.auto()
BLOCK_ATTN = enum.auto()
class Platform: class Platform:

View File

@@ -14,7 +14,9 @@
""" """
dcu platform file dcu platform file
""" """
from .base import Platform import paddle
from .base import Platform, _Backend
from paddleformers.utils.log import logger
class DCUPlatform(Platform): class DCUPlatform(Platform):
@@ -22,3 +24,38 @@ class DCUPlatform(Platform):
dcu platform class dcu platform class
""" """
device_name = "dcu" device_name = "dcu"
@classmethod
def available(self):
"""
Check whether CUDA is available.
"""
try:
assert len(paddle.static.cuda_places()) > 0
return True
except Exception as e:
logger.warning(
"You are using GPU version PaddlePaddle, but there is no GPU "
"detected on your machine. Maybe CUDA devices is not set properly."
f"\n Original Error is {e}"
)
return False
@classmethod
def get_attention_backend_cls(
cls,
selected_backend
):
"""
get_attention_backend_cls
"""
if selected_backend == _Backend.NATIVE_ATTN:
logger.info("Using NATIVE ATTN backend.")
return ("fastdeploy.model_executor.layers.attention.PaddleNativeAttnBackend")
elif selected_backend == _Backend.BLOCK_ATTN:
logger.info("Using BLOCK ATTN backend.")
return ("fastdeploy.model_executor.layers.attention.BlockAttentionBackend")
else:
logger.warning(
"Other backends are not supported for now."
)

View File

@@ -0,0 +1,112 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import gc
import time
from typing import List, Optional
import paddle
import paddle.nn as nn
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.utils import get_logger
from fastdeploy.worker.gpu_model_runner import GPUModelRunner
from fastdeploy.worker.output import ModelRunnerOutput
from fastdeploy.worker.gpu_worker import GpuWorker
logger = get_logger("dcu_worker", "dcu_worker.log")
class DcuWorker(GpuWorker):
""" """
def __init__(
self,
fd_config: FDConfig,
local_rank: int,
rank: int,
):
super().__init__(
fd_config=fd_config,
local_rank=local_rank,
rank=rank,
)
pass
def determine_available_memory(self) -> int:
"""
Profiles the peak memory usage of the model to determine how much
memory can be used for KV cache without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
Tip:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# 1. Record memory state before profile run
Gb = 1024**3
start_time = time.perf_counter()
paddle.device.cuda.reset_max_memory_reserved(self.local_rank)
paddle.device.cuda.reset_max_memory_allocated(self.local_rank)
paddle_reserved_mem_before_run = paddle.device.cuda.max_memory_reserved(
self.local_rank)
paddle_allocated_mem_before_run = paddle.device.cuda.max_memory_allocated(
self.local_rank) # not reserved
total_gpu_memory = paddle.device.cuda.get_device_properties(self.local_rank).total_memory
before_used_gpu_memory = paddle.device.cuda.memory_allocated(self.local_rank)
logger.info((
"Before running the profile, the memory usage info is as follows:",
f"\nDevice Total memory: {total_gpu_memory / Gb}",
f"\nDevice used memory: {before_used_gpu_memory / Gb}",
f"\nPaddle reserved memory: {paddle_reserved_mem_before_run / Gb}",
f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}"))
# 2. Profile run
self.model_runner.profile_run()
# 3. Statistical memory information
paddle_reserved_mem_after_run = paddle.device.cuda.max_memory_reserved(
self.local_rank)
paddle_allocated_mem_after_run = paddle.device.cuda.max_memory_allocated(
self.local_rank)
after_used_gpu_memory = paddle.device.cuda.memory_allocated(self.local_rank)
# v0 worker
model_block_memory_used = self.cal_theortical_kvcache()
paddle.device.cuda.empty_cache()
paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run
available_kv_cache_memory = total_gpu_memory * \
self.parallel_config.gpu_memory_utilization - after_used_gpu_memory - paddle_peak_increase
available_kv_cache_memory += model_block_memory_used * self.parallel_config.max_block_num
end_time = time.perf_counter()
logger.info(
("After running the profile, the memory usage info is as follows:",
f"\nDevice Total memory: {total_gpu_memory / Gb}",
f"\nDevice used memory: {after_used_gpu_memory / Gb}",
f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}",
f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}",
f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}",
f"Profile time: {end_time - start_time}"))
return available_kv_cache_memory # return to caculate the block num in this device

View File

@@ -41,7 +41,9 @@ from fastdeploy.model_executor.pre_and_post_process import (post_process,
pre_process, pre_process,
rebuild_padding, rebuild_padding,
step_cuda) step_cuda)
from fastdeploy.spec_decode import MTPProposer, NgramProposer from fastdeploy.platforms import current_platform
if not current_platform.is_dcu():
from fastdeploy.spec_decode import MTPProposer, NgramProposer
from fastdeploy.worker.forward_meta import ForwardMeta from fastdeploy.worker.forward_meta import ForwardMeta
from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.model_runner_base import ModelRunnerBase
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput

View File

@@ -42,6 +42,9 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase:
""" """
get worker of different device get worker of different device
""" """
if current_platform.is_dcu():
from fastdeploy.worker.dcu_worker import DcuWorker
return DcuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
if current_platform.is_cuda(): if current_platform.is_cuda():
from fastdeploy.worker.gpu_worker import GpuWorker from fastdeploy.worker.gpu_worker import GpuWorker
return GpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank) return GpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)

29
requirements_dcu.txt Normal file
View File

@@ -0,0 +1,29 @@
setuptools>=62.3.0,<80.0
pre-commit
yapf
flake8
ruamel.yaml
zmq
aiozmq
openai
tqdm
pynvml
uvicorn
fastapi
paddleformers
redis
etcd3
httpx
tool_helpers
pybind11[global]
tabulate
gradio
xlwt
visualdl
setuptools-scm>=8
prometheus-client
decord
moviepy
use-triton-in-paddle
crcmod
fastsafetensors==0.1.14

View File

@@ -146,6 +146,8 @@ def load_requirements():
requirements_file_name = 'requirements.txt' requirements_file_name = 'requirements.txt'
if paddle.is_compiled_with_custom_device('iluvatar_gpu'): if paddle.is_compiled_with_custom_device('iluvatar_gpu'):
requirements_file_name = 'requirements_iluvatar.txt' requirements_file_name = 'requirements_iluvatar.txt'
elif paddle.is_compiled_with_rocm():
requirements_file_name = 'requirements_dcu.txt'
requirements_path = os.path.join(os.path.dirname(__file__), requirements_path = os.path.join(os.path.dirname(__file__),
requirements_file_name) requirements_file_name)
with open(requirements_path, 'r') as f: with open(requirements_path, 'r') as f: