enable dcu ci (#3402)

This commit is contained in:
lifulll
2025-08-29 10:23:08 +08:00
committed by GitHub
parent 73d60fe64d
commit 72094d4d82
11 changed files with 295 additions and 5 deletions

View File

@@ -46,7 +46,11 @@ __global__ void GetPaddingOffsetKernel(int *batch_id_per_token,
const int ti = threadIdx.x; const int ti = threadIdx.x;
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) { for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
#ifdef PADDLE_WITH_HIP
batch_id_per_token[bi * max_seq_len - cum_offset + i] = cum_offset;
#else
batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi; batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi;
#endif
} }
if (ti == 0) { if (ti == 0) {
cum_offsets_out[bi] = cum_offset; cum_offsets_out[bi] = cum_offset;

View File

@@ -197,3 +197,13 @@ class XPUForwardMeta(ForwardMeta):
dec_batch: Optional[paddle.Tensor] = None dec_batch: Optional[paddle.Tensor] = None
# #
total_enc_len: Optional[paddle.Tensor] = None total_enc_len: Optional[paddle.Tensor] = None
@dataclass
class DCUForwardMeta(ForwardMeta):
"""
DCUForwardMeta is used to store the global meta information of the forward, and some DCU specific meta info.
"""
# Accumulated offset
cum_offsets: Optional[paddle.Tensor] = None

View File

@@ -154,7 +154,7 @@ class BlockAttentionBackend(AttentionBackend):
forward_meta.seq_lens_encoder, forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder, forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time, forward_meta.seq_lens_this_time,
forward_meta.padding_offset, forward_meta.batch_id_per_token,
forward_meta.cum_offsets, forward_meta.cum_offsets,
forward_meta.cu_seqlens_q, forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k, forward_meta.cu_seqlens_k,

View File

@@ -101,11 +101,12 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Triton compute Fused MoE. Triton compute Fused MoE.
""" """
gate_out = gate(x.cast("float32"))
token_num = x.shape[0] token_num = x.shape[0]
top_k = layer.top_k top_k = layer.top_k
num_local_experts = layer.num_local_experts num_local_experts = layer.num_local_experts
@@ -113,7 +114,6 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
moe_intermediate_size = layer.moe_intermediate_size moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size hidden_size = layer.hidden_size
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
scores = paddle.nn.functional.softmax(gate_out, axis=-1) scores = paddle.nn.functional.softmax(gate_out, axis=-1)
scores += layer.gate_correction_bias scores += layer.gate_correction_bias
topk_weights, topk_ids = paddle.topk(scores, k=top_k, axis=-1, sorted=False) topk_weights, topk_ids = paddle.topk(scores, k=top_k, axis=-1, sorted=False)

View File

@@ -21,6 +21,8 @@ def native_top_p_sampling(probs: paddle.Tensor, top_p: paddle.Tensor) -> tuple[p
sorted_indices = paddle.argsort(probs, descending=True) sorted_indices = paddle.argsort(probs, descending=True)
sorted_probs = paddle.sort(probs, descending=True) sorted_probs = paddle.sort(probs, descending=True)
cumulative_probs = paddle.cumsum(sorted_probs, axis=-1) cumulative_probs = paddle.cumsum(sorted_probs, axis=-1)
if probs.shape[0] != top_p.shape[0]:
top_p = paddle.slice(top_p, [0], [0], [probs.shape[0]])
sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove = paddle.cast(sorted_indices_to_remove, dtype="int64") sorted_indices_to_remove = paddle.cast(sorted_indices_to_remove, dtype="int64")
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()

View File

@@ -218,7 +218,7 @@ def post_process_normal(
model_output.stop_flags, model_output.stop_flags,
) )
if current_platform.is_cuda() or current_platform.is_iluvatar(): if current_platform.is_cuda() or current_platform.is_iluvatar() or current_platform.is_dcu():
set_stop_value_multi_ends( set_stop_value_multi_ends(
sampler_output.sampled_token_ids, sampler_output.sampled_token_ids,
model_output.stop_flags, model_output.stop_flags,

View File

@@ -0,0 +1,81 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import DCUForwardMeta
from fastdeploy.worker.gpu_model_runner import GPUModelRunner
class DCUModelRunner(GPUModelRunner):
def __init__(
self,
fd_config: FDConfig,
device: str, # logic device
device_id: int, # physical device id
rank: int,
local_rank: int,
):
super(DCUModelRunner, self).__init__(
fd_config=fd_config, device=device, device_id=device_id, rank=rank, local_rank=local_rank
)
def initialize_forward_meta(self):
"""
Initialize forward meta and attention meta data
"""
# Initialize forward meta
self.forward_meta = DCUForwardMeta(
input_ids=self.share_inputs["input_ids"],
ids_remove_padding=self.share_inputs["ids_remove_padding"],
rotary_embs=self.share_inputs["rope_emb"],
attn_backend=self.attn_backends[0],
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"],
max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
batch_id_per_token=self.share_inputs["batch_id_per_token"],
cum_offsets=self.share_inputs["cum_offsets"],
cu_seqlens_q=self.share_inputs["cu_seqlens_q"],
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
block_tables=self.share_inputs["block_tables"],
caches=self.share_inputs["caches"],
)
# Update Batch type for cuda graph
only_decode_batch = True
prefill_exists = None
# mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
only_decode_batch_list = []
prefill_exists = self.exist_prefill()
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
only_decode_batch = all(only_decode_batch_list)
self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
self.forward_meta.step_use_cudagraph = (
self.use_cudagraph
and only_decode_batch
and not (prefill_exists if prefill_exists is not None else self.exist_prefill())
)
# Initialzie attention meta data
for attn_backend in self.attn_backends:
attn_backend.init_attention_metadata(self.forward_meta)

View File

@@ -14,12 +14,14 @@
# limitations under the License. # limitations under the License.
""" """
import gc
import time import time
import paddle import paddle
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger, set_random_seed
from fastdeploy.worker.dcu_model_runner import DCUModelRunner
from fastdeploy.worker.gpu_worker import GpuWorker from fastdeploy.worker.gpu_worker import GpuWorker
logger = get_logger("dcu_worker", "dcu_worker.log") logger = get_logger("dcu_worker", "dcu_worker.log")
@@ -41,6 +43,41 @@ class DcuWorker(GpuWorker):
) )
pass pass
def init_device(self):
"""
Initialize device and construct model runner
"""
self.max_chips_per_node = 8
if self.device_config.device_type == "cuda" and paddle.device.is_compiled_with_cuda():
# Set evironment variable
self.device_ids = self.parallel_config.device_ids.split(",")
self.device = f"gpu:{self.local_rank % self.max_chips_per_node}"
paddle.device.set_device(self.device)
paddle.set_default_dtype(self.parallel_config.dtype)
gc.collect()
paddle.device.cuda.empty_cache()
if (
self.parallel_config.enable_custom_all_reduce
and self.parallel_config.tensor_parallel_size > 1
and paddle.is_compiled_with_cuda()
):
from fastdeploy.distributed.communication import use_custom_allreduce
use_custom_allreduce()
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")
set_random_seed(self.fd_config.model_config.seed)
# Construct model runner
self.model_runner: DCUModelRunner = DCUModelRunner(
fd_config=self.fd_config,
device=self.device,
device_id=self.device_ids[self.local_rank % self.max_chips_per_node],
rank=self.rank,
local_rank=self.local_rank,
)
def determine_available_memory(self) -> int: def determine_available_memory(self) -> int:
""" """
Profiles the peak memory usage of the model to determine how much Profiles the peak memory usage of the model to determine how much

View File

@@ -46,6 +46,11 @@ from fastdeploy.platforms import current_platform
if current_platform.is_iluvatar(): if current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import set_value_by_flags_and_idx from fastdeploy.model_executor.ops.iluvatar import set_value_by_flags_and_idx
recover_decode_task = None
share_external_data = None
elif current_platform.is_dcu():
from fastdeploy.model_executor.ops.gpu import set_value_by_flags_and_idx
recover_decode_task = None recover_decode_task = None
share_external_data = None share_external_data = None
else: else:

112
scripts/run_ci_dcu.sh Normal file
View File

@@ -0,0 +1,112 @@
#!/bin/bash
DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
echo "$DIR"
function stop_processes() {
ps -efww | grep -E 'api_server' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
ps -efww | grep -E '8188' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
lsof -t -i :8188 | xargs kill -9 || true
}
echo "Clean up processes..."
stop_processes
echo "Clean up completed."
export model_path=${MODEL_PATH}/paddle/ERNIE-4.5-21B-A3B-Paddle
python -m pip install paddlepaddle_dcu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/dcu/
python -m pip install https://paddle-whl.bj.bcebos.com/stable/dcu/triton/triton-3.0.0%2Bdas.opt4.0da70a2.dtk2504-cp310-cp310-manylinux_2_28_x86_64.whl
python -m pip install git+https://github.com/zhoutianzi666/UseTritonInPaddle.git
python -c "import use_triton_in_paddle; use_triton_in_paddle.make_triton_compatible_with_paddle()"
echo "pip install requirements_dcu"
python -m pip install -r requirements_dcu.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
echo "build whl"
bash build.sh || exit 1
unset http_proxy
unset https_proxy
unset no_proxy
rm -rf log/*
rm -f core*
# Empty the message queue
ipcrm --all=msg
echo "Start server..."
export FD_ATTENTION_BACKEND="BLOCK_ATTN"
python -m fastdeploy.entrypoints.openai.api_server \
--model ${model_path} \
--port 8188 \
--tensor-parallel-size 4 \
--gpu-memory-utilization 0.8 \
--quantization wint8 > server.log 2>&1 &
echo "Waiting 90 seconds..."
sleep 90
if grep -q "Failed to launch worker processes" server.log; then
echo "Failed to launch worker processes..."
stop_processes
cat server.log
cat log/workerlog.0
exit 1
fi
if grep -q "Traceback (most recent call last):" server.log; then
echo "Some errors occurred..."
stop_processes
cat server.log
cat log/workerlog.0
exit 1
fi
# Health check
TIMEOUT=$((5 * 60))
INTERVAL=10 # Check interval (seconds)
ENDPOINT="http://0.0.0.0:8188/health"
START_TIME=$(date +%s) # Record the start timestamp
echo "Start the server health check, maximum waiting time: ${TIMEOUT} seconds..."
while true; do
# Used to calculate the time cost
CURRENT_TIME=$(date +%s)
ELAPSED=$((CURRENT_TIME - START_TIME))
# Timeout
if [ $ELAPSED -ge $TIMEOUT ]; then
echo -e "\nServer start timeout: After $((TIMEOUT/60)) minutes, the service still doesn't start!"
cat server.log
cat log/workerlog.0
exit 1
fi
HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" -m 2 "$ENDPOINT" || true)
if [ "$HTTP_CODE" = "200" ]; then
echo -e "\nThe server was successfully launched! Totally takes $((ELAPSED+90)) seconds."
break
else
sleep $INTERVAL
fi
done
cat server.log
echo -e "\n"
echo "Start inference..."
python test/ci_use/DCU/run_ernie.py
exit_code=$?
echo "exit_code is ${exit_code}.\n"
echo "Stop server..."
stop_processes
echo "Stop server done."
if [ ${exit_code} -ne 0 ]; then
echo "Exit with error, please refer to log/workerlog.0"
cat log/workerlog.0
exit 1
fi

View File

@@ -0,0 +1,39 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import openai
ip = "0.0.0.0"
service_http_port = "8188" # 服务配置的
client = openai.Client(base_url=f"http://{ip}:{service_http_port}/v1", api_key="EMPTY_API_KEY")
# 非流式对话
response = client.chat.completions.create(
model="default",
messages=[
{"role": "user", "content": "The largest ocean is"},
],
temperature=1,
top_p=0,
max_tokens=64,
stream=False,
)
print(f"response is: {response}", flush=True)
generate_context = response.choices[0].message.content
print(f"\ngenerate_context is: {generate_context}", flush=True)
assert "pacific ocean" in generate_context.lower(), "The answer was incorrect!"
print("Test successfully!", flush=True)