mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
dcu adapter ernie45t (#2756)
Co-authored-by: lifu <lifu@sugon.com> Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
4
build.sh
4
build.sh
@@ -77,8 +77,10 @@ function copy_ops(){
|
|||||||
is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"`
|
is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"`
|
||||||
if [ "$is_rocm" = "True" ]; then
|
if [ "$is_rocm" = "True" ]; then
|
||||||
DEVICE_TYPE="rocm"
|
DEVICE_TYPE="rocm"
|
||||||
|
mkdir -p ../fastdeploy/model_executor/ops/base
|
||||||
|
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||||
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
|
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
|
||||||
echo -e "ROCM ops have been copy to fastdeploy"
|
echo -e "BASE and ROCM ops have been copy to fastdeploy"
|
||||||
return
|
return
|
||||||
fi
|
fi
|
||||||
mkdir -p ../fastdeploy/model_executor/ops/base
|
mkdir -p ../fastdeploy/model_executor/ops/base
|
||||||
|
@@ -214,11 +214,19 @@ HOSTDEVICE inline void Store(const AlignedVector<T, Size> &vec, T *addr) {
|
|||||||
*addr_vec = vec;
|
*addr_vec = vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef PADDLE_WITH_HIP
|
||||||
|
template <int Size>
|
||||||
|
HOSTDEVICE inline void Store(const AlignedVector<hip_bfloat16, Size> &vec,
|
||||||
|
int8_t *addr) {
|
||||||
|
printf("Error: Store hip_bfloat16 to int8_t is not supported!");
|
||||||
|
}
|
||||||
|
#else
|
||||||
template <int Size>
|
template <int Size>
|
||||||
HOSTDEVICE inline void Store(const AlignedVector<__nv_bfloat16, Size> &vec,
|
HOSTDEVICE inline void Store(const AlignedVector<__nv_bfloat16, Size> &vec,
|
||||||
int8_t *addr) {
|
int8_t *addr) {
|
||||||
printf("Error: Store __nv_bfloat16 to int8_t is not supported!");
|
printf("Error: Store __nv_bfloat16 to int8_t is not supported!");
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
template <int Size>
|
template <int Size>
|
||||||
HOSTDEVICE inline void Store(const AlignedVector<half, Size> &vec,
|
HOSTDEVICE inline void Store(const AlignedVector<half, Size> &vec,
|
||||||
@@ -478,7 +486,12 @@ template <typename T>
|
|||||||
static void PrintMatrix3(const T *mat_d, int num, std::string name) {
|
static void PrintMatrix3(const T *mat_d, int num, std::string name) {
|
||||||
|
|
||||||
std::vector<T> tmp(num);
|
std::vector<T> tmp(num);
|
||||||
|
#ifdef PADDLE_WITH_HIP
|
||||||
|
hipMemcpy(tmp.data(), mat_d, sizeof(T) * num, hipMemcpyDeviceToHost);
|
||||||
|
#else
|
||||||
cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost);
|
cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
std::ofstream outfile;
|
std::ofstream outfile;
|
||||||
outfile.open(name + ".txt", std::ios::out);
|
outfile.open(name + ".txt", std::ios::out);
|
||||||
@@ -495,6 +508,7 @@ static void PrintMatrix3(const T *mat_d, int num, std::string name) {
|
|||||||
outfile.close();
|
outfile.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifndef PADDLE_WITH_HIP
|
||||||
__forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr,
|
__forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr,
|
||||||
int mode = 0) {
|
int mode = 0) {
|
||||||
uint32_t flag;
|
uint32_t flag;
|
||||||
@@ -534,6 +548,7 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
|||||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||||
return max_shared_mem_per_block_opt_in;
|
return max_shared_mem_per_block_opt_in;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
inline int GetSMVersion() {
|
inline int GetSMVersion() {
|
||||||
static int sm_version = phi::backends::gpu::GetGPUComputeCapability(
|
static int sm_version = phi::backends::gpu::GetGPUComputeCapability(
|
||||||
|
@@ -91,7 +91,12 @@ void set_data_ipc(const paddle::Tensor& tmp_input,
|
|||||||
memset((void *)shm, 0, sizeof(*shm));
|
memset((void *)shm, 0, sizeof(*shm));
|
||||||
|
|
||||||
void *data_ptr_now = reinterpret_cast<void*>(const_cast<data_t*>(tmp_input.data<data_t>()));
|
void *data_ptr_now = reinterpret_cast<void*>(const_cast<data_t*>(tmp_input.data<data_t>()));
|
||||||
|
#ifdef PADDLE_WITH_HIP
|
||||||
|
checkCudaErrors(hipIpcGetMemHandle((hipIpcMemHandle_t *)&shm->memHandle, data_ptr_now));
|
||||||
|
#else
|
||||||
checkCudaErrors(cudaIpcGetMemHandle((cudaIpcMemHandle_t *)&shm->memHandle, data_ptr_now));
|
checkCudaErrors(cudaIpcGetMemHandle((cudaIpcMemHandle_t *)&shm->memHandle, data_ptr_now));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -37,10 +37,18 @@ std::vector<paddle::Tensor> ShareExternalData(paddle::Tensor& input,
|
|||||||
}
|
}
|
||||||
shm = (volatile shmStruct *)info.addr;
|
shm = (volatile shmStruct *)info.addr;
|
||||||
void *ptr = nullptr;
|
void *ptr = nullptr;
|
||||||
|
#ifdef PADDLE_WITH_HIP
|
||||||
|
checkCudaErrors(
|
||||||
|
hipIpcOpenMemHandle(&ptr,
|
||||||
|
*(hipIpcMemHandle_t *)&shm->memHandle, // NOLINT
|
||||||
|
hipIpcMemLazyEnablePeerAccess));
|
||||||
|
#else
|
||||||
checkCudaErrors(
|
checkCudaErrors(
|
||||||
cudaIpcOpenMemHandle(&ptr,
|
cudaIpcOpenMemHandle(&ptr,
|
||||||
*(cudaIpcMemHandle_t *)&shm->memHandle, // NOLINT
|
*(cudaIpcMemHandle_t *)&shm->memHandle, // NOLINT
|
||||||
cudaIpcMemLazyEnablePeerAccess));
|
cudaIpcMemLazyEnablePeerAccess));
|
||||||
|
#endif
|
||||||
|
|
||||||
paddle::Tensor tmp_tensor = paddle::from_blob(
|
paddle::Tensor tmp_tensor = paddle::from_blob(
|
||||||
ptr,
|
ptr,
|
||||||
shape,
|
shape,
|
||||||
|
@@ -187,39 +187,45 @@ def find_end_files(directory, end_str):
|
|||||||
if paddle.is_compiled_with_rocm():
|
if paddle.is_compiled_with_rocm():
|
||||||
# NOTE(@duanyanhui): paddle.is_compiled_with_cuda() returns True when paddle compiled with rocm.
|
# NOTE(@duanyanhui): paddle.is_compiled_with_cuda() returns True when paddle compiled with rocm.
|
||||||
# so we need to check if paddle compiled with rocm at first.
|
# so we need to check if paddle compiled with rocm at first.
|
||||||
|
json_dir = "third_party/nlohmann_json"
|
||||||
|
if not os.path.exists(json_dir) or not os.listdir(json_dir):
|
||||||
|
if not os.path.exists(json_dir):
|
||||||
|
os.makedirs(json_dir)
|
||||||
|
clone_git_repo("v3.11.3", "https://bgithub.xyz/nlohmann/json.git", json_dir)
|
||||||
|
if not os.listdir(json_dir):
|
||||||
|
raise ValueError("Git clone nlohmann_json failed!")
|
||||||
|
sources=[
|
||||||
|
"gpu_ops/set_value_by_flags.cu",
|
||||||
|
"gpu_ops/token_penalty_multi_scores.cu",
|
||||||
|
"gpu_ops/stop_generation.cu",
|
||||||
|
"gpu_ops/stop_generation_multi_ends.cu",
|
||||||
|
"gpu_ops/get_padding_offset.cu",
|
||||||
|
"gpu_ops/update_inputs.cu",
|
||||||
|
"gpu_ops/rebuild_padding.cu",
|
||||||
|
"gpu_ops/step.cu",
|
||||||
|
"gpu_ops/set_data_ipc.cu",
|
||||||
|
"gpu_ops/moe/tritonmoe_preprocess.cu",
|
||||||
|
"gpu_ops/step_system_cache.cu",
|
||||||
|
"gpu_ops/get_output_ep.cc",
|
||||||
|
"gpu_ops/speculate_decoding/speculate_get_padding_offset.cu",
|
||||||
|
"gpu_ops/speculate_decoding/speculate_get_output.cc",
|
||||||
|
"gpu_ops/share_external_data.cu",
|
||||||
|
"gpu_ops/speculate_decoding/speculate_clear_accept_nums.cu",
|
||||||
|
"gpu_ops/speculate_decoding/speculate_get_output_padding_offset.cu",
|
||||||
|
"gpu_ops/speculate_decoding/speculate_get_seq_lens_output.cu",
|
||||||
|
"gpu_ops/speculate_decoding/speculate_save_output.cc",
|
||||||
|
"gpu_ops/speculate_decoding/speculate_set_value_by_flags.cu",
|
||||||
|
"gpu_ops/speculate_decoding/speculate_step.cu",
|
||||||
|
"gpu_ops/speculate_decoding/speculate_step_system_cache.cu",
|
||||||
|
"gpu_ops/speculate_decoding/speculate_update_v3.cu",
|
||||||
|
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
|
||||||
|
"gpu_ops/fused_rotary_position_encoding.cu",
|
||||||
|
"gpu_ops/step_reschedule.cu",
|
||||||
|
]
|
||||||
setup(
|
setup(
|
||||||
name="fastdeploy_ops",
|
name="fastdeploy_ops",
|
||||||
ext_modules=CUDAExtension(
|
ext_modules=CUDAExtension(
|
||||||
sources=[
|
sources=sources,
|
||||||
"gpu_ops/save_with_output.cc",
|
|
||||||
"gpu_ops/set_mask_value.cu",
|
|
||||||
"gpu_ops/set_value_by_flags.cu",
|
|
||||||
"gpu_ops/ngram_mask.cu",
|
|
||||||
"gpu_ops/gather_idx.cu",
|
|
||||||
"gpu_ops/token_penalty_multi_scores.cu",
|
|
||||||
"gpu_ops/token_penalty_only_once.cu",
|
|
||||||
"gpu_ops/stop_generation.cu",
|
|
||||||
"gpu_ops/stop_generation_multi_ends.cu",
|
|
||||||
"gpu_ops/stop_generation_multi_stop_seqs.cu",
|
|
||||||
"gpu_ops/set_flags.cu",
|
|
||||||
"gpu_ops/fused_get_rope.cu",
|
|
||||||
"gpu_ops/transfer_output.cc",
|
|
||||||
"gpu_ops/get_padding_offset.cu",
|
|
||||||
"gpu_ops/update_inputs.cu",
|
|
||||||
"gpu_ops/update_inputs_beam.cu",
|
|
||||||
"gpu_ops/beam_search_softmax.cu",
|
|
||||||
"gpu_ops/rebuild_padding.cu",
|
|
||||||
"gpu_ops/save_with_output_msg.cc",
|
|
||||||
"gpu_ops/get_output.cc",
|
|
||||||
"gpu_ops/get_output_msg_with_topk.cc",
|
|
||||||
"gpu_ops/step.cu",
|
|
||||||
"gpu_ops/step_reschedule.cu",
|
|
||||||
"gpu_ops/set_data_ipc.cu",
|
|
||||||
"gpu_ops/read_data_ipc.cu",
|
|
||||||
"gpu_ops/dequant_int8.cu",
|
|
||||||
"gpu_ops/enforce_generation.cu",
|
|
||||||
"gpu_ops/tune_cublaslt_gemm.cu",
|
|
||||||
],
|
|
||||||
extra_compile_args={
|
extra_compile_args={
|
||||||
"cxx": ["-O3"],
|
"cxx": ["-O3"],
|
||||||
"hipcc": [
|
"hipcc": [
|
||||||
@@ -231,6 +237,9 @@ if paddle.is_compiled_with_rocm():
|
|||||||
"-U__HIP_NO_BFLOAT16_CONVERSIONS__",
|
"-U__HIP_NO_BFLOAT16_CONVERSIONS__",
|
||||||
"-U__HIP_NO_BFLOAT162_OPERATORS__",
|
"-U__HIP_NO_BFLOAT162_OPERATORS__",
|
||||||
"-U__HIP_NO_BFLOAT162_CONVERSIONS__",
|
"-U__HIP_NO_BFLOAT162_CONVERSIONS__",
|
||||||
|
"-DPADDLE_DEV",
|
||||||
|
"-Ithird_party/nlohmann_json/include",
|
||||||
|
"-Igpu_ops",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
@@ -6,3 +6,4 @@ FastDeploy currently supports installation on the following hardware platforms:
|
|||||||
- [Kunlun XPU Installation](kunlunxin_xpu.md)
|
- [Kunlun XPU Installation](kunlunxin_xpu.md)
|
||||||
- [Enflame S60 GCU Installation](Enflame_gcu.md)
|
- [Enflame S60 GCU Installation](Enflame_gcu.md)
|
||||||
- [Iluvatar GPU Installation](iluvatar_gpu.md)
|
- [Iluvatar GPU Installation](iluvatar_gpu.md)
|
||||||
|
- [Hygon DCU Installation](hygon_dcu.md)
|
81
docs/get_started/installation/hygon_dcu.md
Normal file
81
docs/get_started/installation/hygon_dcu.md
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
# Run ERNIE-4.5-300B-A47B & ERNIE-4.5-21B-A3B model on hygon machine
|
||||||
|
The current version of the software merely serves as a demonstration demo for the hygon k100AI combined with the Fastdeploy inference framework for large models. There may be issues when running the latest ERNIE4.5 model, and we will conduct repairs and performance optimization in the future. Subsequent versions will provide customers with a more stable version.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
Firstly, you need to prepare a machine with the following configuration
|
||||||
|
- OS:Linux
|
||||||
|
- Python:3.10
|
||||||
|
- Memory: 2T
|
||||||
|
- Disk: 4T
|
||||||
|
- DCU Model:K100AI
|
||||||
|
- DCU Driver Version:≥ 6.3.8-V1.9.2
|
||||||
|
|
||||||
|
## 1. Set up using Docker (Recommended)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir Work
|
||||||
|
cd Work
|
||||||
|
docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:fastdeploy2.0.0-kylinv10-dtk25.04-py3.10
|
||||||
|
|
||||||
|
docker run -it \
|
||||||
|
--network=host \
|
||||||
|
--name=ernie45t \
|
||||||
|
--privileged \
|
||||||
|
--device=/dev/kfd \
|
||||||
|
--device=/dev/dri \
|
||||||
|
--ipc=host \
|
||||||
|
--shm-size=16G \
|
||||||
|
--group-add video \
|
||||||
|
--cap-add=SYS_PTRACE \
|
||||||
|
--security-opt seccomp=unconfined \
|
||||||
|
-u root \
|
||||||
|
--ulimit stack=-1:-1 \
|
||||||
|
--ulimit memlock=-1:-1 \
|
||||||
|
-v `pwd`:/home \
|
||||||
|
-v /opt/hyhal:/opt/hyhal:ro \
|
||||||
|
image.sourcefind.cn:5000/dcu/admin/base/custom:fastdeploy2.0.0-kylinv10-dtk25.04-py3.10 /bin/bash
|
||||||
|
```
|
||||||
|
|
||||||
|
## 2. Start service
|
||||||
|
```bash
|
||||||
|
export FD_ATTENTION_BACKEND="BLOCK_ATTN"
|
||||||
|
python -m fastdeploy.entrypoints.openai.api_server \
|
||||||
|
--model "/models/ERNIE-45-Turbo/ERNIE-4.5-300B-A47B-Paddle/" \
|
||||||
|
--port 8188 \
|
||||||
|
--tensor-parallel-size 8 \
|
||||||
|
--quantization=wint8 \
|
||||||
|
--gpu-memory-utilization=0.8
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Send requests
|
||||||
|
|
||||||
|
Send requests using either curl or Python
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Where is the capital of China?"}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
|
||||||
|
ip = "0.0.0.0"
|
||||||
|
service_http_port = "8188"
|
||||||
|
client = openai.Client(base_url=f"http://{ip}:{service_http_port}/v1", api_key="EMPTY_API_KEY")
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Eliza's rate per hour for the first 40 hours she works each week is $10. She also receives an overtime pay of 1.2 times her regular hourly rate. If Eliza worked for 45 hours this week, how much are her earnings for this week?"},
|
||||||
|
],
|
||||||
|
temperature=1,
|
||||||
|
max_tokens=1024,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
@@ -6,3 +6,4 @@ FastDeploy currently supports installation on the following hardware platforms:
|
|||||||
- [Kunlunxin XPU Installation](kunlunxin_xpu.md)
|
- [Kunlunxin XPU Installation](kunlunxin_xpu.md)
|
||||||
- [Enflame S60 GCU Installation](Enflame_gcu.md)
|
- [Enflame S60 GCU Installation](Enflame_gcu.md)
|
||||||
- [Iluvatar GPU Installation](iluvatar_gpu.md)
|
- [Iluvatar GPU Installation](iluvatar_gpu.md)
|
||||||
|
- [Hygon DCU Installation](hygon_dcu.md)
|
81
docs/zh/get_started/installation/hygon_dcu.md
Normal file
81
docs/zh/get_started/installation/hygon_dcu.md
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
# 使用 FastDeploy 在海光 K100AI 上运行 ERNIE-4.5-300B-A47B & ERNIE-4.5-21B-A3B
|
||||||
|
当前版本软件只是作为K100AI + Fastdeploy 推理大模型的一个演示 demo,跑最新ERNIE4.5模型可能存在问题,后续进行修复和性能优化,给客户提供一个更稳定的版本。
|
||||||
|
|
||||||
|
## 准备机器
|
||||||
|
首先您需要准备以下配置的机器
|
||||||
|
- OS:Linux
|
||||||
|
- Python:3.10
|
||||||
|
- 内存:2T
|
||||||
|
- 磁盘:4T
|
||||||
|
- DCU 型号:K100AI
|
||||||
|
- DCU 驱动版本:≥ 6.3.8-V1.9.2
|
||||||
|
|
||||||
|
## 1. 使用 Docker 安装(推荐)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir Work
|
||||||
|
cd Work
|
||||||
|
docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:fastdeploy2.0.0-kylinv10-dtk25.04-py3.10
|
||||||
|
|
||||||
|
docker run -it \
|
||||||
|
--network=host \
|
||||||
|
--name=ernie45t \
|
||||||
|
--privileged \
|
||||||
|
--device=/dev/kfd \
|
||||||
|
--device=/dev/dri \
|
||||||
|
--ipc=host \
|
||||||
|
--shm-size=16G \
|
||||||
|
--group-add video \
|
||||||
|
--cap-add=SYS_PTRACE \
|
||||||
|
--security-opt seccomp=unconfined \
|
||||||
|
-u root \
|
||||||
|
--ulimit stack=-1:-1 \
|
||||||
|
--ulimit memlock=-1:-1 \
|
||||||
|
-v `pwd`:/home \
|
||||||
|
-v /opt/hyhal:/opt/hyhal:ro \
|
||||||
|
image.sourcefind.cn:5000/dcu/admin/base/custom:fastdeploy2.0.0-kylinv10-dtk25.04-py3.10 /bin/bash
|
||||||
|
```
|
||||||
|
|
||||||
|
## 2. 启动服务
|
||||||
|
```bash
|
||||||
|
export FD_ATTENTION_BACKEND="BLOCK_ATTN"
|
||||||
|
python -m fastdeploy.entrypoints.openai.api_server \
|
||||||
|
--model "/models/ERNIE-45-Turbo/ERNIE-4.5-300B-A47B-Paddle/" \
|
||||||
|
--port 8188 \
|
||||||
|
--tensor-parallel-size 8 \
|
||||||
|
--quantization=wint8 \
|
||||||
|
--gpu-memory-utilization=0.8
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 请求服务
|
||||||
|
|
||||||
|
您可以基于 OpenAI 协议,通过 curl 和 python 两种方式请求服务。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Where is the capital of China?"}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
|
||||||
|
ip = "0.0.0.0"
|
||||||
|
service_http_port = "8188"
|
||||||
|
client = openai.Client(base_url=f"http://{ip}:{service_http_port}/v1", api_key="EMPTY_API_KEY")
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Eliza's rate per hour for the first 40 hours she works each week is $10. She also receives an overtime pay of 1.2 times her regular hourly rate. If Eliza worked for 45 hours this week, how much are her earnings for this week?"},
|
||||||
|
],
|
||||||
|
temperature=1,
|
||||||
|
max_tokens=1024,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
@@ -20,9 +20,11 @@ from .mla_attention_backend import MLAAttentionBackend
|
|||||||
from .native_paddle_backend import PaddleNativeAttnBackend
|
from .native_paddle_backend import PaddleNativeAttnBackend
|
||||||
from .xpu_attn_backend import XPUAttentionBackend
|
from .xpu_attn_backend import XPUAttentionBackend
|
||||||
from .iluvatar_attn_backend import IluvatarAttnBackend
|
from .iluvatar_attn_backend import IluvatarAttnBackend
|
||||||
|
from .block_multihead_attn_backend import BlockAttentionBackend
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AttentionBackend", "PaddleNativeAttnBackend",
|
"AttentionBackend", "PaddleNativeAttnBackend",
|
||||||
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend",
|
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend",
|
||||||
"MLAAttentionBackend", "FlashAttentionBackend", "IluvatarAttnBackend"
|
"MLAAttentionBackend", "FlashAttentionBackend", "IluvatarAttnBackend",
|
||||||
|
"BlockAttentionBackend"
|
||||||
]
|
]
|
||||||
|
@@ -0,0 +1,172 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from paddle._typing.dtype_like import _DTypeLiteral
|
||||||
|
|
||||||
|
from fastdeploy.config import FDConfig
|
||||||
|
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||||
|
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||||
|
AttentionBackend, AttentionMetadata)
|
||||||
|
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BlockAttentionMetadata(AttentionMetadata):
|
||||||
|
"""
|
||||||
|
BlockAttentionMetadata
|
||||||
|
"""
|
||||||
|
max_len_kv: paddle.Tensor = None
|
||||||
|
set_max_lengths: int = -1
|
||||||
|
encoder_batch_ids: paddle.Tensor = None
|
||||||
|
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||||
|
encoder_num_blocks: paddle.Tensor = None
|
||||||
|
kv_batch_ids: paddle.Tensor = None
|
||||||
|
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||||
|
kv_num_blocks: paddle.Tensor = None
|
||||||
|
decoder_batch_ids: paddle.Tensor = None
|
||||||
|
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||||
|
decoder_num_blocks: paddle.Tensor = None
|
||||||
|
|
||||||
|
_dtype: _DTypeLiteral = paddle.bfloat16
|
||||||
|
encoder_max_partition_size: int = 32768
|
||||||
|
max_partition_size: int = 32768
|
||||||
|
block_tables: Optional[paddle.Tensor] = None
|
||||||
|
rotary_embs: Optional[paddle.Tensor] = None
|
||||||
|
attn_mask: Optional[paddle.Tensor] = None
|
||||||
|
encoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||||
|
decoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||||
|
_fuse_kernel_compute_dtype: str = "bf16"
|
||||||
|
|
||||||
|
# pd_disaggregation
|
||||||
|
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||||
|
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockAttentionBackend(AttentionBackend):
|
||||||
|
"""
|
||||||
|
BlockAttentionBackend backend implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, fd_config: FDConfig, kv_num_heads: int,
|
||||||
|
num_heads: int, head_dim: int):
|
||||||
|
"""
|
||||||
|
BlockAttentionBackend __init__
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.attention_metadata: BlockAttentionMetadata = None
|
||||||
|
self.block_size = fd_config.parallel_config.block_size
|
||||||
|
self.max_seq_len = fd_config.parallel_config.max_model_len
|
||||||
|
self.rope_theta = (10000.0 if fd_config.model_config.rope_theta
|
||||||
|
is None else fd_config.model_config.rope_theta)
|
||||||
|
self.rank = fd_config.parallel_config.tensor_parallel_rank
|
||||||
|
|
||||||
|
self.kv_num_heads = kv_num_heads
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = fd_config.model_config.head_dim
|
||||||
|
|
||||||
|
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||||
|
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||||
|
metadata = BlockAttentionMetadata()
|
||||||
|
metadata._dtype = paddle.get_default_dtype()
|
||||||
|
if metadata._dtype == "bfloat16":
|
||||||
|
metadata._fuse_kernel_compute_dtype = "bf16"
|
||||||
|
elif metadata._dtype == "float16":
|
||||||
|
metadata._fuse_kernel_compute_dtype = "fp16"
|
||||||
|
elif metadata._dtype == "float32":
|
||||||
|
metadata._fuse_kernel_compute_dtype = "fp32"
|
||||||
|
metadata.block_tables = forward_meta.block_tables
|
||||||
|
metadata.rotary_embs = forward_meta.rotary_embs
|
||||||
|
metadata.attn_mask = forward_meta.attn_mask
|
||||||
|
self.attention_metadata = metadata
|
||||||
|
|
||||||
|
def get_attntion_meta(self):
|
||||||
|
"""get_attntion_meta"""
|
||||||
|
return self.attention_metadata
|
||||||
|
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
self,
|
||||||
|
max_num_blocks: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Caculate kv cache shape
|
||||||
|
"""
|
||||||
|
return (max_num_blocks, self.kv_num_heads, self.block_size,
|
||||||
|
self.head_dim)
|
||||||
|
|
||||||
|
def forward_mixed(
|
||||||
|
self,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
qkv,
|
||||||
|
compressed_kv: paddle.Tensor,
|
||||||
|
k_pe: paddle.Tensor,
|
||||||
|
layer: Attention,
|
||||||
|
forward_meta: ForwardMeta,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
forward_mixed
|
||||||
|
"""
|
||||||
|
metadata = self.attention_metadata
|
||||||
|
|
||||||
|
res = paddle.incubate.nn.functional.block_multihead_attention(
|
||||||
|
qkv,
|
||||||
|
forward_meta.caches[2 * layer.layer_id],
|
||||||
|
forward_meta.caches[2 * layer.layer_id + 1],
|
||||||
|
forward_meta.seq_lens_encoder,
|
||||||
|
forward_meta.seq_lens_decoder,
|
||||||
|
forward_meta.seq_lens_this_time,
|
||||||
|
forward_meta.padding_offset,
|
||||||
|
forward_meta.cum_offsets,
|
||||||
|
forward_meta.cu_seqlens_q,
|
||||||
|
forward_meta.cu_seqlens_k,
|
||||||
|
metadata.block_tables,
|
||||||
|
getattr(layer, "pre_key_cache", None),
|
||||||
|
getattr(layer, "pre_value_cache", None),
|
||||||
|
getattr(layer, "cache_k_scale", None),
|
||||||
|
getattr(layer, "cache_v_scale", None),
|
||||||
|
getattr(layer, "cache_k_out_scale", None),
|
||||||
|
getattr(layer, "cache_v_out_scale", None),
|
||||||
|
layer.qkv_scale,
|
||||||
|
layer.qkv_bias,
|
||||||
|
layer.linear_shift,
|
||||||
|
layer.linear_smooth,
|
||||||
|
getattr(layer, "max_enc_len_this_time", None),
|
||||||
|
getattr(layer, "max_dec_len_this_time", None),
|
||||||
|
metadata.rotary_embs,
|
||||||
|
metadata.attn_mask,
|
||||||
|
None, # tgt_mask
|
||||||
|
self.max_seq_len,
|
||||||
|
self.block_size,
|
||||||
|
layer.use_neox_rotary_style,
|
||||||
|
getattr(layer, "use_dynamic_cachekv_quant", False),
|
||||||
|
quant_round_type=getattr(layer, "quant_round_type", 0),
|
||||||
|
quant_max_bound=getattr(layer, "quant_max_bound", 0.0),
|
||||||
|
quant_min_bound=getattr(layer, "quant_min_bound", 0.0),
|
||||||
|
out_scale=getattr(layer, "out_scale", -1.0),
|
||||||
|
compute_dtype=metadata._fuse_kernel_compute_dtype,
|
||||||
|
rope_theta=self.rope_theta,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
return res
|
@@ -29,7 +29,7 @@ from fastdeploy.model_executor.layers.attention.ops import (
|
|||||||
open_shm_and_get_meta_signal)
|
open_shm_and_get_meta_signal)
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda() and not current_platform.is_dcu():
|
||||||
from fastdeploy.model_executor.ops.gpu import (decode_mla_write_cache,
|
from fastdeploy.model_executor.ops.gpu import (decode_mla_write_cache,
|
||||||
multi_head_latent_attention,
|
multi_head_latent_attention,
|
||||||
prefill_mla_write_cache)
|
prefill_mla_write_cache)
|
||||||
|
@@ -20,7 +20,7 @@ import paddle
|
|||||||
|
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda() and not current_platform.is_dcu():
|
||||||
from fastdeploy.model_executor.ops.gpu import \
|
from fastdeploy.model_executor.ops.gpu import \
|
||||||
append_attention as append_attention_gpu
|
append_attention as append_attention_gpu
|
||||||
|
|
||||||
|
@@ -37,3 +37,9 @@ if current_platform.is_gcu():
|
|||||||
from .gcu import *
|
from .gcu import *
|
||||||
if hasattr(gcu, '__all__'):
|
if hasattr(gcu, '__all__'):
|
||||||
__all__.extend(gcu.__all__)
|
__all__.extend(gcu.__all__)
|
||||||
|
|
||||||
|
if current_platform.is_dcu():
|
||||||
|
from .dcu import *
|
||||||
|
from . import dcu
|
||||||
|
if hasattr(dcu, '__all__'):
|
||||||
|
__all__.extend(dcu.__all__)
|
22
fastdeploy/model_executor/layers/backends/dcu/__init__.py
Normal file
22
fastdeploy/model_executor/layers/backends/dcu/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
dcu backend methods
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .fused_moe_triton_backends import DCUTritonWeightOnlyMoEMethod
|
||||||
|
from .weight_only import DCUWeightOnlyLinearMethod
|
||||||
|
|
||||||
|
__all__ = ['DCUTritonWeightOnlyMoEMethod', 'DCUWeightOnlyLinearMethod']
|
@@ -0,0 +1,244 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
from fastdeploy.distributed.communication_op import \
|
||||||
|
tensor_model_parallel_all_reduce
|
||||||
|
from fastdeploy.model_executor.layers.utils import (create_hadamard_matrix_map,
|
||||||
|
get_tensor)
|
||||||
|
from fastdeploy.utils import ceil_div
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
|
||||||
|
|
||||||
|
|
||||||
|
class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||||
|
"""
|
||||||
|
Use Triton Group Gemm to compute Fused MoE.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_method=None):
|
||||||
|
"""
|
||||||
|
Triton Group Gemm to compute Fused MoE.
|
||||||
|
"""
|
||||||
|
self.quant_method = quant_method
|
||||||
|
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
|
||||||
|
self.added_scale_attrs = [
|
||||||
|
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
|
||||||
|
]
|
||||||
|
|
||||||
|
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
|
||||||
|
"""process_prequanted_weights"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def create_weights(self, layer: nn.Layer, state_dict):
|
||||||
|
"""
|
||||||
|
Triton MoE create weight process.
|
||||||
|
"""
|
||||||
|
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||||
|
assert len(ffn1_weights) == layer.num_local_experts
|
||||||
|
assert len(ffn2_weights) == layer.num_local_experts
|
||||||
|
assert self.quant_method.name() == "wint8"
|
||||||
|
assert ffn1_weights[0].shape == [
|
||||||
|
layer.hidden_size, layer.moe_intermediate_size * 2
|
||||||
|
]
|
||||||
|
assert ffn2_weights[0].shape == [
|
||||||
|
layer.moe_intermediate_size, layer.hidden_size
|
||||||
|
]
|
||||||
|
|
||||||
|
ffn1_tensor = paddle.stack(ffn1_weights, axis=0)
|
||||||
|
ffn2_tensor = paddle.stack(ffn2_weights, axis=0)
|
||||||
|
|
||||||
|
if self.quant_method.name() == "wint8":
|
||||||
|
max_bound = 127
|
||||||
|
elif self.quant_method.name() == "wint4":
|
||||||
|
max_bound = 7
|
||||||
|
|
||||||
|
for idx, weight_tensor in enumerate([ffn1_tensor, ffn2_tensor]):
|
||||||
|
weight_name = self.added_weight_attrs[idx]
|
||||||
|
scale_name = self.added_scale_attrs[idx]
|
||||||
|
|
||||||
|
quanted_weight_scale = weight_tensor.abs().max(axis=1)
|
||||||
|
quanted_weight = weight_tensor / quanted_weight_scale[:,
|
||||||
|
None, :] * max_bound
|
||||||
|
quanted_weight = paddle.round(quanted_weight).astype("int8")
|
||||||
|
quanted_weight_scale = quanted_weight_scale / max_bound
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
layer, weight_name,
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=quanted_weight.shape,
|
||||||
|
dtype=quanted_weight.dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
))
|
||||||
|
getattr(layer, weight_name).set_value(quanted_weight)
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
layer, scale_name,
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=quanted_weight_scale.shape,
|
||||||
|
dtype=quanted_weight_scale.dtype,
|
||||||
|
))
|
||||||
|
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: nn.Layer,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
gate_out: paddle.Tensor,
|
||||||
|
) -> paddle.Tensor:
|
||||||
|
"""
|
||||||
|
Triton compute Fused MoE.
|
||||||
|
"""
|
||||||
|
token_num = x.shape[0]
|
||||||
|
top_k = layer.top_k
|
||||||
|
num_local_experts = layer.num_local_experts
|
||||||
|
top_k = layer.top_k
|
||||||
|
moe_intermediate_size = layer.moe_intermediate_size
|
||||||
|
hidden_size = layer.hidden_size
|
||||||
|
|
||||||
|
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
|
||||||
|
scores = paddle.nn.functional.softmax(gate_out, axis=-1)
|
||||||
|
scores += layer.gate_correction_bias
|
||||||
|
topk_weights, topk_ids = paddle.topk(scores,
|
||||||
|
k=top_k,
|
||||||
|
axis=-1,
|
||||||
|
sorted=False)
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True)
|
||||||
|
|
||||||
|
intermediate_cache1 = paddle.empty(
|
||||||
|
[token_num * top_k, moe_intermediate_size * 2],
|
||||||
|
dtype=x.dtype,
|
||||||
|
)
|
||||||
|
intermediate_cache2 = paddle.empty(
|
||||||
|
(token_num * top_k, moe_intermediate_size),
|
||||||
|
dtype=x.dtype,
|
||||||
|
)
|
||||||
|
intermediate_cache3 = paddle.empty(
|
||||||
|
(token_num * top_k, hidden_size),
|
||||||
|
dtype=x.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
}
|
||||||
|
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
|
||||||
|
|
||||||
|
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||||
|
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
|
||||||
|
topk_ids, num_local_experts, config["BLOCK_SIZE_M"])
|
||||||
|
max_num_tokens_padded = sorted_token_ids.shape[0]
|
||||||
|
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
|
||||||
|
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
|
||||||
|
|
||||||
|
fused_moe_kernel_paddle[grid](
|
||||||
|
x,
|
||||||
|
layer.moe_ffn1_weight,
|
||||||
|
intermediate_cache1,
|
||||||
|
None,
|
||||||
|
layer.moe_ffn1_weight_scale,
|
||||||
|
None,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
moe_intermediate_size * 2,
|
||||||
|
hidden_size,
|
||||||
|
max_num_tokens_padded,
|
||||||
|
token_num * top_k,
|
||||||
|
stride_am=x.strides[0],
|
||||||
|
stride_ak=x.strides[1],
|
||||||
|
stride_be=layer.moe_ffn1_weight.strides[0],
|
||||||
|
stride_bk=layer.moe_ffn1_weight.strides[1],
|
||||||
|
stride_bn=layer.moe_ffn1_weight.strides[2],
|
||||||
|
stride_cm=intermediate_cache1.strides[0],
|
||||||
|
stride_cn=intermediate_cache1.strides[1],
|
||||||
|
#
|
||||||
|
stride_asm=-1,
|
||||||
|
stride_ask=-1,
|
||||||
|
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
|
||||||
|
stride_bsk=-1,
|
||||||
|
stride_bsn=layer.moe_ffn1_weight_scale.strides[1],
|
||||||
|
group_n=-1,
|
||||||
|
group_k=-1,
|
||||||
|
# Meta-parameters
|
||||||
|
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||||
|
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
||||||
|
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
||||||
|
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
||||||
|
MUL_ROUTED_WEIGHT=False,
|
||||||
|
top_k=top_k,
|
||||||
|
compute_type_enum=1,
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
use_int8_w8a16=True,
|
||||||
|
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
|
||||||
|
intermediate_cache1)
|
||||||
|
|
||||||
|
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
|
||||||
|
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
|
||||||
|
fused_moe_kernel_paddle[grid](
|
||||||
|
intermediate_cache2,
|
||||||
|
layer.moe_ffn2_weight,
|
||||||
|
intermediate_cache3,
|
||||||
|
None,
|
||||||
|
layer.moe_ffn2_weight_scale,
|
||||||
|
topk_weights,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
hidden_size,
|
||||||
|
moe_intermediate_size,
|
||||||
|
max_num_tokens_padded,
|
||||||
|
token_num * top_k,
|
||||||
|
stride_am=intermediate_cache2.strides[0],
|
||||||
|
stride_ak=intermediate_cache2.strides[1],
|
||||||
|
stride_be=layer.moe_ffn2_weight.strides[0],
|
||||||
|
stride_bk=layer.moe_ffn2_weight.strides[1],
|
||||||
|
stride_bn=layer.moe_ffn2_weight.strides[2],
|
||||||
|
stride_cm=intermediate_cache3.strides[0],
|
||||||
|
stride_cn=intermediate_cache3.strides[1],
|
||||||
|
stride_asm=-1,
|
||||||
|
stride_ask=-1,
|
||||||
|
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
|
||||||
|
stride_bsk=-1,
|
||||||
|
stride_bsn=layer.moe_ffn2_weight_scale.strides[1],
|
||||||
|
group_n=-1,
|
||||||
|
group_k=-1,
|
||||||
|
# Meta-parameters
|
||||||
|
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||||
|
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
||||||
|
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
||||||
|
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
||||||
|
MUL_ROUTED_WEIGHT=True,
|
||||||
|
top_k=1,
|
||||||
|
compute_type_enum=1,
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
use_int8_w8a16=True,
|
||||||
|
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
|
||||||
|
out = intermediate_cache3.sum(axis=1)
|
||||||
|
|
||||||
|
if layer.tp_size > 1:
|
||||||
|
tensor_model_parallel_all_reduce(out)
|
||||||
|
return out
|
@@ -0,0 +1,198 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def fused_moe_kernel_paddle(
|
||||||
|
a_ptr,
|
||||||
|
b_ptr,
|
||||||
|
c_ptr,
|
||||||
|
a_scale_ptr,
|
||||||
|
b_scale_ptr,
|
||||||
|
topk_weights_ptr,
|
||||||
|
sorted_token_ids_ptr,
|
||||||
|
expert_ids_ptr,
|
||||||
|
num_tokens_post_padded_ptr,
|
||||||
|
|
||||||
|
# Matrix dimensions
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
num_valid_tokens,
|
||||||
|
stride_am,
|
||||||
|
stride_ak,
|
||||||
|
stride_be,
|
||||||
|
stride_bk,
|
||||||
|
stride_bn,
|
||||||
|
stride_cm,
|
||||||
|
stride_cn,
|
||||||
|
stride_asm,
|
||||||
|
stride_ask,
|
||||||
|
stride_bse,
|
||||||
|
stride_bsk,
|
||||||
|
stride_bsn,
|
||||||
|
# Block size for block-wise fp8 quantization
|
||||||
|
group_n: tl.constexpr,
|
||||||
|
group_k: tl.constexpr,
|
||||||
|
# Meta-parameters
|
||||||
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
|
GROUP_SIZE_M: tl.constexpr,
|
||||||
|
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||||
|
top_k: tl.constexpr,
|
||||||
|
compute_type_enum: tl.constexpr,
|
||||||
|
use_fp8_w8a8: tl.constexpr,
|
||||||
|
use_int8_w8a16: tl.constexpr,
|
||||||
|
even_Ks: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Key Parameters:
|
||||||
|
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
||||||
|
be any shape representing batches and K is the feature dimension of
|
||||||
|
each token.
|
||||||
|
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
||||||
|
the number of experts, K is the input feature dimension, and N is
|
||||||
|
the output feature dimension.
|
||||||
|
- C: The output cache tensor with shape (M, topk, N), where M is the
|
||||||
|
total number of tokens post padding, topk is the number of times
|
||||||
|
each token is repeated, and N is the output feature dimension.
|
||||||
|
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
||||||
|
repeated topk times and arranged by the expert index they are
|
||||||
|
assigned to.
|
||||||
|
- expert_ids: A tensor containing the indices of the expert for each
|
||||||
|
block. It determines which expert matrix from B should be used for
|
||||||
|
each block in A.
|
||||||
|
This kernel performs the multiplication of a token by its corresponding
|
||||||
|
expert matrix as determined by `expert_ids`. The sorting of
|
||||||
|
`sorted_token_ids` by expert index and padding ensures divisibility by
|
||||||
|
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
||||||
|
multiplication across different blocks processed by the same expert.
|
||||||
|
"""
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
num_pid_m = tl.cdiv(num_tokens_post_padded, BLOCK_SIZE_M)
|
||||||
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||||
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||||
|
group_id = pid // num_pid_in_group
|
||||||
|
first_pid_m = group_id * GROUP_SIZE_M
|
||||||
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||||
|
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||||
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||||
|
|
||||||
|
assert compute_type_enum == 1
|
||||||
|
compute_type = tl.bfloat16
|
||||||
|
|
||||||
|
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
||||||
|
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||||
|
return
|
||||||
|
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
|
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||||
|
token_mask = offs_token < num_valid_tokens
|
||||||
|
|
||||||
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||||
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
||||||
|
offs_k[None, :] * stride_ak)
|
||||||
|
|
||||||
|
off_experts = tl.load(expert_ids_ptr + pid_m)
|
||||||
|
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
|
||||||
|
offs_bn[None, :] * stride_bn)
|
||||||
|
|
||||||
|
if use_int8_w8a16:
|
||||||
|
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[
|
||||||
|
None, :] * stride_bsn
|
||||||
|
b_scale = tl.load(b_scale_ptrs)
|
||||||
|
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
if group_k > 0 and group_n > 0:
|
||||||
|
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
||||||
|
offs_bsn = offs_bn // group_n
|
||||||
|
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
||||||
|
else:
|
||||||
|
# (Zkk): every expert has one activation scale and weight scale.
|
||||||
|
a_scale = tl.load(a_scale_ptr + off_experts)
|
||||||
|
b_scale = tl.load(b_scale_ptr + off_experts)
|
||||||
|
|
||||||
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
|
||||||
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||||
|
if even_Ks:
|
||||||
|
a = tl.load(
|
||||||
|
a_ptrs,
|
||||||
|
mask=token_mask[:, None],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
b = tl.load(b_ptrs,
|
||||||
|
cache_modifier=".cv",
|
||||||
|
eviction_policy='evict_first')
|
||||||
|
else:
|
||||||
|
a = tl.load(
|
||||||
|
a_ptrs,
|
||||||
|
mask=token_mask[:, None] &
|
||||||
|
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
b = tl.load(b_ptrs,
|
||||||
|
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
|
# We accumulate along the K dimension.
|
||||||
|
if use_int8_w8a16:
|
||||||
|
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
||||||
|
elif use_fp8_w8a8:
|
||||||
|
if group_k > 0 and group_n > 0:
|
||||||
|
k_start = k * BLOCK_SIZE_K
|
||||||
|
offs_ks = k_start // group_k
|
||||||
|
a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask,
|
||||||
|
mask=token_mask,
|
||||||
|
other=0.0)
|
||||||
|
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
||||||
|
|
||||||
|
accumulator += tl.dot(a, b) * a_scale[:,
|
||||||
|
None] * b_scale[None, :]
|
||||||
|
else:
|
||||||
|
accumulator = tl.dot(a, b, acc=accumulator)
|
||||||
|
else:
|
||||||
|
accumulator += tl.dot(a, b)
|
||||||
|
|
||||||
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||||
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
|
|
||||||
|
if MUL_ROUTED_WEIGHT:
|
||||||
|
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
||||||
|
mask=token_mask,
|
||||||
|
other=0)
|
||||||
|
accumulator = accumulator * moe_weight[:, None]
|
||||||
|
if use_int8_w8a16:
|
||||||
|
accumulator = (accumulator * b_scale).to(compute_type)
|
||||||
|
elif use_fp8_w8a8:
|
||||||
|
if group_k > 0 and group_n > 0:
|
||||||
|
accumulator = accumulator.to(compute_type)
|
||||||
|
else:
|
||||||
|
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||||
|
else:
|
||||||
|
accumulator = accumulator.to(compute_type)
|
||||||
|
# Write back the block of the output
|
||||||
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
||||||
|
None, :]
|
||||||
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||||
|
|
||||||
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
46
fastdeploy/model_executor/layers/backends/dcu/weight_only.py
Normal file
46
fastdeploy/model_executor/layers/backends/dcu/weight_only.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
import paddle
|
||||||
|
from paddle.nn.quant import weight_dequantize
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig, GPUWeightOnlyLinearMethod
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DCUWeightOnlyLinearMethod(GPUWeightOnlyLinearMethod):
|
||||||
|
"""
|
||||||
|
Weight only quantization method for linear layer on GPU
|
||||||
|
The weights are loaded in the BF16 numerical format. After loading, the quantization coefficients will be computed,
|
||||||
|
and the weights will be quantized to int8 or int4.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
quant_config: WeightOnlyConfig,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(quant_config)
|
||||||
|
|
||||||
|
def apply(self, layer, x):
|
||||||
|
dequant_out = weight_dequantize(
|
||||||
|
x=layer.linear_weight,
|
||||||
|
scale=layer.linear_weight_scale,
|
||||||
|
algo=self.quant_config.algo,
|
||||||
|
out_dtype=paddle.get_default_dtype()
|
||||||
|
)
|
||||||
|
linear_out = paddle.matmul(x, dequant_out)
|
||||||
|
if layer.linear_bias is not None:
|
||||||
|
linear_out = paddle.add(linear_out, layer.linear_bias)
|
||||||
|
return linear_out
|
@@ -27,7 +27,7 @@ from fastdeploy.platforms import current_platform
|
|||||||
from ..utils import create_and_set_parameter, get_tensor
|
from ..utils import create_and_set_parameter, get_tensor
|
||||||
from .fused_moe_backend_base import MoEMethodBase
|
from .fused_moe_backend_base import MoEMethodBase
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda() and not current_platform.is_dcu():
|
||||||
from fastdeploy.model_executor.ops.gpu import (moe_expert_dispatch,
|
from fastdeploy.model_executor.ops.gpu import (moe_expert_dispatch,
|
||||||
moe_expert_reduce, noaux_tc)
|
moe_expert_reduce, noaux_tc)
|
||||||
elif current_platform.is_iluvatar():
|
elif current_platform.is_iluvatar():
|
||||||
|
@@ -75,6 +75,15 @@ class WeightOnlyConfig(QuantConfigBase):
|
|||||||
return GCUWeightOnlyMoEMethod(self)
|
return GCUWeightOnlyMoEMethod(self)
|
||||||
else:
|
else:
|
||||||
return GCUWeightOnlyLinearMethod(self)
|
return GCUWeightOnlyLinearMethod(self)
|
||||||
|
elif current_platform.is_dcu():
|
||||||
|
if isinstance(layer, FusedMoE):
|
||||||
|
from fastdeploy.model_executor.layers.backends import (
|
||||||
|
DCUTritonWeightOnlyMoEMethod)
|
||||||
|
return DCUTritonWeightOnlyMoEMethod(self)
|
||||||
|
else:
|
||||||
|
from fastdeploy.model_executor.layers.backends import (
|
||||||
|
DCUWeightOnlyLinearMethod)
|
||||||
|
return DCUWeightOnlyLinearMethod(self)
|
||||||
else:
|
else:
|
||||||
if isinstance(layer, FusedMoE):
|
if isinstance(layer, FusedMoE):
|
||||||
if layer.use_method == "cutlass":
|
if layer.use_method == "cutlass":
|
||||||
|
@@ -39,7 +39,7 @@ from fastdeploy.model_executor.models.ernie4_5_moe import (Ernie4_5_Attention,
|
|||||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda() and not current_platform.is_dcu():
|
||||||
from fastdeploy.model_executor.ops.gpu import (extract_text_token_output,
|
from fastdeploy.model_executor.ops.gpu import (extract_text_token_output,
|
||||||
text_image_gather_scatter,
|
text_image_gather_scatter,
|
||||||
text_image_index_out)
|
text_image_index_out)
|
||||||
|
@@ -29,6 +29,10 @@ elif current_platform.is_gcu():
|
|||||||
save_output,
|
save_output,
|
||||||
set_stop_value_multi_ends,
|
set_stop_value_multi_ends,
|
||||||
update_inputs)
|
update_inputs)
|
||||||
|
elif current_platform.is_dcu():
|
||||||
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
|
get_padding_offset, save_output, set_stop_value_multi_ends,
|
||||||
|
step_paddle, update_inputs)
|
||||||
else:
|
else:
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
get_padding_offset, save_output, set_stop_value_multi_ends,
|
get_padding_offset, save_output, set_stop_value_multi_ends,
|
||||||
|
@@ -33,14 +33,14 @@ def __getattr__(name: str):
|
|||||||
# lazy init current_platform.
|
# lazy init current_platform.
|
||||||
global _current_platform
|
global _current_platform
|
||||||
if _current_platform is None:
|
if _current_platform is None:
|
||||||
if paddle.is_compiled_with_cuda():
|
if paddle.is_compiled_with_rocm():
|
||||||
|
_current_platform = DCUPlatform()
|
||||||
|
elif paddle.is_compiled_with_cuda():
|
||||||
_current_platform = CUDAPlatform()
|
_current_platform = CUDAPlatform()
|
||||||
elif paddle.is_compiled_with_xpu():
|
elif paddle.is_compiled_with_xpu():
|
||||||
_current_platform = XPUPlatform()
|
_current_platform = XPUPlatform()
|
||||||
elif paddle.is_compiled_with_custom_device("npu"):
|
elif paddle.is_compiled_with_custom_device("npu"):
|
||||||
_current_platform = NPUPlatform()
|
_current_platform = NPUPlatform()
|
||||||
elif paddle.is_compiled_with_rocm():
|
|
||||||
_current_platform = DCUPlatform()
|
|
||||||
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||||
_current_platform = IluvatarPlatform()
|
_current_platform = IluvatarPlatform()
|
||||||
elif paddle.is_compiled_with_custom_device("gcu"):
|
elif paddle.is_compiled_with_custom_device("gcu"):
|
||||||
|
@@ -25,6 +25,7 @@ class _Backend(enum.Enum):
|
|||||||
APPEND_ATTN = enum.auto()
|
APPEND_ATTN = enum.auto()
|
||||||
MLA_ATTN = enum.auto()
|
MLA_ATTN = enum.auto()
|
||||||
FLASH_ATTN = enum.auto()
|
FLASH_ATTN = enum.auto()
|
||||||
|
BLOCK_ATTN = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
class Platform:
|
class Platform:
|
||||||
|
@@ -14,7 +14,9 @@
|
|||||||
"""
|
"""
|
||||||
dcu platform file
|
dcu platform file
|
||||||
"""
|
"""
|
||||||
from .base import Platform
|
import paddle
|
||||||
|
from .base import Platform, _Backend
|
||||||
|
from paddleformers.utils.log import logger
|
||||||
|
|
||||||
|
|
||||||
class DCUPlatform(Platform):
|
class DCUPlatform(Platform):
|
||||||
@@ -22,3 +24,38 @@ class DCUPlatform(Platform):
|
|||||||
dcu platform class
|
dcu platform class
|
||||||
"""
|
"""
|
||||||
device_name = "dcu"
|
device_name = "dcu"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def available(self):
|
||||||
|
"""
|
||||||
|
Check whether CUDA is available.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
assert len(paddle.static.cuda_places()) > 0
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"You are using GPU version PaddlePaddle, but there is no GPU "
|
||||||
|
"detected on your machine. Maybe CUDA devices is not set properly."
|
||||||
|
f"\n Original Error is {e}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_attention_backend_cls(
|
||||||
|
cls,
|
||||||
|
selected_backend
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
get_attention_backend_cls
|
||||||
|
"""
|
||||||
|
if selected_backend == _Backend.NATIVE_ATTN:
|
||||||
|
logger.info("Using NATIVE ATTN backend.")
|
||||||
|
return ("fastdeploy.model_executor.layers.attention.PaddleNativeAttnBackend")
|
||||||
|
elif selected_backend == _Backend.BLOCK_ATTN:
|
||||||
|
logger.info("Using BLOCK ATTN backend.")
|
||||||
|
return ("fastdeploy.model_executor.layers.attention.BlockAttentionBackend")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Other backends are not supported for now."
|
||||||
|
)
|
||||||
|
112
fastdeploy/worker/dcu_worker.py
Normal file
112
fastdeploy/worker/dcu_worker.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
import gc
|
||||||
|
import time
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
|
||||||
|
from fastdeploy.config import FDConfig
|
||||||
|
from fastdeploy.engine.request import Request
|
||||||
|
from fastdeploy.utils import get_logger
|
||||||
|
from fastdeploy.worker.gpu_model_runner import GPUModelRunner
|
||||||
|
from fastdeploy.worker.output import ModelRunnerOutput
|
||||||
|
from fastdeploy.worker.gpu_worker import GpuWorker
|
||||||
|
|
||||||
|
logger = get_logger("dcu_worker", "dcu_worker.log")
|
||||||
|
|
||||||
|
|
||||||
|
class DcuWorker(GpuWorker):
|
||||||
|
""" """
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fd_config: FDConfig,
|
||||||
|
local_rank: int,
|
||||||
|
rank: int,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
fd_config=fd_config,
|
||||||
|
local_rank=local_rank,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def determine_available_memory(self) -> int:
|
||||||
|
"""
|
||||||
|
Profiles the peak memory usage of the model to determine how much
|
||||||
|
memory can be used for KV cache without OOMs.
|
||||||
|
|
||||||
|
The engine will first conduct a profiling of the existing memory usage.
|
||||||
|
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||||
|
that can be allocated with the remaining free memory.
|
||||||
|
|
||||||
|
Tip:
|
||||||
|
You may limit the usage of GPU memory
|
||||||
|
by adjusting the `gpu_memory_utilization` parameter.
|
||||||
|
"""
|
||||||
|
# 1. Record memory state before profile run
|
||||||
|
Gb = 1024**3
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
paddle.device.cuda.reset_max_memory_reserved(self.local_rank)
|
||||||
|
paddle.device.cuda.reset_max_memory_allocated(self.local_rank)
|
||||||
|
paddle_reserved_mem_before_run = paddle.device.cuda.max_memory_reserved(
|
||||||
|
self.local_rank)
|
||||||
|
paddle_allocated_mem_before_run = paddle.device.cuda.max_memory_allocated(
|
||||||
|
self.local_rank) # not reserved
|
||||||
|
|
||||||
|
total_gpu_memory = paddle.device.cuda.get_device_properties(self.local_rank).total_memory
|
||||||
|
before_used_gpu_memory = paddle.device.cuda.memory_allocated(self.local_rank)
|
||||||
|
|
||||||
|
|
||||||
|
logger.info((
|
||||||
|
"Before running the profile, the memory usage info is as follows:",
|
||||||
|
f"\nDevice Total memory: {total_gpu_memory / Gb}",
|
||||||
|
f"\nDevice used memory: {before_used_gpu_memory / Gb}",
|
||||||
|
f"\nPaddle reserved memory: {paddle_reserved_mem_before_run / Gb}",
|
||||||
|
f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}"))
|
||||||
|
|
||||||
|
# 2. Profile run
|
||||||
|
self.model_runner.profile_run()
|
||||||
|
|
||||||
|
# 3. Statistical memory information
|
||||||
|
paddle_reserved_mem_after_run = paddle.device.cuda.max_memory_reserved(
|
||||||
|
self.local_rank)
|
||||||
|
paddle_allocated_mem_after_run = paddle.device.cuda.max_memory_allocated(
|
||||||
|
self.local_rank)
|
||||||
|
|
||||||
|
after_used_gpu_memory = paddle.device.cuda.memory_allocated(self.local_rank)
|
||||||
|
|
||||||
|
# v0 worker
|
||||||
|
model_block_memory_used = self.cal_theortical_kvcache()
|
||||||
|
paddle.device.cuda.empty_cache()
|
||||||
|
paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run
|
||||||
|
available_kv_cache_memory = total_gpu_memory * \
|
||||||
|
self.parallel_config.gpu_memory_utilization - after_used_gpu_memory - paddle_peak_increase
|
||||||
|
available_kv_cache_memory += model_block_memory_used * self.parallel_config.max_block_num
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
("After running the profile, the memory usage info is as follows:",
|
||||||
|
f"\nDevice Total memory: {total_gpu_memory / Gb}",
|
||||||
|
f"\nDevice used memory: {after_used_gpu_memory / Gb}",
|
||||||
|
f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}",
|
||||||
|
f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}",
|
||||||
|
f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}",
|
||||||
|
f"Profile time: {end_time - start_time}"))
|
||||||
|
|
||||||
|
return available_kv_cache_memory # return to caculate the block num in this device
|
@@ -41,6 +41,8 @@ from fastdeploy.model_executor.pre_and_post_process import (post_process,
|
|||||||
pre_process,
|
pre_process,
|
||||||
rebuild_padding,
|
rebuild_padding,
|
||||||
step_cuda)
|
step_cuda)
|
||||||
|
from fastdeploy.platforms import current_platform
|
||||||
|
if not current_platform.is_dcu():
|
||||||
from fastdeploy.spec_decode import MTPProposer, NgramProposer
|
from fastdeploy.spec_decode import MTPProposer, NgramProposer
|
||||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||||
from fastdeploy.worker.model_runner_base import ModelRunnerBase
|
from fastdeploy.worker.model_runner_base import ModelRunnerBase
|
||||||
|
@@ -42,6 +42,9 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase:
|
|||||||
"""
|
"""
|
||||||
get worker of different device
|
get worker of different device
|
||||||
"""
|
"""
|
||||||
|
if current_platform.is_dcu():
|
||||||
|
from fastdeploy.worker.dcu_worker import DcuWorker
|
||||||
|
return DcuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
from fastdeploy.worker.gpu_worker import GpuWorker
|
from fastdeploy.worker.gpu_worker import GpuWorker
|
||||||
return GpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
|
return GpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
|
||||||
|
29
requirements_dcu.txt
Normal file
29
requirements_dcu.txt
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
setuptools>=62.3.0,<80.0
|
||||||
|
pre-commit
|
||||||
|
yapf
|
||||||
|
flake8
|
||||||
|
ruamel.yaml
|
||||||
|
zmq
|
||||||
|
aiozmq
|
||||||
|
openai
|
||||||
|
tqdm
|
||||||
|
pynvml
|
||||||
|
uvicorn
|
||||||
|
fastapi
|
||||||
|
paddleformers
|
||||||
|
redis
|
||||||
|
etcd3
|
||||||
|
httpx
|
||||||
|
tool_helpers
|
||||||
|
pybind11[global]
|
||||||
|
tabulate
|
||||||
|
gradio
|
||||||
|
xlwt
|
||||||
|
visualdl
|
||||||
|
setuptools-scm>=8
|
||||||
|
prometheus-client
|
||||||
|
decord
|
||||||
|
moviepy
|
||||||
|
use-triton-in-paddle
|
||||||
|
crcmod
|
||||||
|
fastsafetensors==0.1.14
|
2
setup.py
2
setup.py
@@ -146,6 +146,8 @@ def load_requirements():
|
|||||||
requirements_file_name = 'requirements.txt'
|
requirements_file_name = 'requirements.txt'
|
||||||
if paddle.is_compiled_with_custom_device('iluvatar_gpu'):
|
if paddle.is_compiled_with_custom_device('iluvatar_gpu'):
|
||||||
requirements_file_name = 'requirements_iluvatar.txt'
|
requirements_file_name = 'requirements_iluvatar.txt'
|
||||||
|
elif paddle.is_compiled_with_rocm():
|
||||||
|
requirements_file_name = 'requirements_dcu.txt'
|
||||||
requirements_path = os.path.join(os.path.dirname(__file__),
|
requirements_path = os.path.join(os.path.dirname(__file__),
|
||||||
requirements_file_name)
|
requirements_file_name)
|
||||||
with open(requirements_path, 'r') as f:
|
with open(requirements_path, 'r') as f:
|
||||||
|
Reference in New Issue
Block a user