diff --git a/custom_ops/gpu_ops/get_padding_offset.cu b/custom_ops/gpu_ops/get_padding_offset.cu index f505e1c32..560310148 100644 --- a/custom_ops/gpu_ops/get_padding_offset.cu +++ b/custom_ops/gpu_ops/get_padding_offset.cu @@ -46,7 +46,11 @@ __global__ void GetPaddingOffsetKernel(int *batch_id_per_token, const int ti = threadIdx.x; int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; for (int i = ti; i < seq_lens[bi]; i += blockDim.x) { +#ifdef PADDLE_WITH_HIP + batch_id_per_token[bi * max_seq_len - cum_offset + i] = cum_offset; +#else batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi; +#endif } if (ti == 0) { cum_offsets_out[bi] = cum_offset; diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index eb8f4b5f8..968495733 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -197,3 +197,13 @@ class XPUForwardMeta(ForwardMeta): dec_batch: Optional[paddle.Tensor] = None # total_enc_len: Optional[paddle.Tensor] = None + + +@dataclass +class DCUForwardMeta(ForwardMeta): + """ + DCUForwardMeta is used to store the global meta information of the forward, and some DCU specific meta info. + """ + + # Accumulated offset + cum_offsets: Optional[paddle.Tensor] = None diff --git a/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py b/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py index 2802e97ba..f5800d156 100644 --- a/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py @@ -154,7 +154,7 @@ class BlockAttentionBackend(AttentionBackend): forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, - forward_meta.padding_offset, + forward_meta.batch_id_per_token, forward_meta.cum_offsets, forward_meta.cu_seqlens_q, forward_meta.cu_seqlens_k, diff --git a/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py b/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py index cd6a51161..0038ed149 100644 --- a/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py +++ b/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py @@ -101,11 +101,12 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase): self, layer: nn.Layer, x: paddle.Tensor, - gate_out: paddle.Tensor, + gate: nn.Layer, ) -> paddle.Tensor: """ Triton compute Fused MoE. """ + gate_out = gate(x.cast("float32")) token_num = x.shape[0] top_k = layer.top_k num_local_experts = layer.num_local_experts @@ -113,7 +114,6 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase): moe_intermediate_size = layer.moe_intermediate_size hidden_size = layer.hidden_size - gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight) scores = paddle.nn.functional.softmax(gate_out, axis=-1) scores += layer.gate_correction_bias topk_weights, topk_ids = paddle.topk(scores, k=top_k, axis=-1, sorted=False) diff --git a/fastdeploy/model_executor/layers/backends/dcu/top_p_sampling.py b/fastdeploy/model_executor/layers/backends/dcu/top_p_sampling.py index 1eafe1351..59ac109a4 100644 --- a/fastdeploy/model_executor/layers/backends/dcu/top_p_sampling.py +++ b/fastdeploy/model_executor/layers/backends/dcu/top_p_sampling.py @@ -21,6 +21,8 @@ def native_top_p_sampling(probs: paddle.Tensor, top_p: paddle.Tensor) -> tuple[p sorted_indices = paddle.argsort(probs, descending=True) sorted_probs = paddle.sort(probs, descending=True) cumulative_probs = paddle.cumsum(sorted_probs, axis=-1) + if probs.shape[0] != top_p.shape[0]: + top_p = paddle.slice(top_p, [0], [0], [probs.shape[0]]) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove = paddle.cast(sorted_indices_to_remove, dtype="int64") sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index a0ff96475..92c43ede4 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -218,7 +218,7 @@ def post_process_normal( model_output.stop_flags, ) - if current_platform.is_cuda() or current_platform.is_iluvatar(): + if current_platform.is_cuda() or current_platform.is_iluvatar() or current_platform.is_dcu(): set_stop_value_multi_ends( sampler_output.sampled_token_ids, model_output.stop_flags, diff --git a/fastdeploy/worker/dcu_model_runner.py b/fastdeploy/worker/dcu_model_runner.py new file mode 100644 index 000000000..df01b1bd7 --- /dev/null +++ b/fastdeploy/worker/dcu_model_runner.py @@ -0,0 +1,81 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.forward_meta import DCUForwardMeta +from fastdeploy.worker.gpu_model_runner import GPUModelRunner + + +class DCUModelRunner(GPUModelRunner): + def __init__( + self, + fd_config: FDConfig, + device: str, # logic device + device_id: int, # physical device id + rank: int, + local_rank: int, + ): + super(DCUModelRunner, self).__init__( + fd_config=fd_config, device=device, device_id=device_id, rank=rank, local_rank=local_rank + ) + + def initialize_forward_meta(self): + """ + Initialize forward meta and attention meta data + """ + # Initialize forward meta + self.forward_meta = DCUForwardMeta( + input_ids=self.share_inputs["input_ids"], + ids_remove_padding=self.share_inputs["ids_remove_padding"], + rotary_embs=self.share_inputs["rope_emb"], + attn_backend=self.attn_backends[0], + decoder_batch_ids=self.share_inputs["decoder_batch_ids"], + decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"], + decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"], + max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + batch_id_per_token=self.share_inputs["batch_id_per_token"], + cum_offsets=self.share_inputs["cum_offsets"], + cu_seqlens_q=self.share_inputs["cu_seqlens_q"], + cu_seqlens_k=self.share_inputs["cu_seqlens_k"], + block_tables=self.share_inputs["block_tables"], + caches=self.share_inputs["caches"], + ) + + # Update Batch type for cuda graph + only_decode_batch = True + prefill_exists = None + # mix ep in single node + if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": + only_decode_batch_list = [] + prefill_exists = self.exist_prefill() + paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) + only_decode_batch = all(only_decode_batch_list) + self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill" + + self.forward_meta.step_use_cudagraph = ( + self.use_cudagraph + and only_decode_batch + and not (prefill_exists if prefill_exists is not None else self.exist_prefill()) + ) + + # Initialzie attention meta data + for attn_backend in self.attn_backends: + attn_backend.init_attention_metadata(self.forward_meta) diff --git a/fastdeploy/worker/dcu_worker.py b/fastdeploy/worker/dcu_worker.py index 58f13bdfb..60fab7a95 100644 --- a/fastdeploy/worker/dcu_worker.py +++ b/fastdeploy/worker/dcu_worker.py @@ -14,12 +14,14 @@ # limitations under the License. """ +import gc import time import paddle from fastdeploy.config import FDConfig -from fastdeploy.utils import get_logger +from fastdeploy.utils import get_logger, set_random_seed +from fastdeploy.worker.dcu_model_runner import DCUModelRunner from fastdeploy.worker.gpu_worker import GpuWorker logger = get_logger("dcu_worker", "dcu_worker.log") @@ -41,6 +43,41 @@ class DcuWorker(GpuWorker): ) pass + def init_device(self): + """ + Initialize device and construct model runner + """ + self.max_chips_per_node = 8 + if self.device_config.device_type == "cuda" and paddle.device.is_compiled_with_cuda(): + # Set evironment variable + self.device_ids = self.parallel_config.device_ids.split(",") + self.device = f"gpu:{self.local_rank % self.max_chips_per_node}" + paddle.device.set_device(self.device) + paddle.set_default_dtype(self.parallel_config.dtype) + + gc.collect() + paddle.device.cuda.empty_cache() + if ( + self.parallel_config.enable_custom_all_reduce + and self.parallel_config.tensor_parallel_size > 1 + and paddle.is_compiled_with_cuda() + ): + from fastdeploy.distributed.communication import use_custom_allreduce + + use_custom_allreduce() + else: + raise RuntimeError(f"Not support device type: {self.device_config.device}") + + set_random_seed(self.fd_config.model_config.seed) + # Construct model runner + self.model_runner: DCUModelRunner = DCUModelRunner( + fd_config=self.fd_config, + device=self.device, + device_id=self.device_ids[self.local_rank % self.max_chips_per_node], + rank=self.rank, + local_rank=self.local_rank, + ) + def determine_available_memory(self) -> int: """ Profiles the peak memory usage of the model to determine how much diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 171d038d8..95649f0a6 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -46,6 +46,11 @@ from fastdeploy.platforms import current_platform if current_platform.is_iluvatar(): from fastdeploy.model_executor.ops.iluvatar import set_value_by_flags_and_idx + recover_decode_task = None + share_external_data = None +elif current_platform.is_dcu(): + from fastdeploy.model_executor.ops.gpu import set_value_by_flags_and_idx + recover_decode_task = None share_external_data = None else: diff --git a/scripts/run_ci_dcu.sh b/scripts/run_ci_dcu.sh new file mode 100644 index 000000000..bb13ff95b --- /dev/null +++ b/scripts/run_ci_dcu.sh @@ -0,0 +1,112 @@ +#!/bin/bash +DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +echo "$DIR" + +function stop_processes() { + ps -efww | grep -E 'api_server' | grep -v grep | awk '{print $2}' | xargs kill -9 || true + ps -efww | grep -E '8188' | grep -v grep | awk '{print $2}' | xargs kill -9 || true + lsof -t -i :8188 | xargs kill -9 || true +} + +echo "Clean up processes..." +stop_processes +echo "Clean up completed." + +export model_path=${MODEL_PATH}/paddle/ERNIE-4.5-21B-A3B-Paddle + +python -m pip install paddlepaddle_dcu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/dcu/ +python -m pip install https://paddle-whl.bj.bcebos.com/stable/dcu/triton/triton-3.0.0%2Bdas.opt4.0da70a2.dtk2504-cp310-cp310-manylinux_2_28_x86_64.whl + +python -m pip install git+https://github.com/zhoutianzi666/UseTritonInPaddle.git +python -c "import use_triton_in_paddle; use_triton_in_paddle.make_triton_compatible_with_paddle()" + +echo "pip install requirements_dcu" +python -m pip install -r requirements_dcu.txt -i https://pypi.tuna.tsinghua.edu.cn/simple + +echo "build whl" +bash build.sh || exit 1 + +unset http_proxy +unset https_proxy +unset no_proxy + + +rm -rf log/* +rm -f core* + +# Empty the message queue +ipcrm --all=msg +echo "Start server..." +export FD_ATTENTION_BACKEND="BLOCK_ATTN" +python -m fastdeploy.entrypoints.openai.api_server \ + --model ${model_path} \ + --port 8188 \ + --tensor-parallel-size 4 \ + --gpu-memory-utilization 0.8 \ + --quantization wint8 > server.log 2>&1 & + +echo "Waiting 90 seconds..." +sleep 90 + +if grep -q "Failed to launch worker processes" server.log; then + echo "Failed to launch worker processes..." + stop_processes + cat server.log + cat log/workerlog.0 + exit 1 +fi + +if grep -q "Traceback (most recent call last):" server.log; then + echo "Some errors occurred..." + stop_processes + cat server.log + cat log/workerlog.0 + exit 1 +fi + +# Health check +TIMEOUT=$((5 * 60)) +INTERVAL=10 # Check interval (seconds) +ENDPOINT="http://0.0.0.0:8188/health" +START_TIME=$(date +%s) # Record the start timestamp +echo "Start the server health check, maximum waiting time: ${TIMEOUT} seconds..." +while true; do + # Used to calculate the time cost + CURRENT_TIME=$(date +%s) + ELAPSED=$((CURRENT_TIME - START_TIME)) + + # Timeout + if [ $ELAPSED -ge $TIMEOUT ]; then + echo -e "\nServer start timeout: After $((TIMEOUT/60)) minutes, the service still doesn't start!" + cat server.log + cat log/workerlog.0 + exit 1 + fi + + HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" -m 2 "$ENDPOINT" || true) + + if [ "$HTTP_CODE" = "200" ]; then + echo -e "\nThe server was successfully launched! Totally takes $((ELAPSED+90)) seconds." + break + else + sleep $INTERVAL + fi +done + +cat server.log +echo -e "\n" + +echo "Start inference..." +python test/ci_use/DCU/run_ernie.py +exit_code=$? +echo "exit_code is ${exit_code}.\n" + +echo "Stop server..." +stop_processes +echo "Stop server done." + +if [ ${exit_code} -ne 0 ]; then + echo "Exit with error, please refer to log/workerlog.0" + cat log/workerlog.0 + exit 1 +fi diff --git a/test/ci_use/DCU/run_ernie.py b/test/ci_use/DCU/run_ernie.py new file mode 100644 index 000000000..4120d74dc --- /dev/null +++ b/test/ci_use/DCU/run_ernie.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import openai + +ip = "0.0.0.0" +service_http_port = "8188" # 服务配置的 +client = openai.Client(base_url=f"http://{ip}:{service_http_port}/v1", api_key="EMPTY_API_KEY") + +# 非流式对话 +response = client.chat.completions.create( + model="default", + messages=[ + {"role": "user", "content": "The largest ocean is"}, + ], + temperature=1, + top_p=0, + max_tokens=64, + stream=False, +) +print(f"response is: {response}", flush=True) + +generate_context = response.choices[0].message.content +print(f"\ngenerate_context is: {generate_context}", flush=True) + +assert "pacific ocean" in generate_context.lower(), "The answer was incorrect!" + +print("Test successfully!", flush=True)