mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 00:06:38 +08:00
Compare commits
97 Commits
copilot/ad
...
feature/on
Author | SHA1 | Date | |
---|---|---|---|
![]() |
b272ca9f83 | ||
![]() |
db653644ad | ||
![]() |
4aa057f28d | ||
![]() |
05b7800d80 | ||
![]() |
12043fc476 | ||
![]() |
acecd5bebe | ||
![]() |
918ccdb123 | ||
![]() |
9845f0d010 | ||
![]() |
c4830ef24c | ||
![]() |
0b62648924 | ||
![]() |
c86945ef49 | ||
![]() |
da74a5f0b3 | ||
![]() |
718f32a6b0 | ||
![]() |
5c33be5a7d | ||
![]() |
91912cc2e1 | ||
![]() |
cc6e14d2ec | ||
![]() |
24180fba0a | ||
![]() |
ee9d8a840a | ||
![]() |
66a98b44ed | ||
![]() |
a685e5ad35 | ||
![]() |
ddf5606263 | ||
![]() |
c3b8ebeb18 | ||
![]() |
62b8b02e08 | ||
![]() |
98447beb4d | ||
![]() |
618ccdbfba | ||
![]() |
2745f37017 | ||
![]() |
896e3bb606 | ||
![]() |
0d3a57a2c6 | ||
![]() |
b52971749c | ||
![]() |
2adca04f1f | ||
![]() |
f9766f917b | ||
![]() |
2e9e53ff7e | ||
![]() |
c01a756912 | ||
![]() |
cd09913552 | ||
![]() |
67e6d8c691 | ||
![]() |
de8638b1e9 | ||
![]() |
4f8901489c | ||
![]() |
e79a1a7938 | ||
![]() |
d682c97dd3 | ||
![]() |
8e49d99009 | ||
![]() |
83bf1fd5aa | ||
![]() |
b70ca35c0b | ||
![]() |
befe463f01 | ||
![]() |
442543cd6b | ||
![]() |
ed2dcec829 | ||
![]() |
a04365a0c7 | ||
![]() |
03b3d6175d | ||
![]() |
17a27170bc | ||
![]() |
113e330030 | ||
![]() |
69aa2781a1 | ||
![]() |
46911f903d | ||
![]() |
b1b33211e8 | ||
![]() |
9409665713 | ||
![]() |
29ed617f0f | ||
![]() |
b1a5b756a3 | ||
![]() |
4408dc7f67 | ||
![]() |
ef4a1aa2da | ||
![]() |
f213ae1e86 | ||
![]() |
553adb299e | ||
![]() |
958abebeab | ||
![]() |
4871f18dad | ||
![]() |
987609c894 | ||
![]() |
9ac539471d | ||
![]() |
88ea565aba | ||
![]() |
c86b3357ce | ||
![]() |
06f4b49ca3 | ||
![]() |
805f29a06c | ||
![]() |
cab7a633fe | ||
![]() |
58e0785bab | ||
![]() |
8466219ec8 | ||
![]() |
82dab8a91a | ||
![]() |
37f1632732 | ||
![]() |
4859f40b20 | ||
![]() |
2056a428bd | ||
![]() |
850465e8ed | ||
![]() |
a47976e82d | ||
![]() |
abdcef30aa | ||
![]() |
d2ec7f6aa2 | ||
![]() |
fec58639db | ||
![]() |
d2d04c2d5e | ||
![]() |
d60f7c4661 | ||
![]() |
e4c64a71cc | ||
![]() |
2650f58740 | ||
![]() |
2af0f671b1 | ||
![]() |
a7392a0ff9 | ||
![]() |
637d96c6ae | ||
![]() |
7ee100903f | ||
![]() |
684e93269b | ||
![]() |
276f73cf83 | ||
![]() |
d3e4ae3d49 | ||
![]() |
453487d5b0 | ||
![]() |
9d0074a91a | ||
![]() |
c3b2a60fb8 | ||
![]() |
dbab579299 | ||
![]() |
f078a959b6 | ||
![]() |
3b1da6e4dd | ||
![]() |
b3fac5bde1 |
1
.github/workflows/_accuracy_test.yml
vendored
1
.github/workflows/_accuracy_test.yml
vendored
@@ -160,6 +160,7 @@ jobs:
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
pushd tests/ce/deploy
|
||||
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
|
||||
python3.10 deploy.py > dd.log 2>&1 &
|
||||
sleep 3
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
||||
|
1
.github/workflows/_base_test.yml
vendored
1
.github/workflows/_base_test.yml
vendored
@@ -160,6 +160,7 @@ jobs:
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
pushd tests/ce/deploy
|
||||
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
|
||||
python3.10 deploy.py > dd.log 2>&1 &
|
||||
sleep 3
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
||||
|
3
.github/workflows/_build_linux.yml
vendored
3
.github/workflows/_build_linux.yml
vendored
@@ -55,7 +55,7 @@ on:
|
||||
jobs:
|
||||
fd-build:
|
||||
runs-on: [self-hosted, GPU-Build]
|
||||
timeout-minutes: 240
|
||||
timeout-minutes: 360
|
||||
outputs:
|
||||
wheel_path: ${{ steps.set_output.outputs.wheel_path }}
|
||||
steps:
|
||||
@@ -134,6 +134,7 @@ jobs:
|
||||
fi
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
chown -R $(whoami) /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
if [[ "${WITH_NIGHTLY_BUILD}" == "ON" ]];then
|
||||
GIT_COMMIT_TIME=$(git --no-pager show -s --format=%ci HEAD)
|
||||
|
2
.github/workflows/_ci_image_build.yml
vendored
2
.github/workflows/_ci_image_build.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
||||
outputs:
|
||||
docker_name_precheck: ${{ steps.docker_build.outputs.docker_name_precheck }}
|
||||
steps:
|
||||
- name: Code Prepare
|
||||
- name: Docker Build
|
||||
id: docker_build
|
||||
shell: bash
|
||||
env:
|
||||
|
1
.github/workflows/_logprob_test_linux.yml
vendored
1
.github/workflows/_logprob_test_linux.yml
vendored
@@ -147,6 +147,7 @@ jobs:
|
||||
--skip install
|
||||
|
||||
cd PaddleTest/framework/ServeTest
|
||||
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
|
||||
python3.10 deploy.py > dd.log 2>&1 &
|
||||
sleep 3
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
||||
|
3
.github/workflows/_pre_ce_test.yml
vendored
3
.github/workflows/_pre_ce_test.yml
vendored
@@ -82,6 +82,9 @@ jobs:
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||
FD_ZMQ_RECV_REQUEST_SERVER_PORT=$((42048 + DEVICE_PORT * 100))
|
||||
FD_ZMQ_SEND_RESPONSE_SERVER_PORT=$((42038 + DEVICE_PORT * 100))
|
||||
FD_ZMQ_CONTROL_CMD_SERVER_PORTS=$((42028 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
|
2
.github/workflows/_unit_test_coverage.yml
vendored
2
.github/workflows/_unit_test_coverage.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
|
||||
run_tests_with_coverage:
|
||||
runs-on: [self-hosted, GPU-h1z1-2Cards]
|
||||
timeout-minutes: 60
|
||||
timeout-minutes: 90
|
||||
needs: check_cov_skip
|
||||
if: needs.check_cov_skip.outputs.can-skip != 'true'
|
||||
outputs:
|
||||
|
2
.github/workflows/ce_job.yml
vendored
2
.github/workflows/ce_job.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.ref }}-${{ github.sha }}
|
||||
group: CE-Job-${{ github.ref }}-${{ github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
|
2
.github/workflows/ci_image_update.yml
vendored
2
.github/workflows/ci_image_update.yml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.ref }}-${{ github.sha }}
|
||||
group: CI-Images-Build-${{ github.ref }}-${{ github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
|
||||
|
2
.github/workflows/publish_job.yml
vendored
2
.github/workflows/publish_job.yml
vendored
@@ -13,7 +13,7 @@ on:
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.ref }}-${{ github.sha }}
|
||||
group: Publish-Job-${{ github.ref }}-${{ github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
|
||||
|
10
.gitmodules
vendored
Normal file
10
.gitmodules
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
[submodule "custom_ops/third_party/DeepGEMM"]
|
||||
path = custom_ops/third_party/DeepGEMM
|
||||
url = https://github.com/deepseek-ai/DeepGEMM.git
|
||||
ignore = all
|
||||
[submodule "custom_ops/third_party/cutlass"]
|
||||
path = custom_ops/third_party/cutlass
|
||||
url = https://github.com/NVIDIA/cutlass.git
|
||||
[submodule "custom_ops/third_party/nlohmann_json"]
|
||||
path = custom_ops/third_party/nlohmann_json
|
||||
url = https://github.com/nlohmann/json.git
|
@@ -98,7 +98,7 @@ def main(args):
|
||||
raise ValueError("--max_concurrency should be same length as --s_itl_base_model")
|
||||
|
||||
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
|
||||
# Wramup
|
||||
# Warmup
|
||||
print("Starting warmup...")
|
||||
with open(os.devnull, "w") as f:
|
||||
with contextlib.redirect_stdout(f):
|
||||
|
@@ -6,3 +6,4 @@ tensor_parallel_size: 8
|
||||
max_num_batched_tokens: 4096
|
||||
max_num_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
||||
quantization: wint4
|
||||
|
6
benchmarks/yaml/eb45-128k-wint4-tp1-plas.yaml
Normal file
6
benchmarks/yaml/eb45-128k-wint4-tp1-plas.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
tensor_parallel_size: 1
|
||||
max_model_len: 131072
|
||||
max_num_seqs: 32
|
||||
quantization: wint4
|
||||
max_num_batched_tokens: 8192
|
||||
plas_attention_config: '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}'
|
@@ -6,3 +6,4 @@ tensor_parallel_size: 8
|
||||
max_num_batched_tokens: 4096
|
||||
max_num_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
||||
quantization: wint8
|
||||
|
5
benchmarks/yaml/eb45-32k-wint2-tp4.yaml
Normal file
5
benchmarks/yaml/eb45-32k-wint2-tp4.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 256
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 4
|
||||
gpu_memory_utilization: 0.9
|
@@ -13,3 +13,4 @@ pd_comm_port: "2334"
|
||||
max_num_batched_tokens: 384
|
||||
max_num_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
||||
quantization: wint4
|
||||
|
@@ -10,3 +10,4 @@ engine_worker_queue_port: 6677
|
||||
cache_transfer_protocol: "rdma,ipc"
|
||||
rdma_comm_ports: "7675,7676,7677,7678"
|
||||
pd_comm_port: "2333"
|
||||
quantization: wint4
|
||||
|
11
benchmarks/yaml/eb45-vl-128k-wint4-h800-tp8.yaml
Normal file
11
benchmarks/yaml/eb45-vl-128k-wint4-h800-tp8.yaml
Normal file
@@ -0,0 +1,11 @@
|
||||
enable_mm: True
|
||||
max_model_len: 131072
|
||||
max_num_seqs: 56
|
||||
gpu_memory_utilization: 0.8
|
||||
kv_cache_ratio: 0.8
|
||||
tensor_parallel_size: 8
|
||||
quantization: wint4
|
||||
limit_mm_per_prompt: '{"image": 100, "video": 100}'
|
||||
enable_chunked_prefill: True
|
||||
max_num_batched_tokens: 384
|
||||
reasoning_parser: ernie-45-vl
|
@@ -1,7 +1,7 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 36
|
||||
gpu_memory_utilization: 0.95
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.8
|
||||
tensor_parallel_size: 8
|
||||
quantization: wint8
|
||||
|
@@ -1,7 +1,7 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 36
|
||||
gpu_memory_utilization: 0.8
|
||||
gpu_memory_utilization: 0.85
|
||||
kv_cache_ratio: 0.8
|
||||
tensor_parallel_size: 8
|
||||
quantization: wint8
|
||||
|
9
benchmarks/yaml/eb45-vl-lite-32k-bf16-a800-tp1.yaml
Normal file
9
benchmarks/yaml/eb45-vl-lite-32k-bf16-a800-tp1.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 1
|
||||
enable_chunked_prefill: True
|
||||
max_num_batched_tokens: 384
|
||||
reasoning_parser: ernie-45-vl
|
10
benchmarks/yaml/eb45-vl-lite-32k-wint4-a800-tp1.yaml
Normal file
10
benchmarks/yaml/eb45-vl-lite-32k-wint4-a800-tp1.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 1
|
||||
enable_chunked_prefill: True
|
||||
max_num_batched_tokens: 384
|
||||
quantization: wint4
|
||||
reasoning_parser: ernie-45-vl
|
10
benchmarks/yaml/eb45-vl-lite-32k-wint8-a800-tp1.yaml
Normal file
10
benchmarks/yaml/eb45-vl-lite-32k-wint8-a800-tp1.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 1
|
||||
enable_chunked_prefill: True
|
||||
max_num_batched_tokens: 384
|
||||
quantization: wint8
|
||||
reasoning_parser: ernie-45-vl
|
1
benchmarks/yaml/request_yaml/eb45-vl-128k.yaml
Normal file
1
benchmarks/yaml/request_yaml/eb45-vl-128k.yaml
Normal file
@@ -0,0 +1 @@
|
||||
max_tokens: 131071
|
1
benchmarks/yaml/request_yaml/eb45-vl-32k.yaml
Normal file
1
benchmarks/yaml/request_yaml/eb45-vl-32k.yaml
Normal file
@@ -0,0 +1 @@
|
||||
max_tokens: 12288
|
@@ -2,7 +2,7 @@ top_p: 0.95
|
||||
temperature: 0.6
|
||||
metadata:
|
||||
min_tokens: 1
|
||||
max_tokens: 65535
|
||||
max_tokens: 131071
|
||||
repetition_penalty: 1.0
|
||||
frequency_penalty: 0
|
||||
presence_penalty: 0
|
6
benchmarks/yaml/x1-a3b-128k-wint8-h800-tp1.yaml
Normal file
6
benchmarks/yaml/x1-a3b-128k-wint8-h800-tp1.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
tensor_parallel_size: 1
|
||||
max_model_len: 131072
|
||||
max_num_seqs: 32
|
||||
reasoning_parser: ernie_x1
|
||||
tool_call_parser: ernie_x1
|
||||
load_choices: "default_v1"
|
4
build.sh
4
build.sh
@@ -143,9 +143,9 @@ function build_and_install_ops() {
|
||||
TMP_DIR_REAL_PATH=`readlink -f ${OPS_TMP_DIR}`
|
||||
is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"`
|
||||
if [ "$is_xpu" = "True" ]; then
|
||||
cd xpu_ops/src
|
||||
cd xpu_ops
|
||||
bash build.sh ${TMP_DIR_REAL_PATH}
|
||||
cd ../..
|
||||
cd ..
|
||||
elif [ "$FD_CPU_USE_BF16" == "true" ]; then
|
||||
if [ "$FD_BUILDING_ARCS" == "" ]; then
|
||||
FD_CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
|
||||
|
@@ -428,6 +428,142 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_neox_partial_rope_kernel(
|
||||
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
|
||||
const float* __restrict__ sin_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int rotary_dim,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt,
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadKVT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
|
||||
LoadT left_vec, right_vec;
|
||||
LoadBiasT left_bias_vec, right_bias_vec;
|
||||
LoadKVT left_cache_vec, right_cache_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int half_head_size = head_size / 2;
|
||||
const int half_rotary_dim = rotary_dim / 2;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
|
||||
const int64_t half_hidden_size = hidden_size / 2;
|
||||
// const int64_t offset = 2 * hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int ori_bi = linear_index / half_hidden_size;
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
if (hi < num_heads && h_bias >= half_rotary_dim){
|
||||
continue;
|
||||
}
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
uint32_t ori_idx_left =
|
||||
start_token_idx * hidden_size + hi * head_size + h_bias;
|
||||
uint32_t ori_idx_right = ori_idx_left + half_head_size;
|
||||
if (hi < num_heads){
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}else if (hi < num_heads + kv_num_heads){
|
||||
if (h_bias < half_rotary_dim){
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}else{
|
||||
ori_idx_left = ori_idx_left + half_rotary_dim;
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}
|
||||
}
|
||||
|
||||
Load<T, VecSize>(&qkv[ori_idx_left], &left_vec);
|
||||
Load<T, VecSize>(&qkv[ori_idx_right], &right_vec);
|
||||
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
// q k rope
|
||||
const uint32_t emb_idx = write_seq_id * half_rotary_dim + h_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
|
||||
if (h_bias < half_rotary_dim){
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
// rope
|
||||
float input_left = static_cast<float>(left_vec[i]);
|
||||
float input_right = static_cast<float>(right_vec[i]);
|
||||
if (hi < num_heads + kv_num_heads && h_bias < half_rotary_dim) {
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_bias_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
right_bias_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
} else {
|
||||
left_bias_vec[i] = static_cast<T>(input_left);
|
||||
right_bias_vec[i] = static_cast<T>(input_right);
|
||||
}
|
||||
}
|
||||
if (hi < num_heads) {
|
||||
// write q
|
||||
Store<T, VecSize>(left_bias_vec, &qkv_out[ori_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &qkv_out[ori_idx_right]);
|
||||
} else {
|
||||
// write k/v
|
||||
const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads;
|
||||
uint32_t tgt_idx_left =
|
||||
block_idx * kv_num_heads * block_size * head_size +
|
||||
kv_head_idx * block_size * head_size + block_offset * head_size +
|
||||
h_bias;
|
||||
uint32_t tgt_idx_right = tgt_idx_left + half_head_size;
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
if (h_bias < half_rotary_dim) {
|
||||
tgt_idx_right = tgt_idx_left + half_rotary_dim;
|
||||
}else{
|
||||
tgt_idx_left = tgt_idx_left + half_rotary_dim;
|
||||
tgt_idx_right = tgt_idx_left + half_rotary_dim;
|
||||
}
|
||||
Store<T, VecSize>(left_bias_vec, &key_cache[tgt_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &key_cache[tgt_idx_right]);
|
||||
} else {
|
||||
Store<T, VecSize>(left_bias_vec, &value_cache[tgt_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &value_cache[tgt_idx_right]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
@@ -913,7 +1049,7 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
local_max = __hmax(local_max, __habs(out_vec2[i]));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int m_offset = 16; m_offset > 1; m_offset /= 2) {
|
||||
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
|
||||
local_max = __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
|
||||
}
|
||||
|
||||
|
@@ -94,6 +94,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int rotary_dim,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const cudaStream_t& stream,
|
||||
@@ -133,7 +134,29 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_decode_cache_T_neox_rope_kernel<T, PackSize>
|
||||
if (rotary_dim < dim_head){
|
||||
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}else{
|
||||
append_decode_cache_T_neox_rope_kernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
@@ -152,6 +175,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
@@ -516,11 +540,20 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
const float* cos_emb =
|
||||
rotary_embs ? rotary_embs.get().data<float>() : nullptr;
|
||||
const float* sin_emb;
|
||||
int rotary_dim = dim_head;
|
||||
if (rotary_embs) {
|
||||
sin_emb =
|
||||
use_neox_rotary_style
|
||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
|
||||
if(rotary_dim < dim_head){
|
||||
if (!use_neox_rotary_style || qkv_out_scales || q_norm_weight || k_norm_weight|| cache_quant_type_str != "none"){
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"partial_rotary_factor < 1.0 only supports neox_rotary_style=True, qkv_out_scales is None, q_norm_weight/k_norm_weight) is None, and cache_quant_type_str is 'none'."));
|
||||
}
|
||||
sin_emb = rotary_embs.get().data<float>() + max_seq_len * rotary_dim / 2;
|
||||
}
|
||||
}
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
@@ -609,6 +642,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
|
@@ -900,6 +900,74 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void GQANeoxVariableLengthPartialRotaryKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const float *qkv_out_scales,
|
||||
const T *qkv_biases,
|
||||
T *qkv_out,
|
||||
const int64_t elem_cnt,
|
||||
const int q_num_head,
|
||||
const int kv_num_head,
|
||||
const int seq_len,
|
||||
const int head_dim,
|
||||
const int rotary_dim,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
LoadT left_vec;
|
||||
LoadT right_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int rotary_dim_half = rotary_dim / 2;
|
||||
const int offset = (q_num_head + kv_num_head) * rotary_dim_half;
|
||||
for (int64_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens && seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / rotary_dim_half;
|
||||
const int h_bias = bias % rotary_dim_half;
|
||||
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * rotary_dim_half + h_bias;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * head_dim * seq_len * 2 : emb_idx;
|
||||
const int base_idx_left =
|
||||
token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim +
|
||||
h_bias;
|
||||
const int base_idx_right = base_idx_left + rotary_dim_half;
|
||||
|
||||
Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
|
||||
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
const float input_left = static_cast<float>(left_vec[i]);
|
||||
const float input_right = static_cast<float>(right_vec[i]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
right_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
Store<T, VecSize>(left_vec, &qkv_out[base_idx_left]);
|
||||
Store<T, VecSize>(right_vec, &qkv_out[base_idx_right]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void cache_kernel(
|
||||
const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads,
|
||||
@@ -2160,6 +2228,7 @@ void gqa_rotary_qk_variable(
|
||||
const int seq_len,
|
||||
const int input_output_len,
|
||||
const int dim_head,
|
||||
const int rotary_dim,
|
||||
const cudaStream_t &stream,
|
||||
bool use_neox_style = false,
|
||||
bool rope_3d = false) {
|
||||
@@ -2240,7 +2309,38 @@ void gqa_rotary_qk_variable(
|
||||
dim_head,
|
||||
rope_3d);
|
||||
} else {
|
||||
GQANeoxVariableLengthRotaryKernel<T, PackSize>
|
||||
if (rotary_dim < dim_head){
|
||||
PD_CHECK((rotary_dim / 2) % PackSize == 0);
|
||||
elem_nums =
|
||||
qkv_out_scales
|
||||
? token_num * (num_heads + 2 * kv_num_heads) * rotary_dim
|
||||
: token_num * (num_heads + kv_num_heads) * rotary_dim; // for all q k v
|
||||
if (use_neox_style) {
|
||||
elem_nums /= 2;
|
||||
}
|
||||
const int pack_num_new = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num_new, &grid_size);
|
||||
GQANeoxVariableLengthPartialRotaryKernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
rotary_emb + input_output_len * rotary_dim / 2,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
rope_3d);
|
||||
}else{
|
||||
GQANeoxVariableLengthRotaryKernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
@@ -2258,6 +2358,7 @@ void gqa_rotary_qk_variable(
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -55,9 +55,19 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
auto head_dim = meta_data.head_dims;
|
||||
bool is_scale_channel_wise = false;
|
||||
int rotary_dim = head_dim;
|
||||
if (cache_k_scale && cache_k_scale.get().dims()[0] == head_dim * kv_num_heads) {
|
||||
is_scale_channel_wise = true;
|
||||
}
|
||||
if (rotary_embs){
|
||||
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
|
||||
if(rotary_dim < head_dim){
|
||||
if (!use_neox_style || q_norm_weight || k_norm_weight || num_heads == kv_num_heads || is_scale_channel_wise){
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"partial_rotary_factor < 1.0 only supports use_neox_rotary_style=True, q_norm_weight/k_norm_weight) is None, GQA and is_scale_channel_wise=false."));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) {
|
||||
@@ -125,6 +135,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
max_seq_len,
|
||||
rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2],
|
||||
head_dim,
|
||||
rotary_dim,
|
||||
stream,
|
||||
use_neox_style,
|
||||
rope_3d);
|
||||
|
@@ -11,10 +11,11 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/memory/memcpy.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void
|
||||
@@ -116,6 +117,93 @@ void GetMaxLen(const paddle::Tensor &seq_lens_tensor,
|
||||
max_len_tensor.data<int>(), batch_size);
|
||||
}
|
||||
|
||||
template <uint32_t config_size>
|
||||
__global__ void search_chunk_size_for_mla(
|
||||
const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
int *__restrict__ num_blocks_x,
|
||||
int *__restrict__ res_chunk_size,
|
||||
const int bsz,
|
||||
const int set_chunk_size,
|
||||
const int block_size,
|
||||
const int sm_cout) {
|
||||
const uint32_t conf_id = threadIdx.x;
|
||||
int gridx = 0;
|
||||
if (set_chunk_size > 0 && conf_id == 0) {
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int seq_len = seq_lens_q[bid];
|
||||
int seq_len_encoder = seq_lens_encoder[bid];
|
||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
||||
if (seq_len == 0 || seq_len_encoder > 0) continue;
|
||||
|
||||
int loop_times;
|
||||
loop_times = cute::ceil_div(seq_len_decoder, set_chunk_size);
|
||||
gridx += loop_times;
|
||||
}
|
||||
*num_blocks_x = gridx;
|
||||
*res_chunk_size = set_chunk_size;
|
||||
} else if (conf_id < config_size) {
|
||||
__shared__ int gridx_shared[config_size];
|
||||
// chunk_size is a multiple of 64
|
||||
const int chunk_size = block_size << conf_id;
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int seq_len = seq_lens_q[bid];
|
||||
int seq_len_encoder = seq_lens_encoder[bid];
|
||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
||||
if (seq_len == 0 || seq_len_encoder > 0) continue;
|
||||
|
||||
int loop_times;
|
||||
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
|
||||
gridx += loop_times;
|
||||
}
|
||||
gridx_shared[conf_id] = gridx;
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
uint32_t res_id = 0;
|
||||
uint32_t max_last_wave_block = 0;
|
||||
for (uint32_t i = 1; i < config_size; i++) {
|
||||
uint32_t last_wave_block = gridx_shared[i] % sm_cout;
|
||||
if (last_wave_block >= max_last_wave_block) {
|
||||
res_id = i;
|
||||
max_last_wave_block = last_wave_block;
|
||||
}
|
||||
}
|
||||
*num_blocks_x = gridx_shared[res_id];
|
||||
*res_chunk_size = block_size << res_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void split_block_for_mla(const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
int *__restrict__ batch_ids,
|
||||
int *__restrict__ tile_ids_per_batch,
|
||||
const int bsz,
|
||||
const int chunk_size) {
|
||||
if (threadIdx.x == 0) {
|
||||
int index = 0;
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int seq_len = seq_lens_q[bid];
|
||||
int seq_len_encoder = seq_lens_encoder[bid];
|
||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
||||
|
||||
if (seq_len == 0) continue;
|
||||
|
||||
int loop_times;
|
||||
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
|
||||
if (seq_len_encoder > 0) {
|
||||
loop_times = 0;
|
||||
}
|
||||
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
|
||||
batch_ids[index] = bid;
|
||||
tile_ids_per_batch[index++] = tile_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void split_q_block(const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
int *__restrict__ batch_ids,
|
||||
@@ -197,7 +285,9 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_device, // Inplace
|
||||
paddle::Tensor &decoder_chunk_size_device, // Inplace
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
|
||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||
@@ -230,8 +320,6 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
int max_system_len = max_len_cpu_ptr[6];
|
||||
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
|
||||
|
||||
|
||||
|
||||
auto max_len_kv =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
|
||||
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
|
||||
@@ -241,6 +329,106 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
|
||||
max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false);
|
||||
|
||||
// decoder
|
||||
if (max_dec_len_this_time > 0) {
|
||||
const bool mla_use_tensorcore = GetMlaUseTensorcore();
|
||||
if (mla_use_tensorcore && group_size <= 64) {
|
||||
const int set_chunk_size = get_mla_dec_chunk_size(bsz);
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int sm_cout;
|
||||
cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device);
|
||||
constexpr int config_size =
|
||||
12; // search space for chunk size:[64, 128, 256, ... 131072]
|
||||
|
||||
search_chunk_size_for_mla<config_size>
|
||||
<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
decoder_num_blocks_device.data<int>(),
|
||||
decoder_chunk_size_device.data<int>(),
|
||||
bsz,
|
||||
set_chunk_size,
|
||||
block_size,
|
||||
sm_cout);
|
||||
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
auto decoder_chunk_size_cpu =
|
||||
decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false);
|
||||
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
|
||||
|
||||
// NOTE: (changwenbin) When using auto_chunk,
|
||||
// decode_max_tile_size must take into account the maximum case, where * 1024 can cover 128K.
|
||||
// const uint32_t decoder_batch_shape = seq_lens_decoder.dims()[0] * 1024;
|
||||
|
||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
||||
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape =
|
||||
bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_batch_ids.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
stream));
|
||||
|
||||
|
||||
split_block_for_mla<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
bsz,
|
||||
chunk_size);
|
||||
|
||||
} else {
|
||||
// Note:(changwenbin)In order to adapt to cudagraph, the maximum value should be taken here
|
||||
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape = bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
split_q_block<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_device.data<int>(),
|
||||
bsz,
|
||||
decoder_block_shape_q,
|
||||
group_size);
|
||||
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
}
|
||||
} else {
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
}
|
||||
|
||||
// encoder
|
||||
if (max_enc_len_this_time > 0) {
|
||||
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
|
||||
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
|
||||
@@ -272,28 +460,6 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
|
||||
}
|
||||
|
||||
if (max_just_dec_len_this_time > 0) {
|
||||
// Clear buffer
|
||||
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
auto decoder_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
split_q_block<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_x.data<int>(),
|
||||
bsz,
|
||||
decoder_block_shape_q,
|
||||
group_size);
|
||||
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
@@ -303,7 +469,9 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
"seq_lens_this_time",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks_x_cpu",
|
||||
"decoder_num_blocks_cpu",
|
||||
"decoder_num_blocks_device",
|
||||
"decoder_chunk_size_device",
|
||||
"max_len_tensor_cpu",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
|
@@ -46,7 +46,8 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
const int gqa_group_size,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
const float rms_norm_eps,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
@@ -109,8 +110,9 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
|
@@ -41,7 +41,8 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
const bool use_neox_style,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
const float rms_norm_eps,
|
||||
const bool rope_3d) {
|
||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
||||
const uint32_t elem_nums =
|
||||
use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
||||
@@ -53,7 +54,6 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
|
||||
if (use_neox_style) {
|
||||
PD_THROW(
|
||||
"append_speculate_cache_rope_qk_norm not support neox rope yet");
|
||||
@@ -82,7 +82,8 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
kv_num_heads,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
rms_norm_eps,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -426,7 +427,6 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||
}
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope_qk_norm(
|
||||
@@ -457,11 +457,13 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
use_neox_rotary_style,
|
||||
reinterpret_cast<const float*>(q_norm_weight.get().data<float>()),
|
||||
reinterpret_cast<const float*>(k_norm_weight.get().data<float>()),
|
||||
rms_norm_eps);
|
||||
rms_norm_eps,
|
||||
rope_3d);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
||||
}
|
||||
|
||||
} else {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope(
|
||||
|
@@ -63,7 +63,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const paddle::Tensor &kv_num_blocks,
|
||||
const paddle::Tensor &decoder_batch_ids,
|
||||
const paddle::Tensor &decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor &decoder_num_blocks,
|
||||
const paddle::Tensor &decoder_num_blocks_cpu,
|
||||
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
|
||||
const paddle::optional<paddle::Tensor> &rotary_embs,
|
||||
const paddle::optional<paddle::Tensor> &attn_mask,
|
||||
@@ -105,7 +105,7 @@ void AppendAttentionWithOutput(
|
||||
const paddle::Tensor &kv_num_blocks,
|
||||
const paddle::Tensor &decoder_batch_ids,
|
||||
const paddle::Tensor &decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor &decoder_num_blocks,
|
||||
const paddle::Tensor &decoder_num_blocks_cpu,
|
||||
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
|
||||
paddle::Tensor &fmha_out,
|
||||
const paddle::optional<paddle::Tensor> &rotary_embs,
|
||||
@@ -305,7 +305,9 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_device, // Inplace
|
||||
paddle::Tensor &decoder_chunk_size_device, // Inplace
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||
@@ -414,8 +416,8 @@ std::vector<paddle::Tensor> MoEDeepGEMMDePermute(
|
||||
const paddle::Tensor &topk_idx, const paddle::Tensor &topk_weights);
|
||||
|
||||
void TextImageIndexOut(const paddle::Tensor &token_type_ids,
|
||||
const paddle::Tensor &text_input,
|
||||
const paddle::Tensor &image_input);
|
||||
paddle::Tensor &text_input,
|
||||
paddle::Tensor &image_input);
|
||||
|
||||
void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
|
||||
paddle::Tensor &image_input,
|
||||
@@ -473,23 +475,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& decoder_num_blocks_device,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
|
@@ -303,7 +303,7 @@ class CustomAllreduce {
|
||||
bool full_nvlink_;
|
||||
|
||||
RankSignals sg_;
|
||||
// Stores an map from a pointer to its peer pointters from all ranks.
|
||||
// Stores an map from a pointer to its peer pointers from all ranks.
|
||||
std::unordered_map<void*, RankData*> buffers_;
|
||||
Signal* self_sg_;
|
||||
|
||||
|
@@ -59,6 +59,15 @@ inline uint32_t get_cascade_attention_num_threads() {
|
||||
inline bool get_mla_use_tensorcore() {
|
||||
static const char* mla_use_tensorcore_env = std::getenv("FLAGS_mla_use_tensorcore");
|
||||
static const uint32_t mla_use_tensorcore =
|
||||
mla_use_tensorcore_env == nullptr ? 1 : std::stoul(std::string(mla_use_tensorcore_env));
|
||||
mla_use_tensorcore_env == nullptr ? 0 : std::stoul(std::string(mla_use_tensorcore_env));
|
||||
return mla_use_tensorcore != 0 ? true : false;
|
||||
}
|
||||
inline int get_mla_dec_chunk_size(int bsz) {
|
||||
static const char* mla_dec_chunk_size_env =
|
||||
std::getenv("FLAGS_mla_dec_chunk_size");
|
||||
static const int mla_dec_chunk_size =
|
||||
mla_dec_chunk_size_env == nullptr
|
||||
? -1
|
||||
: std::stoi(std::string(mla_dec_chunk_size_env));
|
||||
return bsz > 1 ? mla_dec_chunk_size : 64;
|
||||
}
|
||||
|
@@ -39,9 +39,6 @@ void GetOutputTopK(const paddle::Tensor& x,
|
||||
int k,
|
||||
int64_t rank_id,
|
||||
bool wait_flag) {
|
||||
if (rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
static struct msgdata msg_rcv;
|
||||
int msg_queue_id = 1;
|
||||
|
@@ -132,7 +132,7 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_padding_offset)
|
||||
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
|
||||
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
|
||||
.Outputs({"x_remove_padding",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
|
@@ -563,3 +563,11 @@ inline int GetSMVersion() {
|
||||
return sm_version;
|
||||
|
||||
}
|
||||
|
||||
inline bool GetMlaUseTensorcore() {
|
||||
static const bool flags_mla_use_tensorcore = get_mla_use_tensorcore();
|
||||
static const bool enable_mla_tensorcore = GetSMVersion() >= 90 ? true : false;
|
||||
const bool mla_use_tensorcore =
|
||||
flags_mla_use_tensorcore && enable_mla_tensorcore;
|
||||
return mla_use_tensorcore;
|
||||
}
|
||||
|
@@ -30,10 +30,12 @@ paddle::Tensor mm(paddle::Tensor const& A, paddle::Tensor const& B,
|
||||
std::optional<paddle::Tensor> const& maybe_token_scales,
|
||||
std::string maybe_schedule) {
|
||||
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
|
||||
std::optional<int64_t> maybe_group_size_opt;
|
||||
std::optional<int64_t> maybe_group_size_opt = std::optional<int64_t>(maybe_group_size);
|
||||
std::optional<std::string> maybe_schedule_opt;
|
||||
if (maybe_schedule == "") {
|
||||
maybe_schedule_opt = std::nullopt;
|
||||
} else {
|
||||
maybe_schedule_opt = std::optional<std::string>(maybe_schedule);
|
||||
}
|
||||
return machete::mm_dispatch({.A = A,
|
||||
.B = B,
|
||||
@@ -63,6 +65,8 @@ std::vector<paddle::Tensor> MacheteMMKernel(
|
||||
paddle::DataType maybe_out_type;
|
||||
if (b_type_str == "uint4b8") {
|
||||
b_type_id = machete::kU4B8.id();
|
||||
} else if (b_type_str == "uint8b128") {
|
||||
b_type_id = machete::kU8B128.id();
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||
}
|
||||
|
@@ -51,6 +51,8 @@ std::vector<paddle::Tensor> MachetePrepackBKernel(
|
||||
|
||||
if (b_type_str == "uint4b8") {
|
||||
b_type_id = machete::kU4B8.id();
|
||||
} else if (b_type_str == "uint8b128") {
|
||||
b_type_id = machete::kU8B128.id();
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||
}
|
||||
|
@@ -70,7 +70,6 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
@@ -78,9 +77,8 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
@@ -97,14 +95,12 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const auto q_head_num = meta_data.q_num_heads;
|
||||
const auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
|
||||
const auto max_block_num = bsz * max_block_num_per_seq;
|
||||
const uint32_t chunk_size = get_max_partition_size(bsz);
|
||||
|
||||
|
||||
int q_head_dim = meta_data.head_dims;
|
||||
int k_head_dim = meta_data.head_dims;
|
||||
int v_head_dim = meta_data.head_dims_v;
|
||||
// int num_chunks = max_dec_len / chunk_size;
|
||||
int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
int num_chunks = div_up(max_seq_len, 64);
|
||||
|
||||
auto *allocator = paddle::GetAllocator(q.place());
|
||||
phi::Allocator::AllocationPtr O_tmp, m_tmp, d_tmp;
|
||||
@@ -127,14 +123,14 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
params.d = reinterpret_cast<float*>(d_tmp->ptr());
|
||||
params.block_tables = const_cast<int*>(block_tables.data<int>());
|
||||
params.seq_lens_this_time = const_cast<int*>(seq_lens_this_time.data<int>());
|
||||
params.seq_lens_encoder = const_cast<int*>(seq_lens_encoder.data<int>());
|
||||
params.seq_lens_decoder = const_cast<int*>(seq_lens_decoder.data<int>());
|
||||
params.cumsum_q_seqlens = const_cast<int*>(cu_seqlens_q.data<int>());
|
||||
params.batch_id_per_token = const_cast<int*>(batch_id_per_token.data<int>());
|
||||
params.batch_ids = const_cast<int*>(batch_ids.data<int>());
|
||||
params.tile_ids_per_batch = const_cast<int*>(tile_ids_per_batch.data<int>());
|
||||
params.num_blocks_x = const_cast<int*>(num_blocks_x_device.data<int>());
|
||||
params.num_blocks_x_int = num_blocks_x;
|
||||
params.chunk_size_device =
|
||||
const_cast<int*>(decoder_chunk_size_device.data<int>());
|
||||
params.q_stride_bsz = q_head_num * q_head_dim;
|
||||
params.q_stride_head_num = q_head_dim;
|
||||
params.kv_stride_block_num = block_size * k_head_dim;
|
||||
@@ -151,7 +147,6 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
params.block_size = block_size;
|
||||
params.max_draft_token_num = draft_token_num;
|
||||
params.sm_scale = softmax_scale;
|
||||
params.chunk_size = chunk_size;
|
||||
params.chunk_num = num_chunks;
|
||||
|
||||
if (q_head_dim == 576) {
|
||||
@@ -176,7 +171,6 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
@@ -184,9 +178,8 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
@@ -210,7 +203,6 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
@@ -218,9 +210,8 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
|
@@ -47,7 +47,6 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
@@ -55,9 +54,8 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
|
@@ -128,12 +128,13 @@ struct CollectiveMainloop {
|
||||
DTypeMD const* d_ptr;
|
||||
IdType const* kv_block_tables;
|
||||
IdType const* seq_lens_this_time;
|
||||
IdType const* seq_lens_encoder;
|
||||
// IdType const* seq_lens_encoder;
|
||||
IdType const* seq_lens_decoder;
|
||||
IdType const* cumsum_q_seqlens;
|
||||
IdType const* batch_ids;
|
||||
IdType const* tile_ids_per_batch;
|
||||
IdType const* num_blocks_x;
|
||||
IdType const* chunk_size_device;
|
||||
float sm_scale;
|
||||
int bsz;
|
||||
int max_block_num;
|
||||
@@ -144,7 +145,7 @@ struct CollectiveMainloop {
|
||||
int kv_stride_block_size;
|
||||
int o_stride_bsz;
|
||||
int o_stride_head_num;
|
||||
int chunk_size;
|
||||
// int chunk_size;
|
||||
int chunk_num;
|
||||
int max_draft_token_num;
|
||||
};
|
||||
@@ -160,12 +161,13 @@ struct CollectiveMainloop {
|
||||
DTypeMD* d_ptr;
|
||||
IdType* kv_block_tables;
|
||||
IdType* seq_lens_this_time;
|
||||
IdType* seq_lens_encoder;
|
||||
// IdType* seq_lens_encoder;
|
||||
IdType* seq_lens_decoder;
|
||||
IdType* cumsum_q_seqlens;
|
||||
IdType* batch_ids;
|
||||
IdType* tile_ids_per_batch;
|
||||
IdType* num_blocks_x;
|
||||
IdType* chunk_size_device;
|
||||
float sm_scale;
|
||||
int bsz;
|
||||
int max_block_num;
|
||||
@@ -176,7 +178,7 @@ struct CollectiveMainloop {
|
||||
int kv_stride_block_size;
|
||||
int o_stride_bsz;
|
||||
int o_stride_head_num;
|
||||
int chunk_size;
|
||||
// int chunk_size;
|
||||
int chunk_num;
|
||||
int max_draft_token_num;
|
||||
TMA_KV tma_load_KV;
|
||||
@@ -198,12 +200,13 @@ struct CollectiveMainloop {
|
||||
const_cast<DTypeMD*>(args.d_ptr),
|
||||
const_cast<IdType*>(args.kv_block_tables),
|
||||
const_cast<IdType*>(args.seq_lens_this_time),
|
||||
const_cast<IdType*>(args.seq_lens_encoder),
|
||||
// const_cast<IdType*>(args.seq_lens_encoder),
|
||||
const_cast<IdType*>(args.seq_lens_decoder),
|
||||
const_cast<IdType*>(args.cumsum_q_seqlens),
|
||||
const_cast<IdType*>(args.batch_ids),
|
||||
const_cast<IdType*>(args.tile_ids_per_batch),
|
||||
const_cast<IdType*>(args.num_blocks_x),
|
||||
const_cast<IdType*>(args.chunk_size_device),
|
||||
args.sm_scale,
|
||||
args.bsz,
|
||||
args.max_block_num,
|
||||
@@ -214,7 +217,7 @@ struct CollectiveMainloop {
|
||||
args.kv_stride_block_size,
|
||||
args.o_stride_bsz,
|
||||
args.o_stride_head_num,
|
||||
args.chunk_size,
|
||||
// args.chunk_size,
|
||||
args.chunk_num,
|
||||
args.max_draft_token_num,
|
||||
tma_load_KV
|
||||
@@ -281,9 +284,9 @@ struct CollectiveMainloop {
|
||||
auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
|
||||
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
|
||||
|
||||
@@ -322,9 +325,9 @@ struct CollectiveMainloop {
|
||||
group_modes<0, 2>(sK), group_modes<0, 2>(gKV));
|
||||
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
|
||||
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
|
||||
|
||||
|
@@ -57,7 +57,7 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params,
|
||||
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
|
||||
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
|
||||
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size_device[0]);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
@@ -84,9 +84,9 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params,
|
||||
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
|
||||
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
|
||||
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
int kv_tile_idx = end_tile_idx;
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
@@ -263,7 +263,7 @@ CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params,
|
||||
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
|
||||
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
|
||||
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size_device[0]);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
@@ -295,9 +295,9 @@ CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params,
|
||||
Tensor tOrV4 = threadMmaPVSS.partition_fragment_B(sVt_s4);
|
||||
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
|
||||
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
int kv_tile_idx = end_tile_idx;
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
|
@@ -62,13 +62,12 @@ struct Params {
|
||||
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
|
||||
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeO *O_tmp; // [num_chunks, bsz, head_num, dim_head]
|
||||
alignas(16) float *m; // [num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) float *d; // [num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head]
|
||||
alignas(16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
||||
|
||||
alignas(16) IdType *block_tables;
|
||||
alignas(16) IdType *seq_lens_this_time;
|
||||
alignas(16) IdType *seq_lens_encoder;
|
||||
alignas(16) IdType *seq_lens_decoder;
|
||||
alignas(16) IdType *cumsum_q_seqlens;
|
||||
alignas(16) IdType *batch_id_per_token;
|
||||
@@ -76,7 +75,7 @@ struct Params {
|
||||
alignas(16) IdType *batch_ids;
|
||||
alignas(16) IdType *tile_ids_per_batch;
|
||||
alignas(16) IdType *num_blocks_x;
|
||||
|
||||
alignas(16) IdType *chunk_size_device;
|
||||
|
||||
uint32_t q_stride_bsz;
|
||||
uint32_t q_stride_head_num;
|
||||
@@ -96,9 +95,7 @@ struct Params {
|
||||
int vo_head_dim;
|
||||
int block_size;
|
||||
int max_draft_token_num;
|
||||
int chunk_size;
|
||||
int chunk_num;
|
||||
int num_blocks_x_int;
|
||||
|
||||
float sm_scale;
|
||||
};
|
||||
@@ -118,7 +115,7 @@ struct Params {
|
||||
return cudaErrorNotSupported; \
|
||||
}
|
||||
|
||||
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
||||
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1)
|
||||
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
typename CollectiveMainloop::Params const mainloop_params,
|
||||
@@ -137,6 +134,7 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
|
||||
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
|
||||
const int num_blocks_x = mainloop_params.num_blocks_x[0];
|
||||
const int chunk_size = mainloop_params.chunk_size_device[0];
|
||||
|
||||
static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV;
|
||||
|
||||
@@ -205,58 +203,10 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
|
||||
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
|
||||
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
// load Q
|
||||
collective_mainloop.load_q(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_write_q,
|
||||
shared_storage,
|
||||
threadIdx.x,
|
||||
bid);
|
||||
|
||||
if constexpr (!use_tma_load_kv) {
|
||||
// load kv
|
||||
collective_mainloop.load_kv(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
} else {
|
||||
if (warp_idx_in_warpgroup == 0) {
|
||||
// load kv tma
|
||||
collective_mainloop.load_kv_tma(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int block_id = blockIdx.x;
|
||||
const int bid = mainloop_params.batch_ids[block_id];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
@@ -309,76 +259,12 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
|
||||
|
||||
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale);
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
clear(tOrO);
|
||||
clear(attention_updater.scores_scale);
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
if constexpr (BLOCK_SHAPE_KV == 64) {
|
||||
mma_f16<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
} else if (BLOCK_SHAPE_KV == 32) {
|
||||
mma_f16_two_stages<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
}
|
||||
|
||||
collective_epilogue.store(
|
||||
epilogue_params,
|
||||
tOrO,
|
||||
attention_updater.get_lse(),
|
||||
shared_storage,
|
||||
tiled_mma_pv,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
mainloop_params.bsz,
|
||||
seq_len_now,
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
mainloop_params.chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
} else {
|
||||
const int block_id = blockIdx.x;
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
clear(tOrO);
|
||||
clear(attention_updater.scores_scale);
|
||||
const int bid = mainloop_params.batch_ids[block_id];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
@@ -429,15 +315,15 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
mainloop_params.chunk_size,
|
||||
chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
||||
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
cudaStream_t stream) {
|
||||
using DTypeQ = typename KernelTraits::DTypeQ;
|
||||
@@ -460,12 +346,12 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
params.d,
|
||||
params.block_tables,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_encoder,
|
||||
params.seq_lens_decoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_ids,
|
||||
params.tile_ids_per_batch,
|
||||
params.num_blocks_x,
|
||||
params.chunk_size_device,
|
||||
params.sm_scale,
|
||||
params.bsz,
|
||||
params.max_block_num,
|
||||
@@ -476,7 +362,6 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
params.kv_stride_block_size,
|
||||
params.o_stride_bsz,
|
||||
params.o_stride_head_num,
|
||||
params.chunk_size,
|
||||
params.chunk_num,
|
||||
params.max_draft_token_num
|
||||
});
|
||||
@@ -500,13 +385,9 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
|
||||
|
||||
int gridx;
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
gridx = multiprocessor_count;
|
||||
} else {
|
||||
gridx = params.num_blocks_x_int;
|
||||
}
|
||||
dim3 grid_dims = {gridx, 1, 1};
|
||||
// NOTE: (changwenbin) Here the grid size is fixed so that MLA can be captured
|
||||
// by the graph.
|
||||
dim3 grid_dims = {multiprocessor_count, 1, 1};
|
||||
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
|
||||
dim3 block_dims(ctaSize, 1, 1);
|
||||
kernel<<<grid_dims, block_dims, smem_size, stream>>>(
|
||||
@@ -517,37 +398,38 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
constexpr int merge_block_size = 256;
|
||||
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
|
||||
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
|
||||
dim3 grids_merge(min(multiprocessor_count, params.token_num), params.q_num_head); // 128k is too large
|
||||
dim3 grids_merge(multiprocessor_count, params.q_num_head); // 128k is too large
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_kernel<NV_TYPE, vec_size, blocky, KernelTraits::HEAD_DIM_VO><<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(params.O_tmp),
|
||||
params.m,
|
||||
params.d,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_decoder,
|
||||
params.seq_lens_encoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_id_per_token,
|
||||
reinterpret_cast<NV_TYPE*>(params.O),
|
||||
params.chunk_num,
|
||||
params.q_num_head,
|
||||
params.chunk_size,
|
||||
params.vo_head_dim,
|
||||
params.token_num,
|
||||
params.bsz,
|
||||
params.max_draft_token_num
|
||||
);
|
||||
merge_multi_chunks_kernel<NV_TYPE,
|
||||
vec_size,
|
||||
blocky,
|
||||
KernelTraits::HEAD_DIM_VO>
|
||||
<<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE *>(params.O_tmp),
|
||||
params.m,
|
||||
params.d,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_decoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_id_per_token,
|
||||
params.chunk_size_device,
|
||||
reinterpret_cast<NV_TYPE *>(params.O),
|
||||
params.q_num_head,
|
||||
params.vo_head_dim,
|
||||
params.token_num,
|
||||
params.bsz,
|
||||
params.max_draft_token_num);
|
||||
}
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
||||
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) {
|
||||
constexpr bool CAUSAL = true;
|
||||
if constexpr (HEAD_DIM_QK == 576) {
|
||||
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE,
|
||||
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
|
||||
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/false,
|
||||
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true,
|
||||
HEAD_DIM_QK,
|
||||
HEAD_DIM_VO,
|
||||
GROUP_SIZE,
|
||||
|
@@ -249,18 +249,16 @@ struct prefill_softmax_state_t {
|
||||
};
|
||||
|
||||
template <typename T, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
|
||||
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [num_chunks, bsz, max_draft_token, num_heads, head_dim]
|
||||
const float * __restrict__ multi_m, // [num_chunks, bsz, max_draft_token, num_heads]
|
||||
const float * __restrict__ multi_d, // [num_chunks, bsz, max_draft_token, num_heads]
|
||||
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [max_num_chunks, bsz, max_draft_token, num_heads, head_dim]
|
||||
const float * __restrict__ multi_m, // [max_num_chunks, bsz, max_draft_token, num_heads]
|
||||
const float * __restrict__ multi_d, // [max_num_chunks, bsz, max_draft_token, num_heads]
|
||||
const int * __restrict__ seq_lens_this_time,
|
||||
const int * __restrict__ seq_lens_decoder,
|
||||
const int * __restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int * __restrict__ batch_id_per_token,
|
||||
const int * __restrict__ chunk_size_device,
|
||||
T * __restrict__ out, // [token_num, num_heads, head_dim]
|
||||
const int num_chunks,
|
||||
const int num_heads,
|
||||
const int chunk_size,
|
||||
const int head_dim,
|
||||
const int token_num,
|
||||
const int bsz,
|
||||
@@ -271,13 +269,15 @@ __global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
||||
const uint32_t bid = batch_id_per_token[qid];
|
||||
// NOTE : (changwenbin) Batch_id_per_token is initialized to [:]=-1, Marking meaningless batch IDs.
|
||||
if (bid == -1) continue;
|
||||
const int seq_len_q = seq_lens_this_time[bid];
|
||||
if (seq_len_q == 0) continue;
|
||||
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
|
||||
int seq_len_kv = seq_lens_decoder[bid];
|
||||
if (seq_len_kv == 0) continue;
|
||||
seq_len_kv += seq_len_q;
|
||||
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size);
|
||||
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size_device[0]);
|
||||
if (num_chunks_this_seq <= 1) {
|
||||
// not need merge
|
||||
continue;
|
||||
|
@@ -33,6 +33,11 @@
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 3: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 3; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 6: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 6; \
|
||||
__VA_ARGS__ \
|
||||
@@ -448,137 +453,71 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
||||
auto place = input.place();
|
||||
const int gridx = min(132 * 8, num_rows);
|
||||
if (moe_quant_type == "w4a8") {
|
||||
if (num_experts_per_rank == 8) {
|
||||
permute_x_kernel<data_t, int8_t, 8><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<int8_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);
|
||||
} else if (num_experts_per_rank == 16) {
|
||||
permute_x_kernel<data_t, int8_t, 16><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<int8_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);
|
||||
}
|
||||
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
|
||||
permute_x_kernel<data_t, int8_t, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<int8_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);)
|
||||
} else if (moe_quant_type == "w4afp8") {
|
||||
if (num_experts_per_rank == 8) {
|
||||
permute_x_kernel<data_t, data_t_fp8, 8, 512><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t_fp8>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
448.0f,
|
||||
-448.0f
|
||||
);
|
||||
} else if (num_experts_per_rank == 16) {
|
||||
permute_x_kernel<data_t, data_t_fp8, 16, 512><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t_fp8>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
448.0f,
|
||||
-448.0f
|
||||
);
|
||||
}
|
||||
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
|
||||
permute_x_kernel<data_t, data_t_fp8, NUM_EXPERTS_PER_RANK, 512><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t_fp8>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
448.0f,
|
||||
-448.0f
|
||||
);)
|
||||
} else {
|
||||
if (num_experts_per_rank == 8) {
|
||||
permute_x_kernel<data_t, data_t, 8><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);
|
||||
} else if (num_experts_per_rank == 16) {
|
||||
permute_x_kernel<data_t, data_t, 16><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);
|
||||
}
|
||||
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
|
||||
permute_x_kernel<data_t, data_t, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);)
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -236,7 +236,7 @@ public:
|
||||
num_experts, k, stream);
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher<float, int>::run(
|
||||
topk_gating_softmax_kernelLauncher<float, int>(
|
||||
gating_output, nullptr, expert_scales_float, softmax_out_,
|
||||
expert_for_source_row, source_rows_, softmax_max_prob, num_rows,
|
||||
num_experts, k, group_moe, stream);
|
||||
@@ -248,7 +248,7 @@ public:
|
||||
permuted_experts_, source_rows_, permuted_rows_, k * num_rows,
|
||||
false, stream);
|
||||
|
||||
initialize_moe_routing_kernelLauncher<T>::run(
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input_activations, permuted_data_, permuted_rows_, nullptr, nullptr,
|
||||
expanded_source_row_to_expanded_dest_row, num_rows, num_rows,
|
||||
hidden_size, k, stream);
|
||||
@@ -335,14 +335,14 @@ public:
|
||||
num_experts, down_proj_quant_args, stream);
|
||||
}
|
||||
|
||||
finalize_moe_routing_kernelLauncher<T>::run(
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
fc2_result, output_, fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row, expert_for_source_row,
|
||||
num_rows, hidden_size, k, static_cast<int>(1), norm_topk_prob,
|
||||
routed_scaling_factor, stream);
|
||||
} else {
|
||||
finalize_moe_routing_kernelLauncher<T>::run(
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
// fc2_result,
|
||||
fc1_out, output_,
|
||||
fc1_expert_biases, // fc2_expert_biases,
|
||||
|
@@ -1139,9 +1139,7 @@ void topk_gating_softmax_launcher_helper(const T* input,
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = int>
|
||||
struct topk_gating_softmax_kernelLauncher{
|
||||
|
||||
static void run(const T* input,
|
||||
void topk_gating_softmax_kernelLauncher(const T* input,
|
||||
const T* gating_correction_bias,
|
||||
T* output,
|
||||
T* softmax,
|
||||
@@ -1221,7 +1219,6 @@ static void run(const T* input,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ========================== Permutation things
|
||||
// =======================================
|
||||
@@ -1316,9 +1313,7 @@ __global__ void initialize_moe_routing_kernel(
|
||||
}
|
||||
|
||||
template <typename T, typename OutT = T>
|
||||
struct initialize_moe_routing_kernelLauncher{
|
||||
|
||||
static void run(
|
||||
void initialize_moe_routing_kernelLauncher(
|
||||
const T* unpermuted_input,
|
||||
OutT* permuted_output,
|
||||
const int* expanded_dest_row_to_expanded_source_row,
|
||||
@@ -1361,7 +1356,6 @@ static void run(
|
||||
num_rows * k);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ============================== Infer GEMM sizes
|
||||
// =================================
|
||||
@@ -1472,8 +1466,7 @@ __global__ void finalize_moe_routing_kernel(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct finalize_moe_routing_kernelLauncher{
|
||||
static void run(
|
||||
void finalize_moe_routing_kernelLauncher(
|
||||
const T* expanded_permuted_rows,
|
||||
T* reduced_unpermuted_output,
|
||||
const T* bias,
|
||||
@@ -1505,5 +1498,4 @@ static void run(
|
||||
routed_scaling_factor,
|
||||
num_rows);
|
||||
}
|
||||
};
|
||||
} // namespace phi
|
||||
|
@@ -36,6 +36,9 @@ void MoeDispatchKernel(
|
||||
paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token) {
|
||||
using namespace phi;
|
||||
|
||||
if (num_rows == 0){
|
||||
return;
|
||||
}
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
@@ -100,7 +103,7 @@ void MoeDispatchKernel(
|
||||
softmax_out_ = nullptr;
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher<float, int>::run(
|
||||
topk_gating_softmax_kernelLauncher(
|
||||
gating_output.data<float>(),
|
||||
gating_correction_bias ? gating_correction_bias.get().data<float>()
|
||||
: nullptr,
|
||||
@@ -114,13 +117,13 @@ void MoeDispatchKernel(
|
||||
|
||||
if (w4a8_in_scale) {
|
||||
if (permute_input->dtype() == paddle::DataType::INT8) {
|
||||
initialize_moe_routing_kernelLauncher<data_t, int8_t>::run(
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input.data<data_t>(), permute_input->data<int8_t>(), permuted_rows_,
|
||||
expert_idx_per_token->data<int32_t>(), w4a8_in_scale->data<float>(),
|
||||
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
|
||||
hidden_size, moe_topk, stream);
|
||||
} else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) {
|
||||
initialize_moe_routing_kernelLauncher<data_t, float8_e4m3fn>::run(
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input.data<data_t>(), permute_input->data<float8_e4m3fn>(),
|
||||
permuted_rows_, expert_idx_per_token->data<int32_t>(),
|
||||
w4a8_in_scale->data<float>(),
|
||||
@@ -128,7 +131,7 @@ void MoeDispatchKernel(
|
||||
hidden_size, moe_topk, stream);
|
||||
}
|
||||
} else {
|
||||
initialize_moe_routing_kernelLauncher<data_t>::run(
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input.data<data_t>(), permute_input->data<data_t>(), permuted_rows_,
|
||||
expert_idx_per_token->data<int32_t>(), nullptr,
|
||||
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
|
||||
@@ -185,6 +188,15 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
auto expert_idx_per_token =
|
||||
GetEmptyTensor({num_rows * moe_topk}, paddle::DataType::INT32, place);
|
||||
|
||||
if (token_rows == 0){
|
||||
return {permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
permute_indices_per_token,
|
||||
topk_weight,
|
||||
topk_idx,
|
||||
expert_idx_per_token};
|
||||
}
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeDispatchKernel<paddle::DataType::BFLOAT16>(
|
||||
|
@@ -412,7 +412,9 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
|
||||
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
|
||||
permute_input.dtype();
|
||||
auto ffn_out = paddle::empty_like(permute_input, t_type);
|
||||
|
||||
if(permute_input.numel() == 0){
|
||||
return ffn_out;
|
||||
}
|
||||
switch (t_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeFFNKernel<paddle::DataType::BFLOAT16>(permute_input,
|
||||
|
@@ -36,7 +36,7 @@ void MoeReduceKernel(const paddle::Tensor &ffn_out,
|
||||
typedef typename traits_::data_t data_t;
|
||||
auto stream = ffn_out.stream();
|
||||
|
||||
finalize_moe_routing_kernelLauncher<data_t>::run(
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
ffn_out.data<data_t>(), output->data<data_t>(),
|
||||
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
|
||||
top_k_weight.data<float>(), permute_indices_per_token.data<int32_t>(),
|
||||
@@ -59,6 +59,10 @@ paddle::Tensor MoeExpertReduceFunc(
|
||||
|
||||
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
|
||||
|
||||
if(num_rows == 0){
|
||||
return output;
|
||||
}
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeReduceKernel<paddle::DataType::BFLOAT16>(
|
||||
|
@@ -22,23 +22,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
@@ -64,9 +59,12 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
int decoder_num_blocks_data = decoder_num_blocks_cpu.data<int>()[0];
|
||||
// NOTE: (changwenbin) In cuda graph, it will be fixed in the capture stage
|
||||
// int decoder_num_blocks_data = decoder_num_blocks_cpu.data<int>()[0];
|
||||
int max_dec_len_this_time_data = max_dec_len_this_time.data<int>()[0];
|
||||
int max_len_kv_data = max_len_kv.data<int>()[0];
|
||||
// int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
|
||||
//
|
||||
|
||||
const bool mla_use_tensorcore = get_mla_use_tensorcore();
|
||||
auto sm_version = GetSMVersion();
|
||||
@@ -96,7 +94,6 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
out_linear_smooths,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
cu_seqlens_q,
|
||||
batch_id_per_token,
|
||||
block_tables,
|
||||
@@ -104,9 +101,8 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
cache_quant_type_str,
|
||||
decoder_num_blocks_data,
|
||||
decoder_chunk_size_device,
|
||||
max_input_length,
|
||||
max_len_kv_data,
|
||||
softmax_scale,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
@@ -145,23 +141,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
@@ -208,23 +199,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
cu_seqlens_q,
|
||||
batch_id_per_token,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
decoder_num_blocks_cpu,
|
||||
max_enc_len_this_time,
|
||||
decoder_chunk_size_device,
|
||||
max_dec_len_this_time,
|
||||
max_len_kv,
|
||||
attn_mask,
|
||||
@@ -254,23 +240,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
cu_seqlens_q,
|
||||
batch_id_per_token,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
decoder_num_blocks_cpu,
|
||||
max_enc_len_this_time,
|
||||
decoder_chunk_size_device,
|
||||
max_dec_len_this_time,
|
||||
max_len_kv,
|
||||
attn_mask,
|
||||
@@ -307,23 +288,18 @@ std::vector<std::vector<int64_t>> MultiHeadLatentAttentionInferShape(
|
||||
const std::vector<int64_t>& query_shape,
|
||||
const std::vector<int64_t>& key_cache_shape,
|
||||
const std::vector<int64_t>& value_cache_shape,
|
||||
const std::vector<int64_t>& seq_lens_encoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_decoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||
const std::vector<int64_t>& batch_id_per_token_shape,
|
||||
const std::vector<int64_t>& block_tables_shape,
|
||||
const std::vector<int64_t>& encoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& encoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& kv_batch_ids_shape,
|
||||
const std::vector<int64_t>& kv_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& kv_num_blocks_shape,
|
||||
const std::vector<int64_t>& decoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& decoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& decoder_num_blocks_cpu_shape,
|
||||
const std::vector<int64_t>& max_enc_len_this_time_shape,
|
||||
const std::vector<int64_t>& decoder_chunk_size_device_shape,
|
||||
const std::vector<int64_t>& max_dec_len_this_time_shape,
|
||||
const std::vector<int64_t>& max_len_kv_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
|
||||
@@ -361,23 +337,18 @@ std::vector<paddle::DataType> MultiHeadLatentAttentionInferDtype(
|
||||
const paddle::DataType& query_dtype,
|
||||
const paddle::DataType& key_cache_dtype,
|
||||
const paddle::DataType& value_cache_dtype,
|
||||
const paddle::DataType& seq_lens_encoder_dtype,
|
||||
const paddle::DataType& seq_lens_decoder_dtype,
|
||||
const paddle::DataType& seq_lens_this_time_dtype,
|
||||
const paddle::DataType& cu_seqlens_q_dtype,
|
||||
const paddle::DataType& batch_id_per_token_dtype,
|
||||
const paddle::DataType& block_tables_dtype,
|
||||
const paddle::DataType& encoder_batch_ids_dtype,
|
||||
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& encoder_num_blocks_dtype,
|
||||
const paddle::DataType& kv_batch_ids_dtype,
|
||||
const paddle::DataType& kv_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& kv_num_blocks_dtype,
|
||||
const paddle::DataType& decoder_batch_ids_dtype,
|
||||
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& decoder_num_blocks_dtype,
|
||||
const paddle::DataType& decoder_num_blocks_cpu_dtype,
|
||||
const paddle::DataType& max_enc_len_this_time_dtype,
|
||||
const paddle::DataType& decoder_chunk_size_device_dtype,
|
||||
const paddle::DataType& max_dec_len_this_time_dtype,
|
||||
const paddle::DataType& max_len_kv_dtype,
|
||||
const paddle::optional<paddle::DataType>& attn_mask_dtype,
|
||||
@@ -415,23 +386,18 @@ PD_BUILD_STATIC_OP(multi_head_latent_attention)
|
||||
.Inputs({"query",
|
||||
"key_cache",
|
||||
"value_cache",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"cu_seqlens_q",
|
||||
"batch_id_per_token",
|
||||
"block_tables",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
"encoder_num_blocks",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
"kv_num_blocks",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks",
|
||||
"decoder_num_blocks_cpu",
|
||||
"max_enc_len_this_time",
|
||||
"decoder_chunk_size_device",
|
||||
"max_dec_len_this_time",
|
||||
"max_len_kv",
|
||||
paddle::Optional("attn_mask"),
|
||||
|
@@ -59,7 +59,7 @@ __global__ void text_image_scatter_kernel(
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using T_Vec = AlignedVector<T, VecSize>;
|
||||
T_Vec input_ptr_vec;
|
||||
T_Vec text_imgaes_vec;
|
||||
T_Vec text_images_vec;
|
||||
|
||||
int64_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int64_t step = blockDim.x * gridDim.x * VecSize;
|
||||
@@ -76,16 +76,20 @@ __global__ void text_image_scatter_kernel(
|
||||
Load<T, VecSize>(input_ptr + input_load_offset, &input_ptr_vec);
|
||||
#pragma unroll
|
||||
for(int vi = 0; vi < VecSize; ++vi) {
|
||||
text_imgaes_vec[vi] = input_ptr_vec[vi];
|
||||
text_images_vec[vi] = input_ptr_vec[vi];
|
||||
}
|
||||
|
||||
if (token_type_ids_num == 0) {
|
||||
int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset;
|
||||
Store<T,VecSize>(text_imgaes_vec, text_gather_ptr + text_load_offset);
|
||||
Store<T,VecSize>(text_images_vec, text_gather_ptr + text_load_offset);
|
||||
|
||||
} else if(token_type_ids_num == 1){
|
||||
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
|
||||
Store<T,VecSize>(text_images_vec, image_gather_ptr + image_load_offset);
|
||||
|
||||
} else {
|
||||
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
|
||||
Store<T,VecSize>(text_imgaes_vec, image_gather_ptr + image_load_offset);
|
||||
// skip cuda graph padding value
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -120,9 +124,12 @@ __global__ void text_image_gather_kernel(
|
||||
int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset;
|
||||
Load<T,VecSize>(text_gather_ptr + text_load_offset, &text_imgaes_vec);
|
||||
|
||||
} else {
|
||||
} else if (token_type_ids_num == 1){
|
||||
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
|
||||
Load<T,VecSize>(image_gather_ptr + image_load_offset, &text_imgaes_vec);
|
||||
} else {
|
||||
// skip cuda graph padding value
|
||||
continue;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@@ -154,7 +161,6 @@ void LaunchTextImageGatherScatter(
|
||||
const int64_t token_num = in_dims[0];
|
||||
const int64_t hidden_size = in_dims[1];
|
||||
|
||||
|
||||
const int VecSize = 16 / sizeof(data_t);
|
||||
const int64_t tot_element_num = token_num * hidden_size;
|
||||
|
||||
@@ -168,7 +174,7 @@ void LaunchTextImageGatherScatter(
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(GetGridSize(tot_pack_num, block_size, kNumWaves, &grid_size_x));
|
||||
dim3 grid_dim = dim3(grid_size_x, 1, 1);
|
||||
if (is_scatter) {
|
||||
text_image_scatter_kernel<DataType_, 8><<<grid_dim, block_size>>>(
|
||||
text_image_scatter_kernel<DataType_, VecSize><<<grid_dim, block_size, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(input.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
|
||||
@@ -179,7 +185,7 @@ void LaunchTextImageGatherScatter(
|
||||
tot_element_num
|
||||
);
|
||||
} else {
|
||||
text_image_gather_kernel<DataType_, 8><<<grid_dim, block_size>>>(
|
||||
text_image_gather_kernel<DataType_, VecSize><<<grid_dim, block_size, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(input.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
|
||||
|
@@ -16,7 +16,7 @@
|
||||
|
||||
template <int VecSize>
|
||||
__global__ void text_image_index_out_kernel(
|
||||
int32_t* token_type_ids,
|
||||
const int32_t* token_type_ids,
|
||||
int32_t* text_index,
|
||||
int32_t* image_index,
|
||||
const int64_t token_num
|
||||
@@ -31,23 +31,27 @@ __global__ void text_image_index_out_kernel(
|
||||
if (token_type_ids[i] == 0) {
|
||||
text_index[i] = text_count;
|
||||
text_count += 1;
|
||||
} else {
|
||||
} else if (token_type_ids[i] == 1) {
|
||||
image_index[i] = images_count;
|
||||
images_count += 1;
|
||||
} else {
|
||||
// skip cuda graph padding value
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TextImageIndexOut(
|
||||
const paddle::Tensor& token_type_ids,
|
||||
const paddle::Tensor& text_index,
|
||||
const paddle::Tensor& image_index) {
|
||||
paddle::Tensor& text_index,
|
||||
paddle::Tensor& image_index) {
|
||||
|
||||
const int64_t token_num = token_type_ids.shape()[0];
|
||||
text_image_index_out_kernel<1><<<1, 1>>>(
|
||||
const_cast<int32_t*>(token_type_ids.data<int32_t>()),
|
||||
const_cast<int32_t*>(text_index.data<int32_t>()),
|
||||
const_cast<int32_t*>(image_index.data<int32_t>()),
|
||||
auto stream = token_type_ids.stream();
|
||||
text_image_index_out_kernel<1><<<1, 1, 0, stream>>>(
|
||||
token_type_ids.data<int32_t>(),
|
||||
text_index.data<int32_t>(),
|
||||
image_index.data<int32_t>(),
|
||||
token_num
|
||||
);
|
||||
}
|
||||
|
@@ -32,7 +32,8 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
const int block_size,
|
||||
bool prefill_one_step_stop) {
|
||||
int thread_idx = threadIdx.x;
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
@@ -54,23 +55,32 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
} else {
|
||||
if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= prompt_lens[thread_idx]) {
|
||||
// decoding
|
||||
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
|
||||
seq_lens_this_time[thread_idx] = 1;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride;
|
||||
input_ids_now[0] = next_tokens[thread_idx];
|
||||
if (prefill_one_step_stop) {
|
||||
// prefill done, stop
|
||||
stop_flags[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
} else{
|
||||
// decoding
|
||||
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
|
||||
seq_lens_this_time[thread_idx] = 1;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride;
|
||||
input_ids_now[0] = next_tokens[thread_idx];
|
||||
|
||||
// to judge whether block is not enough
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) {
|
||||
// should be scheduled by server
|
||||
is_block_step[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx]= 0;
|
||||
stop_flags[thread_idx] = true;
|
||||
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx];
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
// to judge whether block is not enough
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) {
|
||||
// should be scheduled by server
|
||||
is_block_step[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx]= 0;
|
||||
stop_flags[thread_idx] = true;
|
||||
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx];
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
}
|
||||
} else
|
||||
{
|
||||
@@ -110,6 +120,12 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
|
||||
#else
|
||||
auto cu_stream = input_ids.stream();
|
||||
#endif
|
||||
bool prefill_one_step_stop = false;
|
||||
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
|
||||
if (env_p[0] == '1') {
|
||||
prefill_one_step_stop = true;
|
||||
}
|
||||
}
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int now_bsz = seq_lens_this_time.shape()[0];
|
||||
const int input_ids_stride = input_ids.shape()[1];
|
||||
@@ -133,7 +149,8 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
|
||||
max_bsz,
|
||||
input_ids_stride,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
block_size,
|
||||
prefill_one_step_stop);
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
|
@@ -37,6 +37,52 @@ def load_module_from_path(module_name, path):
|
||||
return module
|
||||
|
||||
|
||||
def update_git_repo():
|
||||
try:
|
||||
print("update third party repo...", flush=True)
|
||||
original_dir = os.getcwd()
|
||||
submodule_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
third_party_path = os.path.join(submodule_dir, "third_party")
|
||||
root_path = Path(third_party_path)
|
||||
|
||||
# check if third_party is empty
|
||||
update_third_party = False
|
||||
for dirpath in root_path.iterdir():
|
||||
if dirpath.is_dir():
|
||||
has_content = any(dirpath.iterdir())
|
||||
if not has_content:
|
||||
update_third_party = True
|
||||
|
||||
if update_third_party:
|
||||
os.chdir(submodule_dir)
|
||||
subprocess.run(
|
||||
"git submodule sync --recursive && git submodule update --init --recursive",
|
||||
shell=True,
|
||||
check=True,
|
||||
text=True,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"\033[33m[===WARNING===]third_party directory already exists, skip clone and update.\033[0m",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# apply deep gemm patch
|
||||
deep_gemm_dir = "third_party/DeepGEMM"
|
||||
dst_path = os.path.join(submodule_dir, deep_gemm_dir)
|
||||
patch = "0001-DeepGEMM-95e81b3.patch"
|
||||
patch_source = os.path.join(submodule_dir, patch)
|
||||
patch_destination = os.path.join(dst_path, patch)
|
||||
if not os.path.exists(patch_destination):
|
||||
shutil.copy(patch_source, patch_destination)
|
||||
apply_cmd = ["git", "apply", patch]
|
||||
os.chdir(dst_path)
|
||||
subprocess.run(apply_cmd, check=True)
|
||||
os.chdir(original_dir)
|
||||
except subprocess.CalledProcessError:
|
||||
raise Exception("Git submodule update and apply patch failed. Maybe network connection is poor.")
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent
|
||||
|
||||
# cannot import envs directly because it depends on fastdeploy,
|
||||
@@ -46,6 +92,8 @@ envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "fastdeploy", "envs.
|
||||
archs = json.loads(envs.FD_BUILDING_ARCS)
|
||||
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
|
||||
|
||||
update_git_repo()
|
||||
|
||||
|
||||
def download_and_extract(url, destination_directory):
|
||||
"""
|
||||
@@ -78,52 +126,6 @@ def download_and_extract(url, destination_directory):
|
||||
print(f"Error extracting file: {e}")
|
||||
|
||||
|
||||
def clone_git_repo(version, repo_url, destination_path):
|
||||
"""
|
||||
Clone git repo to destination path.
|
||||
"""
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"-b",
|
||||
version,
|
||||
"--single-branch",
|
||||
repo_url,
|
||||
destination_path,
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
|
||||
def process_git_repo(cur_path, dst_path, commit_id=None, patch=None):
|
||||
"""
|
||||
reset git repo to destination commit and apply patch.
|
||||
"""
|
||||
if commit_id is not None:
|
||||
reset_cmd = ["git", "reset", "--hard", commit_id]
|
||||
if patch is not None:
|
||||
patch_source = os.path.join(cur_path, patch)
|
||||
patch_destination = os.path.join(dst_path, patch)
|
||||
shutil.copy(patch_source, patch_destination)
|
||||
apply_cmd = ["git", "apply", patch]
|
||||
|
||||
try:
|
||||
os.chdir(dst_path)
|
||||
if commit_id is not None:
|
||||
subprocess.run(reset_cmd, check=True)
|
||||
if patch is not None:
|
||||
subprocess.run(apply_cmd, check=True)
|
||||
os.chdir(cur_path)
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
|
||||
def get_sm_version(archs):
|
||||
"""
|
||||
Get sm version of paddle.
|
||||
@@ -191,13 +193,6 @@ def find_end_files(directory, end_str):
|
||||
if paddle.is_compiled_with_rocm():
|
||||
# NOTE(@duanyanhui): paddle.is_compiled_with_cuda() returns True when paddle compiled with rocm.
|
||||
# so we need to check if paddle compiled with rocm at first.
|
||||
json_dir = "third_party/nlohmann_json"
|
||||
if not os.path.exists(json_dir) or not os.listdir(json_dir):
|
||||
if not os.path.exists(json_dir):
|
||||
os.makedirs(json_dir)
|
||||
clone_git_repo("v3.11.3", "https://bgithub.xyz/nlohmann/json.git", json_dir)
|
||||
if not os.listdir(json_dir):
|
||||
raise ValueError("Git clone nlohmann_json failed!")
|
||||
sources = [
|
||||
"gpu_ops/save_with_output_msg.cc",
|
||||
"gpu_ops/get_output.cc",
|
||||
@@ -316,28 +311,6 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/ipc_sent_key_value_cache_by_remote_ptr.cu",
|
||||
]
|
||||
|
||||
cutlass_dir = "third_party/cutlass"
|
||||
if not os.path.exists(cutlass_dir) or not os.listdir(cutlass_dir):
|
||||
if not os.path.exists(cutlass_dir):
|
||||
os.makedirs(cutlass_dir)
|
||||
clone_git_repo("v3.8.0", "https://github.com/NVIDIA/cutlass.git", cutlass_dir)
|
||||
if not os.listdir(cutlass_dir):
|
||||
raise ValueError("Git clone cutlass failed!")
|
||||
|
||||
# deep gemm
|
||||
deep_gemm_dir = "third_party/DeepGEMM"
|
||||
if not os.path.exists(deep_gemm_dir) or not os.listdir(deep_gemm_dir):
|
||||
if not os.path.exists(deep_gemm_dir):
|
||||
os.makedirs(deep_gemm_dir)
|
||||
clone_git_repo("main", "https://github.com/deepseek-ai/DeepGEMM.git", deep_gemm_dir)
|
||||
if not os.listdir(deep_gemm_dir):
|
||||
raise ValueError("Git clone DeepGEMM failed!")
|
||||
cur_path = os.path.dirname(os.path.abspath(__file__))
|
||||
dst_path = os.path.join(cur_path, deep_gemm_dir)
|
||||
commit_id = "95e81b3dd6704e279e5f4757c5b94776ac988a8d"
|
||||
patch = "0001-DeepGEMM-95e81b3.patch"
|
||||
process_git_repo(cur_path, dst_path, commit_id, patch)
|
||||
|
||||
dg_third_party_include_dirs = (
|
||||
"third_party/cutlass/include/cute",
|
||||
"third_party/cutlass/include/cutlass",
|
||||
@@ -365,14 +338,6 @@ elif paddle.is_compiled_with_cuda():
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to copy from {src_dir} to {dst_dir}: {e}")
|
||||
|
||||
json_dir = "third_party/nlohmann_json"
|
||||
if not os.path.exists(json_dir) or not os.listdir(json_dir):
|
||||
if not os.path.exists(json_dir):
|
||||
os.makedirs(json_dir)
|
||||
clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git", json_dir)
|
||||
if not os.listdir(json_dir):
|
||||
raise ValueError("Git clone nlohmann_json failed!")
|
||||
|
||||
cc_compile_args = []
|
||||
nvcc_compile_args = get_gencode_flags(archs)
|
||||
nvcc_compile_args += ["-DPADDLE_DEV"]
|
||||
@@ -542,7 +507,7 @@ elif paddle.is_compiled_with_cuda():
|
||||
include_package_data=True,
|
||||
)
|
||||
elif paddle.is_compiled_with_xpu():
|
||||
assert False, "In XPU, we should use setup_ops.py in xpu_ops/src, not this."
|
||||
assert False, "For XPU, please use setup_ops.py in the xpu_ops directory to compile custom ops."
|
||||
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||
setup(
|
||||
name="fastdeploy_ops",
|
||||
@@ -593,13 +558,6 @@ elif paddle.is_compiled_with_custom_device("gcu"):
|
||||
)
|
||||
elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
maca_path = os.getenv("MACA_PATH", "/opt/maca")
|
||||
json_dir = "third_party/nlohmann_json"
|
||||
if not os.path.exists(json_dir) or not os.listdir(json_dir):
|
||||
if not os.path.exists(json_dir):
|
||||
os.makedirs(json_dir)
|
||||
clone_git_repo("v3.11.3", "https://gitee.com/learnlov/mirrors_nlohmann_json.git", json_dir)
|
||||
if not os.listdir(json_dir):
|
||||
raise ValueError("Git clone nlohmann_json failed!")
|
||||
sources = [
|
||||
"gpu_ops/update_inputs_v1.cu",
|
||||
"gpu_ops/save_with_output_msg.cc",
|
||||
|
1
custom_ops/third_party/DeepGEMM
vendored
Submodule
1
custom_ops/third_party/DeepGEMM
vendored
Submodule
Submodule custom_ops/third_party/DeepGEMM added at 95e81b3dd6
1
custom_ops/third_party/cutlass
vendored
Submodule
1
custom_ops/third_party/cutlass
vendored
Submodule
Submodule custom_ops/third_party/cutlass added at afa1772203
1
custom_ops/third_party/nlohmann_json
vendored
Submodule
1
custom_ops/third_party/nlohmann_json
vendored
Submodule
Submodule custom_ops/third_party/nlohmann_json added at 9cca280a4d
@@ -27,7 +27,7 @@ import paddle
|
||||
from paddle.utils.cpp_extension import CppExtension, setup
|
||||
|
||||
current_file = Path(__file__).resolve()
|
||||
base_dir = current_file.parent
|
||||
base_dir = os.path.join(current_file.parent, "src")
|
||||
|
||||
|
||||
def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR, XDNN_LIB_DIR):
|
||||
@@ -136,33 +136,8 @@ def xpu_setup_ops():
|
||||
# build plugin
|
||||
build_plugin(CLANG_PATH, XRE_INC_PATH, XRE_LIB_DIR, XDNN_INC_PATH, XDNN_LIB_DIR)
|
||||
|
||||
ops = [
|
||||
# custom ops
|
||||
"./ops/save_with_output_msg.cc",
|
||||
"./ops/stop_generation_multi_ends.cc",
|
||||
"./ops/set_value_by_flags_and_idx.cc",
|
||||
"./ops/get_token_penalty_multi_scores.cc",
|
||||
"./ops/get_padding_offset.cc",
|
||||
"./ops/update_inputs.cc",
|
||||
"./ops/recover_decode_task.cc",
|
||||
"./ops/update_inputs_v1.cc",
|
||||
"./ops/get_output.cc",
|
||||
"./ops/step.cc",
|
||||
"./ops/get_infer_param.cc",
|
||||
"./ops/adjust_batch.cc",
|
||||
"./ops/gather_next_token.cc",
|
||||
"./ops/block_attn.cc",
|
||||
"./ops/moe_layer.cc",
|
||||
"./ops/weight_quantize_xpu.cc",
|
||||
# device manage ops
|
||||
"./ops/device/get_context_gm_max_mem_demand.cc",
|
||||
"./ops/device/get_free_global_memory.cc",
|
||||
"./ops/device/get_total_global_memory.cc",
|
||||
"./ops/device/get_used_global_memory.cc",
|
||||
]
|
||||
ops = [os.path.join(base_dir, op) for op in ops]
|
||||
|
||||
for root, dirs, files in os.walk(base_dir / "ops/mtp_ops"):
|
||||
ops = []
|
||||
for root, dirs, files in os.walk(os.path.join(base_dir, "ops")):
|
||||
for file in files:
|
||||
if file.endswith(".cc"):
|
||||
ops.append(os.path.join(root, file))
|
225
custom_ops/xpu_ops/src/ops/fused_rms_norm.cc
Normal file
225
custom_ops/xpu_ops/src/ops/fused_rms_norm.cc
Normal file
@@ -0,0 +1,225 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <infer_ops.h>
|
||||
#include <functional>
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/backends/xpu/enforce_xpu.h"
|
||||
#include "utility/debug.h"
|
||||
#include "utility/env.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
XPU_DECLARE_BOOL(ENABLE_XVLLM_SDNN_INFER, false);
|
||||
namespace api = baidu::xpu::api;
|
||||
|
||||
template <typename T>
|
||||
std::vector<paddle::Tensor> RmsNormKernel(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::optional<paddle::Tensor>& bias,
|
||||
const paddle::optional<paddle::Tensor>& residual,
|
||||
const paddle::Tensor& norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& norm_bias,
|
||||
const float epsilon,
|
||||
const int begin_norm_axis,
|
||||
const float quant_scale,
|
||||
const int quant_round_type,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound) {
|
||||
using XPU_T = typename XPUTypeTrait<T>::Type;
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
|
||||
int ret = -1;
|
||||
auto x_shape = x.shape();
|
||||
PD_CHECK(quant_scale <= 0, "Quantization is not supported");
|
||||
PD_CHECK(begin_norm_axis > 0 && begin_norm_axis <= x_shape.size(),
|
||||
"begin_norm_axis check fail");
|
||||
PD_CHECK(norm_bias.get_ptr() == nullptr,
|
||||
"rms norm kernel don't support norm_bias");
|
||||
|
||||
int64_t m = std::accumulate(x_shape.begin(),
|
||||
x_shape.begin() + begin_norm_axis,
|
||||
static_cast<int64_t>(1),
|
||||
std::multiplies<int64_t>());
|
||||
int64_t n = std::accumulate(x_shape.begin() + begin_norm_axis,
|
||||
x_shape.end(),
|
||||
static_cast<int64_t>(1),
|
||||
std::multiplies<int64_t>());
|
||||
|
||||
PD_CHECK(n == norm_weight.shape()[0],
|
||||
"The product from begin_norm_axis to the last axis of x must be "
|
||||
"equal to the norm_weight's shape[0]");
|
||||
if (bias.get_ptr()) {
|
||||
PD_CHECK(n == bias.get_ptr()->shape()[0],
|
||||
"The product from begin_norm_axis to the last axis of x must be "
|
||||
"equal to the bias's shape[0]");
|
||||
}
|
||||
|
||||
paddle::Tensor out = paddle::empty(x_shape, x.dtype(), x.place());
|
||||
paddle::Tensor residual_out = paddle::empty(x_shape, x.dtype(), x.place());
|
||||
const XPU_T* x_data = reinterpret_cast<const XPU_T*>(x.data<T>());
|
||||
const XPU_T* norm_weight_data =
|
||||
reinterpret_cast<const XPU_T*>(norm_weight.data<T>());
|
||||
const XPU_T* bias_data =
|
||||
bias.get_ptr() ? reinterpret_cast<const XPU_T*>(bias.get_ptr()->data<T>())
|
||||
: nullptr;
|
||||
const XPU_T* residual_data =
|
||||
residual.get_ptr()
|
||||
? reinterpret_cast<const XPU_T*>(residual.get_ptr()->data<T>())
|
||||
: nullptr;
|
||||
XPU_T* out_data = reinterpret_cast<XPU_T*>(const_cast<T*>(out.data<T>()));
|
||||
XPU_T* residual_out_data = nullptr;
|
||||
if (residual_data) {
|
||||
residual_out_data =
|
||||
reinterpret_cast<XPU_T*>(const_cast<T*>(residual_out.data<T>()));
|
||||
}
|
||||
|
||||
XPU_T* add_out_data = const_cast<XPU_T*>(x_data);
|
||||
if (bias_data) {
|
||||
ret = api::broadcast_add(
|
||||
xpu_ctx->x_context(), x_data, bias_data, out_data, {m, n}, {n});
|
||||
PD_CHECK(ret == 0, "broadcast_add");
|
||||
add_out_data = out_data;
|
||||
}
|
||||
|
||||
bool use_sdnn = FLAGS_ENABLE_XVLLM_SDNN_INFER;
|
||||
if (residual_data) {
|
||||
ret = infer_ops::add_rms_layer_norm<XPU_T, XPU_T>(xpu_ctx->x_context(),
|
||||
add_out_data,
|
||||
residual_data,
|
||||
out_data,
|
||||
m,
|
||||
n,
|
||||
epsilon,
|
||||
norm_weight_data,
|
||||
nullptr,
|
||||
nullptr,
|
||||
residual_out_data,
|
||||
nullptr,
|
||||
use_sdnn);
|
||||
PD_CHECK(ret == 0, "add_rms_layer_norm");
|
||||
} else {
|
||||
ret = api::rms_layer_norm<XPU_T, XPU_T>(xpu_ctx->x_context(),
|
||||
add_out_data,
|
||||
out_data,
|
||||
m,
|
||||
n,
|
||||
epsilon,
|
||||
norm_weight_data,
|
||||
nullptr,
|
||||
nullptr,
|
||||
false);
|
||||
PD_CHECK(ret == 0, "rms_layer_norm");
|
||||
}
|
||||
|
||||
return {out, residual_out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> RmsNorm(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::optional<paddle::Tensor>& bias,
|
||||
const paddle::optional<paddle::Tensor>& residual,
|
||||
const paddle::Tensor& norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& norm_bias,
|
||||
const float epsilon,
|
||||
const int begin_norm_axis,
|
||||
const float quant_scale,
|
||||
const int quant_round_type,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound) {
|
||||
const auto x_type = x.dtype();
|
||||
|
||||
#define APPLY_RMS_NORM_KERNEL(TX) \
|
||||
return RmsNormKernel<TX>(x, \
|
||||
bias, \
|
||||
residual, \
|
||||
norm_weight, \
|
||||
norm_bias, \
|
||||
epsilon, \
|
||||
begin_norm_axis, \
|
||||
quant_scale, \
|
||||
quant_round_type, \
|
||||
quant_max_bound, \
|
||||
quant_min_bound);
|
||||
|
||||
if (x_type == paddle::DataType::BFLOAT16) {
|
||||
APPLY_RMS_NORM_KERNEL(paddle::bfloat16);
|
||||
} else if (x_type == paddle::DataType::FLOAT16) {
|
||||
APPLY_RMS_NORM_KERNEL(paddle::float16);
|
||||
} else if (x_type == paddle::DataType::FLOAT32) {
|
||||
APPLY_RMS_NORM_KERNEL(float);
|
||||
} else {
|
||||
PD_THROW("RmsNorm not support x_type=", static_cast<int>(x_type));
|
||||
return {};
|
||||
}
|
||||
#undef APPLY_RMS_NORM_KERNEL
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> RmsNormInferShape(
|
||||
const std::vector<int64_t>& x_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& residual_shape,
|
||||
const std::vector<int64_t>& norm_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& norm_bias_shape,
|
||||
const float epsilon,
|
||||
const int begin_norm_axis,
|
||||
const float quant_scale,
|
||||
const int quant_round_type,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound) {
|
||||
PD_CHECK(begin_norm_axis > 0 && begin_norm_axis <= x_shape.size(),
|
||||
"begin_norm_axis check fail");
|
||||
int64_t m = std::accumulate(x_shape.begin(),
|
||||
x_shape.begin() + begin_norm_axis,
|
||||
static_cast<int64_t>(1),
|
||||
std::multiplies<int64_t>());
|
||||
return {x_shape, x_shape, {m}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> RmsNormInferDtype(
|
||||
const paddle::DataType& x_dtype,
|
||||
const paddle::optional<paddle::DataType>& bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& residual_dtype,
|
||||
const paddle::DataType& norm_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& norm_bias_dtype,
|
||||
const float epsilon,
|
||||
const int begin_norm_axis,
|
||||
const float quant_scale,
|
||||
const int quant_round_type,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound) {
|
||||
// out, residual_out
|
||||
return {x_dtype, x_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(fused_rms_norm_xpu)
|
||||
.Inputs({"x",
|
||||
paddle::Optional("bias"),
|
||||
paddle::Optional("residual"),
|
||||
"norm_weight",
|
||||
paddle::Optional("norm_bias")})
|
||||
.Outputs({"out", "residul_out"})
|
||||
.Attrs({"epsilon:float",
|
||||
"begin_norm_axis:int",
|
||||
"quant_scale:float",
|
||||
"quant_round_type:int",
|
||||
"quant_max_bound:float",
|
||||
"quant_min_bound:float"})
|
||||
.SetKernelFn(PD_KERNEL(RmsNorm))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(RmsNormInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(RmsNormInferDtype));
|
@@ -18,13 +18,35 @@
|
||||
#include <sys/ipc.h>
|
||||
#include <sys/msg.h>
|
||||
#include <sys/types.h>
|
||||
#include "msg_utils.h"
|
||||
|
||||
#define MAX_BSZ 256
|
||||
// #define GET_OUTPUT_DEBUG
|
||||
struct msgdata {
|
||||
long mtype;
|
||||
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
|
||||
};
|
||||
void GetOutputKVSignal(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag) {
|
||||
int msg_queue_id = 1024 + rank_id;
|
||||
static struct msgdatakv msg_rcv;
|
||||
static key_t key = ftok("/opt/", msg_queue_id);
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
|
||||
int* out_data = const_cast<int*>(x.data<int>());
|
||||
int ret = -1;
|
||||
if (!wait_flag) {
|
||||
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT);
|
||||
} else {
|
||||
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, 0);
|
||||
}
|
||||
if (ret == -1) {
|
||||
out_data[0] = -1;
|
||||
out_data[1] = -1;
|
||||
return;
|
||||
}
|
||||
int encoder_count = msg_rcv.mtext[0];
|
||||
|
||||
for (int i = 0; i < encoder_count * 3 + 2; i++) {
|
||||
out_data[i] = msg_rcv.mtext[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void GetOutput(const paddle::Tensor &x, int64_t rank_id, bool wait_flag,
|
||||
int msg_queue_id) {
|
||||
|
119
custom_ops/xpu_ops/src/ops/moe_ep_combine.cc
Normal file
119
custom_ops/xpu_ops/src/ops/moe_ep_combine.cc
Normal file
@@ -0,0 +1,119 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <infer_ops.h>
|
||||
#include <xft_api.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/backends/xpu/enforce_xpu.h"
|
||||
#include "utility/debug.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
std::vector<paddle::Tensor> MoeEPCombineKernel(
|
||||
const paddle::Tensor&
|
||||
ffn_out, // expand_token_num * hidden_dim dtype is fp16/bf16
|
||||
const paddle::Tensor& moe_index, // token_num * topk dtype is int
|
||||
const paddle::Tensor&
|
||||
weights, // token_num * topk dtype is same as ffn_out
|
||||
int64_t recv_token_num,
|
||||
int64_t expand_token_num,
|
||||
int64_t hidden_dim,
|
||||
int64_t topk) {
|
||||
using XPU_T = typename XPUTypeTrait<T>::Type;
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
|
||||
auto combined_out = paddle::empty(
|
||||
{recv_token_num, hidden_dim}, ffn_out.dtype(), ffn_out.place());
|
||||
|
||||
const float* dequant_score = nullptr;
|
||||
int ret = infer_ops::moe_ep_ffn_post_fusion(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_T*>(ffn_out.data<T>()),
|
||||
moe_index.data<int32_t>(),
|
||||
reinterpret_cast<const XPU_T*>(weights.data<T>()),
|
||||
dequant_score,
|
||||
reinterpret_cast<XPU_T*>(combined_out.mutable_data<T>()),
|
||||
recv_token_num,
|
||||
hidden_dim,
|
||||
topk,
|
||||
expand_token_num);
|
||||
PD_CHECK(ret == 0);
|
||||
|
||||
return {combined_out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MoeEPCombine(const paddle::Tensor& ffn_out,
|
||||
const paddle::Tensor& moe_index,
|
||||
const paddle::Tensor& weights,
|
||||
const int recv_token_num,
|
||||
const int expand_token_num,
|
||||
const int hidden_dim,
|
||||
const int topk) {
|
||||
#define APPLY_KERNEL(TX) \
|
||||
return MoeEPCombineKernel<TX>(ffn_out, \
|
||||
moe_index, \
|
||||
weights, \
|
||||
recv_token_num, \
|
||||
expand_token_num, \
|
||||
hidden_dim, \
|
||||
topk);
|
||||
|
||||
const auto ffn_out_dtype = ffn_out.dtype();
|
||||
if (ffn_out_dtype == paddle::DataType::FLOAT16) {
|
||||
APPLY_KERNEL(paddle::float16);
|
||||
} else if (ffn_out_dtype == paddle::DataType::BFLOAT16) {
|
||||
APPLY_KERNEL(paddle::bfloat16);
|
||||
} else {
|
||||
PD_THROW("MoeEPCombine not support ffn_out_type==%d",
|
||||
static_cast<int>(ffn_out_dtype));
|
||||
return {};
|
||||
}
|
||||
|
||||
#undef APPLY_KERNEL
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeEPCombineInferShape(
|
||||
const std::vector<int64_t>& ffn_out_shape,
|
||||
const std::vector<int64_t>& moe_index_shape,
|
||||
const std::vector<int64_t>& weights_shape,
|
||||
const int recv_token_num,
|
||||
const int expand_token_num,
|
||||
const int hidden_dim,
|
||||
const int topk) {
|
||||
std::vector<int64_t> combined_out_shape = {recv_token_num, hidden_dim};
|
||||
return {combined_out_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MoeEPCombineInferDtype(
|
||||
const paddle::DataType& ffn_out_dtype,
|
||||
const paddle::DataType& moe_index_dtype,
|
||||
const paddle::DataType& weights_dtype) {
|
||||
return {ffn_out_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(ep_moe_expert_combine)
|
||||
.Inputs({"ffn_out", "moe_index", "weights"})
|
||||
.Outputs({"combined_out"})
|
||||
.Attrs({"recv_token_num: int",
|
||||
"expand_token_num: int",
|
||||
"hidden_dim: int",
|
||||
"topk: int"})
|
||||
.SetKernelFn(PD_KERNEL(MoeEPCombine))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeEPCombineInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeEPCombineInferDtype));
|
201
custom_ops/xpu_ops/src/ops/moe_ep_dispatch.cc
Normal file
201
custom_ops/xpu_ops/src/ops/moe_ep_dispatch.cc
Normal file
@@ -0,0 +1,201 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <infer_ops.h>
|
||||
#include <infer_ops_eb.h>
|
||||
#include <xft_api.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/backends/xpu/enforce_xpu.h"
|
||||
#include "utility/debug.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
template <typename TX, typename TY>
|
||||
std::vector<paddle::Tensor> EPMoeExpertDispatchKernel(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& topk_weights,
|
||||
const paddle::optional<paddle::Tensor>& input_scales,
|
||||
const std::vector<int>& token_nums_per_expert,
|
||||
const int64_t token_nums_this_rank) {
|
||||
using XPU_TX = typename XPUTypeTrait<TX>::Type;
|
||||
using XPU_TY = typename XPUTypeTrait<TY>::Type;
|
||||
phi::XPUPlace xpu_place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx =
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(xpu_place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
|
||||
const auto input_type = input.dtype();
|
||||
auto m = input.dims()[0];
|
||||
auto n = input.dims()[1];
|
||||
const int64_t expert_num = token_nums_per_expert.size();
|
||||
const int topk = topk_ids.dims()[1];
|
||||
auto place = input.place();
|
||||
|
||||
auto block_num = xpu_ctx->x_context()->ncluster();
|
||||
paddle::Tensor permute_input;
|
||||
auto permute_indices_per_token =
|
||||
paddle::empty({m, topk}, paddle::DataType::INT32, place);
|
||||
auto expert_m = paddle::empty({expert_num}, paddle::DataType::INT32, place);
|
||||
auto recv_num_tokens_per_expert_list_cumsum =
|
||||
paddle::empty({expert_num + 1}, paddle::DataType::INT32, place);
|
||||
auto expand_input_scales =
|
||||
paddle::empty({token_nums_this_rank}, paddle::DataType::FLOAT32, place);
|
||||
const int64_t ep_size = 1;
|
||||
const int64_t ep_rank = 0;
|
||||
|
||||
if (std::is_same<TY, int8_t>::value) {
|
||||
permute_input =
|
||||
paddle::empty({token_nums_this_rank, n}, paddle::DataType::INT8, place);
|
||||
auto ret = infer_ops::moe_ffn_pre_sorted_quant_pe<XPU_TX, int>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_TX*>(input.data<TX>()),
|
||||
topk_ids.data<int>(),
|
||||
input_scales.get_ptr()->data<float>(),
|
||||
nullptr,
|
||||
reinterpret_cast<int8_t*>(permute_input.data<int8_t>()),
|
||||
const_cast<int*>(permute_indices_per_token.data<int>()),
|
||||
const_cast<int*>(expert_m.data<int>()),
|
||||
const_cast<int*>(recv_num_tokens_per_expert_list_cumsum.data<int>()),
|
||||
expand_input_scales.data<float>(),
|
||||
m,
|
||||
n,
|
||||
expert_num,
|
||||
topk,
|
||||
block_num,
|
||||
token_nums_this_rank);
|
||||
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
|
||||
} else {
|
||||
permute_input = paddle::empty({token_nums_this_rank, n}, input_type, place);
|
||||
auto ret = infer_ops::moe_ep_ffn_pre_sorted<XPU_TX, int>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_TX*>(input.data<TX>()),
|
||||
topk_ids.data<int>(),
|
||||
nullptr,
|
||||
reinterpret_cast<XPU_TX*>(permute_input.data<TX>()),
|
||||
const_cast<int*>(permute_indices_per_token.data<int>()),
|
||||
const_cast<int*>(expert_m.data<int>()),
|
||||
const_cast<int*>(recv_num_tokens_per_expert_list_cumsum.data<int>()),
|
||||
m,
|
||||
n,
|
||||
expert_num,
|
||||
topk,
|
||||
block_num,
|
||||
ep_size,
|
||||
ep_rank,
|
||||
token_nums_this_rank);
|
||||
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
|
||||
}
|
||||
return {permute_input,
|
||||
permute_indices_per_token,
|
||||
recv_num_tokens_per_expert_list_cumsum,
|
||||
topk_weights,
|
||||
expand_input_scales};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> EPMoeExpertDispatch(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& topk_weights,
|
||||
const paddle::optional<paddle::Tensor>& input_scales,
|
||||
const std::vector<int>& token_nums_per_expert,
|
||||
const int token_nums_this_rank,
|
||||
const std::string quant_method) {
|
||||
#define APPLY_KERNEL(TX, TY) \
|
||||
return EPMoeExpertDispatchKernel<TX, TY>(input, \
|
||||
topk_ids, \
|
||||
topk_weights, \
|
||||
input_scales, \
|
||||
token_nums_per_expert, \
|
||||
token_nums_this_rank);
|
||||
|
||||
const auto input_dtype = input.dtype();
|
||||
if (input_dtype == paddle::DataType::FLOAT16 && quant_method == "w4a8") {
|
||||
APPLY_KERNEL(paddle::float16, int8_t);
|
||||
} else if (input_dtype == paddle::DataType::FLOAT16 &&
|
||||
quant_method != "w4a8") {
|
||||
APPLY_KERNEL(paddle::float16, paddle::float16);
|
||||
} else if (input_dtype == paddle::DataType::BFLOAT16 &&
|
||||
quant_method == "w4a8") {
|
||||
APPLY_KERNEL(paddle::bfloat16, int8_t);
|
||||
} else if (input_dtype == paddle::DataType::BFLOAT16 &&
|
||||
quant_method != "w4a8") {
|
||||
APPLY_KERNEL(paddle::bfloat16, paddle::bfloat16);
|
||||
} else {
|
||||
PD_THROW("EPMoeExpertDispatch not support input_dtype=",
|
||||
static_cast<int>(input_dtype),
|
||||
"quant_method=",
|
||||
quant_method);
|
||||
return {};
|
||||
}
|
||||
|
||||
#undef APPLY_KERNEL
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> EPMoeExpertDispatchInferShape(
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const std::vector<int64_t>& topk_ids_shape,
|
||||
const std::vector<int64_t>& topk_weights_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& input_scales_shape,
|
||||
const std::vector<int>& token_nums_per_expert,
|
||||
const int token_nums_this_rank,
|
||||
const std::string quant_method) {
|
||||
const int m = input_shape[0];
|
||||
const int hidden_size = input_shape[input_shape.size() - 1];
|
||||
const int topk = topk_ids_shape[topk_ids_shape.size() - 1];
|
||||
const int expert_num = token_nums_per_expert.size();
|
||||
return {{token_nums_this_rank, hidden_size},
|
||||
{expert_num, m},
|
||||
{expert_num},
|
||||
{token_nums_this_rank},
|
||||
{token_nums_this_rank}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> EPMoeExpertDispatchInferDtype(
|
||||
const paddle::DataType& input_dtype,
|
||||
const paddle::DataType& topk_ids_dtype,
|
||||
const paddle::DataType& topk_weights_dtype,
|
||||
const paddle::optional<paddle::DataType>& input_scales_dtype,
|
||||
const std::vector<int>& token_nums_per_expert,
|
||||
const int token_nums_this_rank,
|
||||
const std::string quant_method) {
|
||||
auto output_dtype = input_dtype;
|
||||
if (quant_method == "w4a8") {
|
||||
output_dtype = paddle::DataType::INT8;
|
||||
}
|
||||
return {
|
||||
output_dtype,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::INT32,
|
||||
topk_weights_dtype,
|
||||
paddle::DataType::FLOAT32,
|
||||
};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(ep_moe_expert_dispatch)
|
||||
.Inputs(
|
||||
{"input", "topk_ids", "topk_weights", paddle::Optional("input_scales")})
|
||||
.Outputs({"permute_input",
|
||||
"permute_indices_per_token",
|
||||
"token_nums_per_expert_cumsum",
|
||||
"dst_weights",
|
||||
"expand_input_scales"})
|
||||
.Attrs({"token_nums_per_expert: std::vector<int>",
|
||||
"token_nums_this_rank: int",
|
||||
"quant_method: std::string"})
|
||||
.SetKernelFn(PD_KERNEL(EPMoeExpertDispatch))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(EPMoeExpertDispatchInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(EPMoeExpertDispatchInferDtype));
|
535
custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc
Normal file
535
custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc
Normal file
@@ -0,0 +1,535 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <blocks/moe_fc_block_eb.h>
|
||||
#include <core/check.h>
|
||||
#include <core/context.h>
|
||||
#include <core/param.h>
|
||||
#include <infer_ops.h>
|
||||
#include <xft_api.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/backends/xpu/enforce_xpu.h"
|
||||
#include "utility/debug.h"
|
||||
#include "utility/env.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
XPU_DECLARE_BOOL(MOE_FFN_USE_DENSE_INPUT, false);
|
||||
XPU_DECLARE_BOOL(BKCL_DISPATCH_ALL_GATHER, false);
|
||||
|
||||
namespace xftblock = baidu::xpu::xftblock;
|
||||
namespace api = baidu::xpu::api;
|
||||
|
||||
template <typename TX1, typename TX2, typename TW, typename TGEMM>
|
||||
void MoeExpertFFNImpl(xftblock::Tensor* ffn_in,
|
||||
xftblock::Tensor* token_num_info,
|
||||
xftblock::Tensor* ffn1_weight,
|
||||
xftblock::Tensor* ffn2_weight,
|
||||
xftblock::Tensor* ffn1_bias,
|
||||
xftblock::Tensor* ffn2_bias,
|
||||
xftblock::Tensor* ffn2_out,
|
||||
float* ffn2_act_scale,
|
||||
TX2* ffn2_shift,
|
||||
TX2* ffn2_smooth,
|
||||
const int hadamard_blocksize) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
xftblock::XFTContext xctx(xpu_ctx->x_context(), nullptr);
|
||||
auto rt_guard = xctx.get_rt_guard();
|
||||
auto xftblock_tx2 = xftblock::DataTypeToEnum<TX2>::value;
|
||||
|
||||
int ret = -1;
|
||||
int expert_num = ffn1_weight->get_dim(0);
|
||||
int inter_dim = ffn1_weight->get_dim(1);
|
||||
int outer_dim = inter_dim / 2;
|
||||
|
||||
bool is_padding_input = ffn_in->get_dims().size() == 3;
|
||||
auto ffn1_out_shape = ffn_in->get_dims();
|
||||
int hidden_dim = ffn1_out_shape[ffn1_out_shape.size() - 1];
|
||||
ffn1_out_shape[ffn1_out_shape.size() - 1] = inter_dim;
|
||||
xftblock::Tensor ffn1_out(rt_guard, xftblock_tx2, ffn1_out_shape);
|
||||
ret = xftblock::xft_moe_fc_block_eb<TX1, TW, TX2, float, int, TGEMM>(
|
||||
&xctx,
|
||||
ffn_in,
|
||||
ffn1_weight,
|
||||
&ffn1_out,
|
||||
ffn1_bias,
|
||||
is_padding_input ? nullptr : token_num_info,
|
||||
is_padding_input ? token_num_info : nullptr,
|
||||
expert_num,
|
||||
1, // moe_topk
|
||||
ffn1_out_shape.size() == 2 ? xftblock::MoeFCInputMode::DENSE
|
||||
: xftblock::MoeFCInputMode::SPARSE);
|
||||
PD_CHECK(ret == 0);
|
||||
|
||||
int token_num = ffn_in->numel() / hidden_dim;
|
||||
auto swiglu_out_shape = ffn1_out_shape;
|
||||
swiglu_out_shape[swiglu_out_shape.size() - 1] /= 2;
|
||||
xftblock::Tensor swiglu_out(rt_guard, xftblock_tx2, swiglu_out_shape);
|
||||
ret = api::fast_swiglu<TX2>(xpu_ctx->x_context(),
|
||||
ffn1_out.data<TX2>(),
|
||||
swiglu_out.mutable_data<TX2>(),
|
||||
{token_num, inter_dim},
|
||||
1,
|
||||
true);
|
||||
PD_CHECK(ret == 0);
|
||||
// TODO(mayang02): use fusion_smooth_transform
|
||||
if (ffn2_shift != nullptr) {
|
||||
ret = api::broadcast_add<TX2>(xpu_ctx->x_context(),
|
||||
ffn2_shift,
|
||||
swiglu_out.data<TX2>(),
|
||||
swiglu_out.mutable_data<TX2>(),
|
||||
{1, outer_dim},
|
||||
{token_num, outer_dim});
|
||||
PD_CHECK(ret == 0);
|
||||
}
|
||||
if (ffn2_smooth != nullptr) {
|
||||
ret = api::broadcast_mul<TX2>(xpu_ctx->x_context(),
|
||||
ffn2_smooth,
|
||||
swiglu_out.data<TX2>(),
|
||||
swiglu_out.mutable_data<TX2>(),
|
||||
{1, outer_dim},
|
||||
{token_num, outer_dim});
|
||||
PD_CHECK(ret == 0);
|
||||
}
|
||||
|
||||
if (hadamard_blocksize > 0) {
|
||||
ret = infer_ops::fast_walsh_transform<TX2>(xpu_ctx->x_context(),
|
||||
swiglu_out.data<TX2>(),
|
||||
nullptr,
|
||||
nullptr,
|
||||
swiglu_out.mutable_data<TX2>(),
|
||||
hadamard_blocksize,
|
||||
token_num,
|
||||
outer_dim);
|
||||
PD_CHECK(ret == 0);
|
||||
}
|
||||
|
||||
xftblock::Tensor ffn2_in(swiglu_out.mutable_data<TX2>(),
|
||||
nullptr,
|
||||
ffn2_act_scale,
|
||||
xftblock_tx2,
|
||||
swiglu_out_shape);
|
||||
ret = xftblock::xft_moe_fc_block_eb<TX2, TW, TX2, float, int, TGEMM>(
|
||||
&xctx,
|
||||
&ffn2_in,
|
||||
ffn2_weight,
|
||||
ffn2_out,
|
||||
nullptr,
|
||||
is_padding_input ? nullptr : token_num_info,
|
||||
is_padding_input ? token_num_info : nullptr,
|
||||
expert_num,
|
||||
1, // moe_topk
|
||||
ffn1_out_shape.size() == 2
|
||||
? xftblock::MoeFCInputMode::DENSE
|
||||
: xftblock::MoeFCInputMode::SPARSE); // bias_mode
|
||||
PD_CHECK(ret == 0);
|
||||
}
|
||||
|
||||
static void convert_to_lod(xftblock::XFTContext* xctx,
|
||||
xftblock::Tensor* token_num_info) {
|
||||
auto rt_guard = xctx->get_rt_guard();
|
||||
auto ctx = xctx->get_context();
|
||||
const int expert_num = token_num_info->numel();
|
||||
xftblock::Tensor tokens_num_lod(
|
||||
rt_guard, xftblock::DataType::DT_INT32, {expert_num + 1});
|
||||
int ret = api::constant(ctx, tokens_num_lod.data<int>(), expert_num + 1, 0);
|
||||
PD_CHECK(ret == 0);
|
||||
ret = api::cumsum<int>(ctx,
|
||||
token_num_info->data<int>(),
|
||||
tokens_num_lod.data<int>() + 1,
|
||||
{expert_num},
|
||||
false,
|
||||
false,
|
||||
0);
|
||||
PD_CHECK(ret == 0);
|
||||
*token_num_info = std::move(tokens_num_lod);
|
||||
}
|
||||
|
||||
template <typename TX1, typename TX2, typename TW>
|
||||
std::vector<paddle::Tensor> MoeExpertFFNKernel(
|
||||
const paddle::Tensor& ffn_in,
|
||||
const paddle::Tensor& token_num_info,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_act_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_act_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_weight_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_weight_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_shift,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_smooth,
|
||||
const std::string& quant_method,
|
||||
const int hadamard_blocksize,
|
||||
const int valid_token_num) {
|
||||
using XPU_TX1 = typename XPUTypeTrait<TX1>::Type;
|
||||
using XPU_TX2 = typename XPUTypeTrait<TX2>::Type;
|
||||
using XPU_TW = typename XPUTypeTrait<TW>::Type;
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
xftblock::XFTContext xctx(xpu_ctx->x_context(), nullptr);
|
||||
auto rt_guard = xctx.get_rt_guard();
|
||||
|
||||
int ret = -1;
|
||||
auto input_shape = ffn_in.shape();
|
||||
auto ffn1_w_shape = ffn1_weight.shape();
|
||||
int expert_num = ffn1_w_shape[0];
|
||||
int hidden_dim = input_shape[input_shape.size() - 1];
|
||||
int inter_dim = ffn1_w_shape[1];
|
||||
int outer_dim = inter_dim / 2;
|
||||
bool is_padding_input = input_shape.size() == 3;
|
||||
if (is_padding_input) {
|
||||
PD_CHECK(input_shape[0] == expert_num);
|
||||
PD_CHECK(token_num_info.numel() == expert_num,
|
||||
"token_num_info.numel() != expert_num, "
|
||||
"token_num_info.numel(): ",
|
||||
token_num_info.numel(),
|
||||
", expert_num: ",
|
||||
expert_num);
|
||||
}
|
||||
|
||||
bool is_w4 = quant_method == "w4a8" || quant_method == "weight_only_int4";
|
||||
auto xftblock_tx1 = xftblock::DataTypeToEnum<XPU_TX1>::value;
|
||||
auto xftblock_tx2 = xftblock::DataTypeToEnum<XPU_TX2>::value;
|
||||
auto xftblock_tw = xftblock::DataTypeToEnum<XPU_TW>::value;
|
||||
if (is_w4) {
|
||||
xftblock_tw = xftblock::DataTypeToEnum<int4_t>::value;
|
||||
}
|
||||
float* ffn1_act_scale_data =
|
||||
ffn1_act_scale.get_ptr() == nullptr
|
||||
? nullptr
|
||||
: const_cast<float*>(ffn1_act_scale.get_ptr()->data<float>());
|
||||
float* ffn2_act_scale_data =
|
||||
ffn2_act_scale.get_ptr() == nullptr
|
||||
? nullptr
|
||||
: const_cast<float*>(ffn2_act_scale.get_ptr()->data<float>());
|
||||
float* ffn1_w_scale_data =
|
||||
ffn1_weight_scale.get_ptr() == nullptr
|
||||
? nullptr
|
||||
: const_cast<float*>(ffn1_weight_scale.get_ptr()->data<float>());
|
||||
xftblock::Tensor xffn1_w(const_cast<TW*>(ffn1_weight.data<TW>()),
|
||||
nullptr,
|
||||
ffn1_w_scale_data,
|
||||
xftblock_tw,
|
||||
{expert_num, inter_dim, hidden_dim});
|
||||
float* ffn2_w_scale_data =
|
||||
ffn2_weight_scale.get_ptr() == nullptr
|
||||
? nullptr
|
||||
: const_cast<float*>(ffn2_weight_scale.get_ptr()->data<float>());
|
||||
xftblock::Tensor xffn2_w(const_cast<TW*>(ffn2_weight.data<TW>()),
|
||||
nullptr,
|
||||
ffn2_w_scale_data,
|
||||
xftblock_tw,
|
||||
{expert_num, hidden_dim, outer_dim});
|
||||
std::shared_ptr<xftblock::Tensor> xffn1_bias;
|
||||
if (ffn1_bias.get_ptr()) {
|
||||
xffn1_bias = std::make_shared<xftblock::Tensor>(
|
||||
const_cast<float*>(ffn1_bias.get_ptr()->data<float>()),
|
||||
xftblock::DataType::DT_FLOAT,
|
||||
ffn1_bias.get_ptr()->shape());
|
||||
}
|
||||
std::shared_ptr<xftblock::Tensor> xffn2_bias;
|
||||
if (ffn2_bias.get_ptr()) {
|
||||
xffn2_bias = std::make_shared<xftblock::Tensor>(
|
||||
const_cast<float*>(ffn2_bias.get_ptr()->data<float>()),
|
||||
xftblock::DataType::DT_FLOAT,
|
||||
ffn2_bias.get_ptr()->shape());
|
||||
}
|
||||
xftblock::Tensor xtoken_num_info(const_cast<int*>(token_num_info.data<int>()),
|
||||
xftblock::DataType::DT_INT32,
|
||||
token_num_info.shape());
|
||||
XPU_TX2* shift_data = nullptr;
|
||||
XPU_TX2* smooth_data = nullptr;
|
||||
if (ffn2_shift.get_ptr()) {
|
||||
shift_data = reinterpret_cast<XPU_TX2*>(
|
||||
const_cast<TX2*>(ffn2_shift.get_ptr()->data<TX2>()));
|
||||
}
|
||||
if (ffn2_smooth.get_ptr()) {
|
||||
smooth_data = reinterpret_cast<XPU_TX2*>(
|
||||
const_cast<TX2*>(ffn2_smooth.get_ptr()->data<TX2>()));
|
||||
}
|
||||
paddle::Tensor ffn2_out =
|
||||
paddle::empty_like(ffn_in, paddle::DataType::BFLOAT16);
|
||||
xftblock::Tensor xffn1_in;
|
||||
xftblock::Tensor xffn2_out;
|
||||
paddle::Tensor ffn1_in_dense;
|
||||
paddle::Tensor ffn1_in_scale_per_token;
|
||||
if (FLAGS_MOE_FFN_USE_DENSE_INPUT && is_padding_input) {
|
||||
convert_to_lod(&xctx, &xtoken_num_info);
|
||||
if (quant_method == "w4a8") {
|
||||
ffn1_in_scale_per_token = paddle::empty(
|
||||
{valid_token_num}, paddle::DataType::FLOAT32, ffn_in.place());
|
||||
ffn1_in_dense = paddle::empty({valid_token_num, hidden_dim},
|
||||
paddle::DataType::INT8,
|
||||
ffn_in.place());
|
||||
xffn1_in = xftblock::Tensor(ffn1_in_dense.data<int8_t>(),
|
||||
nullptr,
|
||||
ffn1_in_scale_per_token.data<float>(),
|
||||
xftblock::DataType::DT_INT8,
|
||||
{valid_token_num, hidden_dim});
|
||||
if (std::is_same<XPU_TX1, int8_t>::value) {
|
||||
PD_CHECK(ffn1_act_scale_data != nullptr,
|
||||
"need ffn1_act_scale for x int8 per expert input");
|
||||
ret = infer_ops::sequence_unpad<float, int>(
|
||||
xpu_ctx->x_context(),
|
||||
ffn1_act_scale_data,
|
||||
ffn1_in_scale_per_token.data<float>(),
|
||||
xtoken_num_info.data<int>(),
|
||||
expert_num,
|
||||
input_shape[1],
|
||||
1,
|
||||
true);
|
||||
PD_CHECK(ret == 0);
|
||||
ret = infer_ops::sequence_unpad<int8_t, int>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const int8_t*>(ffn_in.data<int8_t>()),
|
||||
reinterpret_cast<int8_t*>(xffn1_in.data<int8_t>()),
|
||||
xtoken_num_info.data<int>(),
|
||||
expert_num,
|
||||
input_shape[1],
|
||||
input_shape[2],
|
||||
true);
|
||||
PD_CHECK(ret == 0);
|
||||
} else {
|
||||
ret = infer_ops::quant2d_per_expert<XPU_TX1>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_TX1*>(ffn_in.data<TX1>()),
|
||||
ffn1_act_scale_data,
|
||||
xtoken_num_info.data<int>(),
|
||||
reinterpret_cast<int8_t*>(xffn1_in.data<int8_t>()),
|
||||
ffn1_in_scale_per_token.data<float>(),
|
||||
expert_num,
|
||||
valid_token_num,
|
||||
hidden_dim,
|
||||
true,
|
||||
false,
|
||||
input_shape[1]);
|
||||
PD_CHECK(ret == 0);
|
||||
}
|
||||
} else {
|
||||
ffn1_in_dense = paddle::empty(
|
||||
{valid_token_num, hidden_dim}, ffn_in.dtype(), ffn_in.place());
|
||||
xffn1_in = xftblock::Tensor(ffn1_in_dense.data<TX1>(),
|
||||
nullptr,
|
||||
ffn1_act_scale_data,
|
||||
xftblock_tx1,
|
||||
{valid_token_num, hidden_dim});
|
||||
ret = infer_ops::sequence_unpad<XPU_TX1, int>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_TX1*>(ffn_in.data<TX1>()),
|
||||
reinterpret_cast<XPU_TX1*>(xffn1_in.data<XPU_TX1>()),
|
||||
xtoken_num_info.data<int>(),
|
||||
expert_num,
|
||||
input_shape[1],
|
||||
input_shape[2],
|
||||
true);
|
||||
PD_CHECK(ret == 0);
|
||||
}
|
||||
xffn2_out =
|
||||
xftblock::Tensor(rt_guard, xftblock_tx2, {valid_token_num, hidden_dim});
|
||||
} else if (FLAGS_BKCL_DISPATCH_ALL_GATHER && !is_padding_input &&
|
||||
quant_method == "w4a8") {
|
||||
convert_to_lod(&xctx, &xtoken_num_info);
|
||||
ffn1_in_scale_per_token = paddle::empty(
|
||||
{valid_token_num}, paddle::DataType::FLOAT32, ffn_in.place());
|
||||
ffn1_in_dense = paddle::empty(
|
||||
{valid_token_num, hidden_dim}, paddle::DataType::INT8, ffn_in.place());
|
||||
xffn1_in = xftblock::Tensor(ffn1_in_dense.data<int8_t>(),
|
||||
nullptr,
|
||||
ffn1_in_scale_per_token.data<float>(),
|
||||
xftblock::DataType::DT_INT8,
|
||||
{valid_token_num, hidden_dim});
|
||||
ret = infer_ops::quant2d_per_expert<XPU_TX1>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_TX1*>(ffn_in.data<TX1>()),
|
||||
ffn1_act_scale_data,
|
||||
xtoken_num_info.data<int>(),
|
||||
reinterpret_cast<int8_t*>(xffn1_in.data<int8_t>()),
|
||||
ffn1_in_scale_per_token.data<float>(),
|
||||
expert_num,
|
||||
valid_token_num,
|
||||
hidden_dim);
|
||||
PD_CHECK(ret == 0);
|
||||
xffn2_out =
|
||||
xftblock::Tensor(ffn2_out.data<TX2>(), xftblock_tx2, input_shape);
|
||||
} else {
|
||||
xffn1_in = xftblock::Tensor(const_cast<TX1*>(ffn_in.data<TX1>()),
|
||||
nullptr,
|
||||
ffn1_act_scale_data,
|
||||
xftblock_tx1,
|
||||
input_shape);
|
||||
xffn2_out = xftblock::Tensor(
|
||||
ffn2_out.mutable_data<TX2>(), xftblock_tx2, input_shape);
|
||||
}
|
||||
|
||||
#define FFN_IMPL(TX1, TX2, TW, TGEMM) \
|
||||
MoeExpertFFNImpl<TX1, TX2, TW, TGEMM>(&xffn1_in, \
|
||||
&xtoken_num_info, \
|
||||
&xffn1_w, \
|
||||
&xffn2_w, \
|
||||
xffn1_bias.get(), \
|
||||
xffn2_bias.get(), \
|
||||
&xffn2_out, \
|
||||
ffn2_act_scale_data, \
|
||||
shift_data, \
|
||||
smooth_data, \
|
||||
hadamard_blocksize)
|
||||
if (quant_method == "weight_only_int8") {
|
||||
FFN_IMPL(XPU_TX1, XPU_TX2, int8_t, float);
|
||||
} else if (quant_method == "weight_only_int4") {
|
||||
FFN_IMPL(XPU_TX1, XPU_TX2, int4_t, int4_wo_int15);
|
||||
} else if (quant_method == "w4a8") {
|
||||
if (FLAGS_MOE_FFN_USE_DENSE_INPUT && is_padding_input) {
|
||||
FFN_IMPL(int8_t, XPU_TX2, int4_t, int4_wo_int8);
|
||||
} else if (FLAGS_BKCL_DISPATCH_ALL_GATHER && !is_padding_input) {
|
||||
FFN_IMPL(int8_t, XPU_TX2, int4_t, int4_wo_int8);
|
||||
} else {
|
||||
FFN_IMPL(XPU_TX1, XPU_TX2, int4_t, int4_wo_int8);
|
||||
}
|
||||
} else {
|
||||
FFN_IMPL(XPU_TX1, XPU_TX2, XPU_TW, float);
|
||||
}
|
||||
#undef FFN_IMPL
|
||||
if (FLAGS_MOE_FFN_USE_DENSE_INPUT && is_padding_input) {
|
||||
ret = infer_ops::sequence_pad<XPU_TX2, int>(
|
||||
xpu_ctx->x_context(),
|
||||
const_cast<XPU_TX2*>(xffn2_out.data<XPU_TX2>()),
|
||||
reinterpret_cast<XPU_TX2*>(ffn2_out.data<TX2>()),
|
||||
xtoken_num_info.data<int>(),
|
||||
input_shape[0],
|
||||
input_shape[1],
|
||||
input_shape[2],
|
||||
false,
|
||||
0);
|
||||
PD_CHECK(ret == 0);
|
||||
}
|
||||
|
||||
return {ffn2_out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
const paddle::Tensor& ffn_in,
|
||||
const paddle::Tensor& token_num_info,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_act_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_act_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_weight_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_weight_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_shift,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_smooth,
|
||||
const std::string& quant_method,
|
||||
const int hadamard_blocksize,
|
||||
const int valid_token_num) {
|
||||
const auto x_type = ffn_in.dtype();
|
||||
const auto w_type = ffn1_weight.dtype();
|
||||
|
||||
#define APPLY_FFN_KERNEL(TX1, TX2, TW) \
|
||||
return MoeExpertFFNKernel<TX1, TX2, TW>(ffn_in, \
|
||||
token_num_info, \
|
||||
ffn1_weight, \
|
||||
ffn2_weight, \
|
||||
ffn1_bias, \
|
||||
ffn2_bias, \
|
||||
ffn1_act_scale, \
|
||||
ffn2_act_scale, \
|
||||
ffn1_weight_scale, \
|
||||
ffn2_weight_scale, \
|
||||
ffn2_shift, \
|
||||
ffn2_smooth, \
|
||||
quant_method, \
|
||||
hadamard_blocksize, \
|
||||
valid_token_num);
|
||||
if (x_type == paddle::DataType::BFLOAT16 &&
|
||||
w_type == paddle::DataType::BFLOAT16) {
|
||||
APPLY_FFN_KERNEL(paddle::bfloat16, paddle::bfloat16, paddle::bfloat16);
|
||||
} else if (x_type == paddle::DataType::BFLOAT16 &&
|
||||
w_type == paddle::DataType::INT8) {
|
||||
APPLY_FFN_KERNEL(paddle::bfloat16, paddle::bfloat16, int8_t);
|
||||
} else if (x_type == paddle::DataType::INT8 &&
|
||||
w_type == paddle::DataType::INT8) {
|
||||
APPLY_FFN_KERNEL(int8_t, paddle::bfloat16, int8_t);
|
||||
} else {
|
||||
PD_THROW("MoeExpertFFN not support x_type=",
|
||||
static_cast<int>(x_type),
|
||||
", w_type=",
|
||||
static_cast<int>(w_type));
|
||||
return {};
|
||||
}
|
||||
#undef APPLY_FFN_KERNEL
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
const std::vector<int64_t>& permute_input_shape,
|
||||
const std::vector<int64_t>& token_num_info_shape,
|
||||
const std::vector<int64_t>& ffn1_weight_shape,
|
||||
const std::vector<int64_t>& ffn2_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_act_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_act_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_weight_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_weight_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_shift_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_smooth_shape) {
|
||||
return {permute_input_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
const paddle::DataType& permute_input_dtype,
|
||||
const paddle::DataType& token_num_info_dtype,
|
||||
const paddle::DataType& ffn1_weight_dtype,
|
||||
const paddle::DataType& ffn2_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn1_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn2_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn1_act_scale_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn2_act_scale_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn1_weight_scale_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn2_weight_scale_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn2_shift_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn2_smooth_dtype) {
|
||||
if (permute_input_dtype == paddle::DataType::INT8) {
|
||||
return {paddle::DataType::BFLOAT16};
|
||||
} else {
|
||||
return {permute_input_dtype};
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(moe_expert_ffn)
|
||||
.Inputs({"ffn_in",
|
||||
"token_num_info",
|
||||
"ffn1_weight",
|
||||
"ffn2_weight",
|
||||
paddle::Optional("ffn1_bias"),
|
||||
paddle::Optional("ffn2_bias"),
|
||||
paddle::Optional("ffn1_act_scale"),
|
||||
paddle::Optional("ffn2_act_scale"),
|
||||
paddle::Optional("ffn1_weight_scale"),
|
||||
paddle::Optional("ffn2_weight_scale"),
|
||||
paddle::Optional("ffn2_shift"),
|
||||
paddle::Optional("ffn2_smooth")})
|
||||
.Outputs({"ffn_out"})
|
||||
.Attrs({"quant_method:std::string",
|
||||
"hadamard_blocksize:int",
|
||||
"valid_token_num:int"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));
|
134
custom_ops/xpu_ops/src/ops/moe_redundant_topk_select.cc
Normal file
134
custom_ops/xpu_ops/src/ops/moe_redundant_topk_select.cc
Normal file
@@ -0,0 +1,134 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <infer_ops.h>
|
||||
#include <xft_api.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/backends/xpu/enforce_xpu.h"
|
||||
#include "utility/debug.h"
|
||||
|
||||
std::vector<paddle::Tensor> MoERedundantTopKSelect(
|
||||
const paddle::Tensor& gating_logits,
|
||||
const paddle::Tensor& expert_id_to_ep_rank_array,
|
||||
const paddle::Tensor& expert_in_rank_num_list,
|
||||
paddle::Tensor& tokens_per_expert_stats_list, // NOLINT
|
||||
const paddle::optional<paddle::Tensor>& bias,
|
||||
const int moe_topk,
|
||||
const bool apply_norm_weight,
|
||||
const bool enable_softmax_top_k_fused,
|
||||
const int redundant_ep_rank_num_plus_one) {
|
||||
namespace api = baidu::xpu::api;
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
api::Context* ctx = xpu_ctx->x_context();
|
||||
if (gating_logits.is_cpu()) {
|
||||
ctx = new api::Context(api::kCPU);
|
||||
}
|
||||
|
||||
PD_CHECK(apply_norm_weight, "only support apply_norm_weight==true");
|
||||
PD_CHECK(enable_softmax_top_k_fused,
|
||||
"only support enable_softmax_top_k_fused==true");
|
||||
PD_CHECK(bias.get_ptr() != nullptr, "only support bias != nullptr");
|
||||
|
||||
auto gating_logits_dims = gating_logits.shape();
|
||||
int expert_num = gating_logits_dims[gating_logits_dims.size() - 1];
|
||||
int64_t token_num = 0;
|
||||
if (gating_logits_dims.size() == 3) {
|
||||
token_num = gating_logits_dims[0] * gating_logits_dims[1];
|
||||
} else {
|
||||
token_num = gating_logits_dims[0];
|
||||
}
|
||||
auto topk_ids = paddle::empty(
|
||||
{token_num, moe_topk}, paddle::DataType::INT32, gating_logits.place());
|
||||
auto topk_ids_tmp = paddle::empty(
|
||||
{token_num, moe_topk}, paddle::DataType::INT32, gating_logits.place());
|
||||
auto source_rows_tmp = paddle::empty(
|
||||
{token_num, moe_topk}, paddle::DataType::INT32, gating_logits.place());
|
||||
auto topk_weights = paddle::empty(
|
||||
{token_num, moe_topk}, paddle::DataType::FLOAT32, gating_logits.place());
|
||||
|
||||
const float* bias_data =
|
||||
bias.get_ptr() != nullptr ? bias.get_ptr()->data<float>() : nullptr;
|
||||
int ret = infer_ops::moe_redundant_softmax_topk_normed<float, float, int>(
|
||||
ctx,
|
||||
gating_logits.data<float>(),
|
||||
bias_data,
|
||||
expert_id_to_ep_rank_array.data<int>(),
|
||||
expert_in_rank_num_list.data<int>(),
|
||||
tokens_per_expert_stats_list.data<int>(),
|
||||
topk_weights.data<float>(),
|
||||
topk_ids.data<int>(),
|
||||
topk_ids_tmp.data<int>(),
|
||||
source_rows_tmp.data<int>(),
|
||||
expert_num,
|
||||
moe_topk,
|
||||
token_num,
|
||||
redundant_ep_rank_num_plus_one);
|
||||
PD_CHECK(ret == 0);
|
||||
|
||||
return {topk_ids, topk_weights};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MoERedundantTopKSelectInferShape(
|
||||
const std::vector<int64_t>& gating_logits_shape,
|
||||
const std::vector<int64_t>& expert_id_to_ep_rank_array_shape,
|
||||
const std::vector<int64_t>& expert_in_rank_num_list_shape,
|
||||
const std::vector<int64_t>& tokens_per_expert_stats_list_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& bias_shape,
|
||||
const int moe_topk,
|
||||
const bool apply_norm_weight,
|
||||
const bool enable_softmax_top_k_fused,
|
||||
const int redundant_ep_rank_num_plus_one) {
|
||||
int64_t token_rows = -1;
|
||||
if (gating_logits_shape.size() == 3) {
|
||||
token_rows = gating_logits_shape[0] * gating_logits_shape[1];
|
||||
} else {
|
||||
token_rows = gating_logits_shape[0];
|
||||
}
|
||||
|
||||
std::vector<int64_t> topk_ids_shape = {token_rows, moe_topk};
|
||||
std::vector<int64_t> topk_weights_shape = {token_rows, moe_topk};
|
||||
return {topk_ids_shape, topk_weights_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MoERedundantTopKSelectInferDtype(
|
||||
const paddle::DataType& gating_logits_dtype,
|
||||
const paddle::DataType& expert_id_to_ep_rank_array_dtype,
|
||||
const paddle::DataType& expert_in_rank_num_list_dtype,
|
||||
const paddle::DataType& tokens_per_expert_stats_list_dtype,
|
||||
const paddle::optional<paddle::DataType>& bias_type,
|
||||
const int moe_topk,
|
||||
const bool apply_norm_weight,
|
||||
const bool enable_softmax_top_k_fused,
|
||||
const int redundant_ep_rank_num_plus_one) {
|
||||
return {paddle::DataType::INT32, paddle::DataType::FLOAT32};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(moe_redundant_topk_select)
|
||||
.Inputs({"gating_logits",
|
||||
"expert_id_to_ep_rank_array",
|
||||
"expert_in_rank_num_list",
|
||||
"tokens_per_expert_stats_list",
|
||||
paddle::Optional("bias")})
|
||||
.Outputs({"topk_ids", "topk_weights", "tokens_per_expert_stats_list_out"})
|
||||
.Attrs({"moe_topk: int",
|
||||
"apply_norm_weight: bool",
|
||||
"enable_softmax_top_k_fused:bool",
|
||||
"redundant_ep_rank_num_plus_one:int"})
|
||||
.SetInplaceMap({{"tokens_per_expert_stats_list",
|
||||
"tokens_per_expert_stats_list_out"}})
|
||||
.SetKernelFn(PD_KERNEL(MoERedundantTopKSelect))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoERedundantTopKSelectInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoERedundantTopKSelectInferDtype));
|
86
custom_ops/xpu_ops/src/ops/moe_topk_select.cc
Normal file
86
custom_ops/xpu_ops/src/ops/moe_topk_select.cc
Normal file
@@ -0,0 +1,86 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <infer_ops.h>
|
||||
#include <xft_api.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/backends/xpu/enforce_xpu.h"
|
||||
#include "utility/debug.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
std::vector<paddle::Tensor> MoeTopkSelect(
|
||||
const paddle::Tensor& gating_logits,
|
||||
const paddle::optional<paddle::Tensor>& bias,
|
||||
const int moe_topk,
|
||||
const bool apply_norm_weight) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
|
||||
PD_CHECK(apply_norm_weight, "only support apply_norm_weight==true");
|
||||
|
||||
auto gating_logits_dims = gating_logits.shape();
|
||||
int token_num = gating_logits_dims[0];
|
||||
int expert_num = gating_logits_dims[1];
|
||||
auto topk_ids = paddle::empty(
|
||||
{token_num, moe_topk}, paddle::DataType::INT32, gating_logits.place());
|
||||
auto topk_weights = paddle::empty(
|
||||
{token_num, moe_topk}, paddle::DataType::FLOAT32, gating_logits.place());
|
||||
int32_t* block_statistic = nullptr;
|
||||
const float* bias_data =
|
||||
bias.get_ptr() != nullptr ? bias.get_ptr()->data<float>() : nullptr;
|
||||
if (token_num > 0) {
|
||||
int ret = infer_ops::moe_softmax_topk_norm_fusion(
|
||||
xpu_ctx->x_context(),
|
||||
gating_logits.data<float>(),
|
||||
topk_weights.mutable_data<float>(),
|
||||
topk_ids.mutable_data<int>(),
|
||||
block_statistic,
|
||||
token_num,
|
||||
expert_num,
|
||||
moe_topk,
|
||||
0,
|
||||
bias_data);
|
||||
PD_CHECK(ret == 0);
|
||||
}
|
||||
|
||||
return {topk_ids, topk_weights};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeTopkSelectInferShape(
|
||||
const std::vector<int64_t>& gating_logits_shape,
|
||||
const std::vector<int64_t>& bias_shape,
|
||||
const int moe_topk,
|
||||
const bool apply_norm_weight) {
|
||||
std::vector<int64_t> topk_ids_shape = {gating_logits_shape[0], moe_topk};
|
||||
std::vector<int64_t> topk_weights_shape = {gating_logits_shape[0], moe_topk};
|
||||
return {topk_ids_shape, topk_weights_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MoeTopkSelectInferDtype(
|
||||
const paddle::DataType& gating_logits_dtype,
|
||||
const paddle::DataType& bias_dtype) {
|
||||
return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(moe_topk_select)
|
||||
.Inputs({"gating_logits", paddle::Optional("bias")})
|
||||
.Outputs({"topk_ids", "topk_weights"})
|
||||
.Attrs({"moe_topk: int", "apply_norm_weight: bool"})
|
||||
.SetKernelFn(PD_KERNEL(MoeTopkSelect))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeTopkSelectInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeTopkSelectInferDtype));
|
39
custom_ops/xpu_ops/src/ops/msg_utils.h
Normal file
39
custom_ops/xpu_ops/src/ops/msg_utils.h
Normal file
@@ -0,0 +1,39 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <sys/ipc.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/msg.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#define MAX_BSZ 512
|
||||
|
||||
struct msgdata {
|
||||
long mtype; // NOLINT
|
||||
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
|
||||
};
|
||||
|
||||
struct msgdatakv {
|
||||
long mtype; // NOLINT
|
||||
int mtext[MAX_BSZ * 3 + 2]; // encoder_count, layer_id, bid- pair
|
||||
};
|
@@ -17,6 +17,10 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
@@ -37,7 +41,7 @@ void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
|
||||
PADDLE_ENFORCE_XDNN_SUCCESS(r, "");
|
||||
}
|
||||
|
||||
PD_BUILD_OP(draft_model_postprocess)
|
||||
PD_BUILD_STATIC_OP(draft_model_postprocess)
|
||||
.Inputs({"base_model_draft_tokens",
|
||||
"base_model_seq_lens_this_time",
|
||||
"base_model_seq_lens_encoder",
|
@@ -17,6 +17,10 @@
|
||||
#include "paddle/phi/core/enforce.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
namespace api = baidu::xpu::api;
|
||||
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& input_ids,
|
||||
@@ -90,7 +94,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_OP(draft_model_preprocess)
|
||||
PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
.Inputs({"draft_tokens",
|
||||
"input_ids",
|
||||
"stop_flags",
|
@@ -17,6 +17,10 @@
|
||||
#include "paddle/phi/core/enforce.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& pre_ids,
|
||||
@@ -86,7 +90,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
PD_CHECK(r == 0, "draft_model_update failed.");
|
||||
}
|
||||
|
||||
PD_BUILD_OP(draft_model_update)
|
||||
PD_BUILD_STATIC_OP(draft_model_update)
|
||||
.Inputs({"inter_next_tokens",
|
||||
"draft_tokens",
|
||||
"pre_ids",
|
@@ -16,6 +16,10 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
namespace api = baidu::xpu::api;
|
||||
std::vector<paddle::Tensor> EagleGetHiddenStates(
|
||||
const paddle::Tensor& input,
|
||||
@@ -102,7 +106,7 @@ std::vector<paddle::Tensor> EagleGetHiddenStates(
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(eagle_get_hidden_states)
|
||||
PD_BUILD_STATIC_OP(eagle_get_hidden_states)
|
||||
.Inputs({"input",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_encoder",
|
@@ -16,6 +16,10 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
namespace api = baidu::xpu::api;
|
||||
std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
|
||||
const paddle::Tensor& input,
|
||||
@@ -97,7 +101,7 @@ std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(eagle_get_self_hidden_states)
|
||||
PD_BUILD_STATIC_OP(eagle_get_self_hidden_states)
|
||||
.Inputs(
|
||||
{"input", "last_seq_lens_this_time", "seq_lens_this_time", "step_idx"})
|
||||
.Outputs({"out"})
|
@@ -17,6 +17,10 @@
|
||||
#include "paddle/phi/core/enforce.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
namespace api = baidu::xpu::api;
|
||||
void MTPStepPaddle(
|
||||
const paddle::Tensor &base_model_stop_flags,
|
||||
@@ -64,7 +68,7 @@ void MTPStepPaddle(
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(mtp_step_paddle)
|
||||
PD_BUILD_STATIC_OP(mtp_step_paddle)
|
||||
.Inputs({"base_model_stop_flags",
|
||||
"stop_flags",
|
||||
"batch_drop",
|
@@ -16,6 +16,10 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& seq_lens_decoder) {
|
||||
// printf("enter clear \n");
|
||||
@@ -31,7 +35,7 @@ void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
|
||||
PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed.");
|
||||
}
|
||||
|
||||
PD_BUILD_OP(speculate_clear_accept_nums)
|
||||
PD_BUILD_STATIC_OP(speculate_clear_accept_nums)
|
||||
.Inputs({"accept_num", "seq_lens_decoder"})
|
||||
.Outputs({"seq_lens_decoder_out"})
|
||||
.SetInplaceMap({{"seq_lens_decoder", "seq_lens_decoder_out"}})
|
@@ -16,6 +16,10 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
|
||||
const paddle::Tensor& output_cum_offsets_tmp,
|
||||
const paddle::Tensor& out_token_num,
|
||||
@@ -69,7 +73,7 @@ std::vector<paddle::DataType> SpeculateGetOutputPaddingOffsetInferDtype(
|
||||
return {output_cum_offsets_tmp_dtype, output_cum_offsets_tmp_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(speculate_get_output_padding_offset)
|
||||
PD_BUILD_STATIC_OP(speculate_get_output_padding_offset)
|
||||
.Inputs({"output_cum_offsets_tmp", "out_token_num", "seq_lens_output"})
|
||||
.Outputs({"output_padding_offset", "output_cum_offsets"})
|
||||
.Attrs({"max_seq_len: int"})
|
@@ -16,6 +16,10 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& draft_tokens,
|
||||
@@ -110,7 +114,7 @@ std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
|
||||
seq_len_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(speculate_get_padding_offset)
|
||||
PD_BUILD_STATIC_OP(speculate_get_padding_offset)
|
||||
.Inputs({"input_ids",
|
||||
"draft_tokens",
|
||||
"cum_offsets",
|
@@ -16,6 +16,10 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
@@ -61,7 +65,7 @@ std::vector<paddle::DataType> SpeculateGetSeqLensOutputInferDtype(
|
||||
return {seq_lens_this_time_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(speculate_get_seq_lens_output)
|
||||
PD_BUILD_STATIC_OP(speculate_get_seq_lens_output)
|
||||
.Inputs({"seq_lens_this_time", "seq_lens_encoder", "seq_lens_decoder"})
|
||||
.Outputs({"seq_lens_output"})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateGetSeqLensOutput))
|
@@ -16,6 +16,10 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &accept_num,
|
||||
@@ -53,7 +57,7 @@ void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
||||
PD_CHECK(r == 0, "speculate_clear_accept_nums_kernel failed.");
|
||||
}
|
||||
|
||||
PD_BUILD_OP(speculate_set_value_by_flags_and_idx)
|
||||
PD_BUILD_STATIC_OP(speculate_set_value_by_flags_and_idx)
|
||||
.Inputs({"pre_ids_all",
|
||||
"accept_tokens",
|
||||
"accept_num",
|
@@ -17,6 +17,10 @@
|
||||
#include "speculate_msg.h" // NOLINT
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
// 为不修改接口调用方式,入参暂不改变
|
||||
void SpeculateStepSchedule(
|
||||
const paddle::Tensor &stop_flags,
|
||||
@@ -150,7 +154,7 @@ void SpeculateStepSchedule(
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(speculate_step_reschedule)
|
||||
PD_BUILD_STATIC_OP(speculate_step_reschedule)
|
||||
.Inputs({"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"ori_seq_lens_encoder",
|
@@ -17,20 +17,25 @@
|
||||
#include "paddle/phi/core/enforce.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
void TokenPenaltyMultiScores(const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& logits,
|
||||
const paddle::Tensor& penalty_scores,
|
||||
const paddle::Tensor& frequency_scores,
|
||||
const paddle::Tensor& presence_scores,
|
||||
const paddle::Tensor& temperatures,
|
||||
const paddle::Tensor& bad_tokens,
|
||||
const paddle::Tensor& cur_len,
|
||||
const paddle::Tensor& min_len,
|
||||
const paddle::Tensor& eos_token_id,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const int max_seq_len) {
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
void SpeculateTokenPenaltyMultiScores(
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& logits,
|
||||
const paddle::Tensor& penalty_scores,
|
||||
const paddle::Tensor& frequency_scores,
|
||||
const paddle::Tensor& presence_scores,
|
||||
const paddle::Tensor& temperatures,
|
||||
const paddle::Tensor& bad_tokens,
|
||||
const paddle::Tensor& cur_len,
|
||||
const paddle::Tensor& min_len,
|
||||
const paddle::Tensor& eos_token_id,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const int max_seq_len) {
|
||||
namespace api = baidu::xpu::api;
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
@@ -137,7 +142,7 @@ void TokenPenaltyMultiScores(const paddle::Tensor& pre_ids,
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(speculate_get_token_penalty_multi_scores)
|
||||
PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores)
|
||||
.Inputs({"pre_ids",
|
||||
"logits",
|
||||
"penalty_scores",
|
||||
@@ -154,4 +159,4 @@ PD_BUILD_OP(speculate_get_token_penalty_multi_scores)
|
||||
.Outputs({"logits_out"})
|
||||
.Attrs({"max_seq_len: int"})
|
||||
.SetInplaceMap({{"logits", "logits_out"}})
|
||||
.SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores));
|
||||
.SetKernelFn(PD_KERNEL(SpeculateTokenPenaltyMultiScores));
|
@@ -16,6 +16,10 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
namespace api = baidu::xpu::api;
|
||||
|
||||
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
|
||||
@@ -66,7 +70,7 @@ void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_OP(speculate_update_v3)
|
||||
PD_BUILD_STATIC_OP(speculate_update_v3)
|
||||
.Inputs({"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"not_need_stop",
|
@@ -17,10 +17,13 @@
|
||||
#include "paddle/common/flags.h"
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/backends/xpu/enforce_xpu.h"
|
||||
#include "ops/utility/debug.h"
|
||||
#include "xpu/internal/infra_op.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
namespace api = baidu::xpu::api;
|
||||
|
||||
void SpeculateVerify(const paddle::Tensor &accept_tokens,
|
||||
@@ -221,7 +224,7 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(speculate_verify)
|
||||
PD_BUILD_STATIC_OP(speculate_verify)
|
||||
.Inputs({"accept_tokens",
|
||||
"accept_num",
|
||||
"step_idx",
|
@@ -16,6 +16,10 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/plugin.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#define FIXED_TOPK_BASE(topk, ...) \
|
||||
case (topk): { \
|
||||
constexpr auto kTopK = topk; \
|
||||
@@ -149,7 +153,7 @@ std::vector<paddle::DataType> TopPCandidatesInferDtype(
|
||||
return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(top_p_candidates)
|
||||
PD_BUILD_STATIC_OP(top_p_candidates)
|
||||
.Inputs({"probs", "top_p", "output_padding_offset"})
|
||||
.Outputs({"verify_scores", "verify_tokens", "actual_candidate_lens"})
|
||||
.Attrs({"candidates_len: int", "max_seq_len: int"})
|
91
custom_ops/xpu_ops/src/ops/open_shm_and_get_meta_signal.cc
Normal file
91
custom_ops/xpu_ops/src/ops/open_shm_and_get_meta_signal.cc
Normal file
@@ -0,0 +1,91 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ops/pybind/pybind.h"
|
||||
#include "ops/remote_cache_kv_ipc.h"
|
||||
#include "ops/utility/env.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
XPU_DECLARE_BOOL(fmt_write_cache_completed_signal, false);
|
||||
|
||||
using cache_write_complete_signal_type =
|
||||
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data;
|
||||
|
||||
paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank,
|
||||
const bool keep_pd_step_flag) {
|
||||
cache_write_complete_signal_type kv_signal_metadata;
|
||||
const char *fmt_write_cache_completed_signal_str =
|
||||
std::getenv("FLAGS_fmt_write_cache_completed_signal");
|
||||
if (fmt_write_cache_completed_signal_str &&
|
||||
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
|
||||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
|
||||
kv_signal_metadata =
|
||||
RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data(
|
||||
rank, keep_pd_step_flag);
|
||||
}
|
||||
|
||||
auto kv_signal_metadata_out =
|
||||
paddle::full({3}, -1, paddle::DataType::INT64, paddle::CPUPlace());
|
||||
kv_signal_metadata_out.data<int64_t>()[0] =
|
||||
static_cast<int64_t>(kv_signal_metadata.layer_id);
|
||||
kv_signal_metadata_out.data<int64_t>()[1] =
|
||||
reinterpret_cast<int64_t>(kv_signal_metadata.shm_ptr);
|
||||
kv_signal_metadata_out.data<int64_t>()[2] =
|
||||
static_cast<int64_t>(kv_signal_metadata.shm_fd);
|
||||
return kv_signal_metadata_out;
|
||||
}
|
||||
|
||||
void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
|
||||
const paddle::Tensor &seq_lens_this_time_tensor,
|
||||
const paddle::Tensor &seq_lens_decoder_tensor,
|
||||
const int rank,
|
||||
const int num_layers) {
|
||||
if (FLAGS_fmt_write_cache_completed_signal) {
|
||||
int real_bsz = seq_lens_this_time_tensor.dims()[0];
|
||||
// GPU init, cp to cpu?
|
||||
auto seq_lens_encoder_cpu =
|
||||
seq_lens_encoder_tensor.copy_to(paddle::CPUPlace(), false);
|
||||
auto seq_lens_decoder_cpu =
|
||||
seq_lens_decoder_tensor.copy_to(paddle::CPUPlace(), false);
|
||||
RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query.init(
|
||||
seq_lens_encoder_cpu.data<int>(),
|
||||
seq_lens_decoder_cpu.data<int>(),
|
||||
rank,
|
||||
num_layers,
|
||||
real_bsz);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> OpenShmAndGetMetaSignal(
|
||||
const int rank, const bool keep_pd_step_flag) {
|
||||
return {OpenShmAndGetMetaSignalFunc(rank, keep_pd_step_flag)};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> OpenShmAndGetMetaSignalShape(
|
||||
const int rank, const bool keep_pd_step_flag) {
|
||||
return {{3}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> OpenShmAndGetMetaSignalDtype(
|
||||
const int rank, const bool keep_pd_step_flag) {
|
||||
return {paddle::DataType::INT64};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(open_shm_and_get_meta_signal)
|
||||
.Inputs({})
|
||||
.Outputs({"kv_signal_metadata"})
|
||||
.Attrs({"rank: int", "keep_pd_step_flag: bool"})
|
||||
.SetKernelFn(PD_KERNEL(OpenShmAndGetMetaSignal))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(OpenShmAndGetMetaSignalShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(OpenShmAndGetMetaSignalDtype));
|
46
custom_ops/xpu_ops/src/ops/pybind/alloc_cache_pinned.cc
Normal file
46
custom_ops/xpu_ops/src/ops/pybind/alloc_cache_pinned.cc
Normal file
@@ -0,0 +1,46 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <sys/mman.h> // NOLINT
|
||||
#include "cuda_runtime_api.h" // NOLINT
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/runtime.h"
|
||||
#include "ops/pybind/pybind.h"
|
||||
|
||||
void check_xpu_error(int error) {
|
||||
if (error != XPU_SUCCESS) {
|
||||
throw XPUError(error);
|
||||
}
|
||||
}
|
||||
|
||||
// 封装xpu_host_alloc的Python函数
|
||||
uintptr_t custom_xpu_host_alloc(size_t size, unsigned int flags) {
|
||||
void* ptr = nullptr;
|
||||
// check_xpu_error(xpu_host_alloc(&ptr, size, flags));
|
||||
ptr = malloc(size);
|
||||
PD_CHECK(ptr != nullptr);
|
||||
PD_CHECK(mlock(ptr, size) == 0);
|
||||
return reinterpret_cast<uintptr_t>(ptr);
|
||||
}
|
||||
|
||||
// 封装xpu_host_free的Python函数
|
||||
void custom_xpu_host_free(uintptr_t ptr) {
|
||||
check_xpu_error(xpu_host_free(reinterpret_cast<void*>(ptr)));
|
||||
}
|
||||
|
||||
// 封装cudaHostRegister的Python函数,将可分页内存注册为锁页的
|
||||
void xpu_cuda_host_register(uintptr_t ptr, size_t size, unsigned int flags) {
|
||||
cudaError_t e = cudaHostRegister(reinterpret_cast<void*>(ptr), size, flags);
|
||||
PD_CHECK(e == cudaSuccess, cudaGetErrorString(e));
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user