mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
enable dcu ci (#3402)
This commit is contained in:
@@ -46,7 +46,11 @@ __global__ void GetPaddingOffsetKernel(int *batch_id_per_token,
|
||||
const int ti = threadIdx.x;
|
||||
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
|
||||
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
|
||||
#ifdef PADDLE_WITH_HIP
|
||||
batch_id_per_token[bi * max_seq_len - cum_offset + i] = cum_offset;
|
||||
#else
|
||||
batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi;
|
||||
#endif
|
||||
}
|
||||
if (ti == 0) {
|
||||
cum_offsets_out[bi] = cum_offset;
|
||||
|
@@ -197,3 +197,13 @@ class XPUForwardMeta(ForwardMeta):
|
||||
dec_batch: Optional[paddle.Tensor] = None
|
||||
#
|
||||
total_enc_len: Optional[paddle.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DCUForwardMeta(ForwardMeta):
|
||||
"""
|
||||
DCUForwardMeta is used to store the global meta information of the forward, and some DCU specific meta info.
|
||||
"""
|
||||
|
||||
# Accumulated offset
|
||||
cum_offsets: Optional[paddle.Tensor] = None
|
||||
|
@@ -154,7 +154,7 @@ class BlockAttentionBackend(AttentionBackend):
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.padding_offset,
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cum_offsets,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.cu_seqlens_k,
|
||||
|
@@ -101,11 +101,12 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
"""
|
||||
gate_out = gate(x.cast("float32"))
|
||||
token_num = x.shape[0]
|
||||
top_k = layer.top_k
|
||||
num_local_experts = layer.num_local_experts
|
||||
@@ -113,7 +114,6 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
|
||||
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
|
||||
scores = paddle.nn.functional.softmax(gate_out, axis=-1)
|
||||
scores += layer.gate_correction_bias
|
||||
topk_weights, topk_ids = paddle.topk(scores, k=top_k, axis=-1, sorted=False)
|
||||
|
@@ -21,6 +21,8 @@ def native_top_p_sampling(probs: paddle.Tensor, top_p: paddle.Tensor) -> tuple[p
|
||||
sorted_indices = paddle.argsort(probs, descending=True)
|
||||
sorted_probs = paddle.sort(probs, descending=True)
|
||||
cumulative_probs = paddle.cumsum(sorted_probs, axis=-1)
|
||||
if probs.shape[0] != top_p.shape[0]:
|
||||
top_p = paddle.slice(top_p, [0], [0], [probs.shape[0]])
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove = paddle.cast(sorted_indices_to_remove, dtype="int64")
|
||||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
||||
|
@@ -218,7 +218,7 @@ def post_process_normal(
|
||||
model_output.stop_flags,
|
||||
)
|
||||
|
||||
if current_platform.is_cuda() or current_platform.is_iluvatar():
|
||||
if current_platform.is_cuda() or current_platform.is_iluvatar() or current_platform.is_dcu():
|
||||
set_stop_value_multi_ends(
|
||||
sampler_output.sampled_token_ids,
|
||||
model_output.stop_flags,
|
||||
|
81
fastdeploy/worker/dcu_model_runner.py
Normal file
81
fastdeploy/worker/dcu_model_runner.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.forward_meta import DCUForwardMeta
|
||||
from fastdeploy.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
|
||||
class DCUModelRunner(GPUModelRunner):
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
device: str, # logic device
|
||||
device_id: int, # physical device id
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
):
|
||||
super(DCUModelRunner, self).__init__(
|
||||
fd_config=fd_config, device=device, device_id=device_id, rank=rank, local_rank=local_rank
|
||||
)
|
||||
|
||||
def initialize_forward_meta(self):
|
||||
"""
|
||||
Initialize forward meta and attention meta data
|
||||
"""
|
||||
# Initialize forward meta
|
||||
self.forward_meta = DCUForwardMeta(
|
||||
input_ids=self.share_inputs["input_ids"],
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
rotary_embs=self.share_inputs["rope_emb"],
|
||||
attn_backend=self.attn_backends[0],
|
||||
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
|
||||
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
|
||||
decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"],
|
||||
max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"],
|
||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
|
||||
batch_id_per_token=self.share_inputs["batch_id_per_token"],
|
||||
cum_offsets=self.share_inputs["cum_offsets"],
|
||||
cu_seqlens_q=self.share_inputs["cu_seqlens_q"],
|
||||
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
|
||||
block_tables=self.share_inputs["block_tables"],
|
||||
caches=self.share_inputs["caches"],
|
||||
)
|
||||
|
||||
# Update Batch type for cuda graph
|
||||
only_decode_batch = True
|
||||
prefill_exists = None
|
||||
# mix ep in single node
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
only_decode_batch_list = []
|
||||
prefill_exists = self.exist_prefill()
|
||||
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
|
||||
only_decode_batch = all(only_decode_batch_list)
|
||||
self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
|
||||
|
||||
self.forward_meta.step_use_cudagraph = (
|
||||
self.use_cudagraph
|
||||
and only_decode_batch
|
||||
and not (prefill_exists if prefill_exists is not None else self.exist_prefill())
|
||||
)
|
||||
|
||||
# Initialzie attention meta data
|
||||
for attn_backend in self.attn_backends:
|
||||
attn_backend.init_attention_metadata(self.forward_meta)
|
@@ -14,12 +14,14 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import time
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.utils import get_logger
|
||||
from fastdeploy.utils import get_logger, set_random_seed
|
||||
from fastdeploy.worker.dcu_model_runner import DCUModelRunner
|
||||
from fastdeploy.worker.gpu_worker import GpuWorker
|
||||
|
||||
logger = get_logger("dcu_worker", "dcu_worker.log")
|
||||
@@ -41,6 +43,41 @@ class DcuWorker(GpuWorker):
|
||||
)
|
||||
pass
|
||||
|
||||
def init_device(self):
|
||||
"""
|
||||
Initialize device and construct model runner
|
||||
"""
|
||||
self.max_chips_per_node = 8
|
||||
if self.device_config.device_type == "cuda" and paddle.device.is_compiled_with_cuda():
|
||||
# Set evironment variable
|
||||
self.device_ids = self.parallel_config.device_ids.split(",")
|
||||
self.device = f"gpu:{self.local_rank % self.max_chips_per_node}"
|
||||
paddle.device.set_device(self.device)
|
||||
paddle.set_default_dtype(self.parallel_config.dtype)
|
||||
|
||||
gc.collect()
|
||||
paddle.device.cuda.empty_cache()
|
||||
if (
|
||||
self.parallel_config.enable_custom_all_reduce
|
||||
and self.parallel_config.tensor_parallel_size > 1
|
||||
and paddle.is_compiled_with_cuda()
|
||||
):
|
||||
from fastdeploy.distributed.communication import use_custom_allreduce
|
||||
|
||||
use_custom_allreduce()
|
||||
else:
|
||||
raise RuntimeError(f"Not support device type: {self.device_config.device}")
|
||||
|
||||
set_random_seed(self.fd_config.model_config.seed)
|
||||
# Construct model runner
|
||||
self.model_runner: DCUModelRunner = DCUModelRunner(
|
||||
fd_config=self.fd_config,
|
||||
device=self.device,
|
||||
device_id=self.device_ids[self.local_rank % self.max_chips_per_node],
|
||||
rank=self.rank,
|
||||
local_rank=self.local_rank,
|
||||
)
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
"""
|
||||
Profiles the peak memory usage of the model to determine how much
|
||||
|
@@ -46,6 +46,11 @@ from fastdeploy.platforms import current_platform
|
||||
if current_platform.is_iluvatar():
|
||||
from fastdeploy.model_executor.ops.iluvatar import set_value_by_flags_and_idx
|
||||
|
||||
recover_decode_task = None
|
||||
share_external_data = None
|
||||
elif current_platform.is_dcu():
|
||||
from fastdeploy.model_executor.ops.gpu import set_value_by_flags_and_idx
|
||||
|
||||
recover_decode_task = None
|
||||
share_external_data = None
|
||||
else:
|
||||
|
112
scripts/run_ci_dcu.sh
Normal file
112
scripts/run_ci_dcu.sh
Normal file
@@ -0,0 +1,112 @@
|
||||
#!/bin/bash
|
||||
DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
echo "$DIR"
|
||||
|
||||
function stop_processes() {
|
||||
ps -efww | grep -E 'api_server' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
|
||||
ps -efww | grep -E '8188' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
|
||||
lsof -t -i :8188 | xargs kill -9 || true
|
||||
}
|
||||
|
||||
echo "Clean up processes..."
|
||||
stop_processes
|
||||
echo "Clean up completed."
|
||||
|
||||
export model_path=${MODEL_PATH}/paddle/ERNIE-4.5-21B-A3B-Paddle
|
||||
|
||||
python -m pip install paddlepaddle_dcu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/dcu/
|
||||
python -m pip install https://paddle-whl.bj.bcebos.com/stable/dcu/triton/triton-3.0.0%2Bdas.opt4.0da70a2.dtk2504-cp310-cp310-manylinux_2_28_x86_64.whl
|
||||
|
||||
python -m pip install git+https://github.com/zhoutianzi666/UseTritonInPaddle.git
|
||||
python -c "import use_triton_in_paddle; use_triton_in_paddle.make_triton_compatible_with_paddle()"
|
||||
|
||||
echo "pip install requirements_dcu"
|
||||
python -m pip install -r requirements_dcu.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
echo "build whl"
|
||||
bash build.sh || exit 1
|
||||
|
||||
unset http_proxy
|
||||
unset https_proxy
|
||||
unset no_proxy
|
||||
|
||||
|
||||
rm -rf log/*
|
||||
rm -f core*
|
||||
|
||||
# Empty the message queue
|
||||
ipcrm --all=msg
|
||||
echo "Start server..."
|
||||
export FD_ATTENTION_BACKEND="BLOCK_ATTN"
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${model_path} \
|
||||
--port 8188 \
|
||||
--tensor-parallel-size 4 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--quantization wint8 > server.log 2>&1 &
|
||||
|
||||
echo "Waiting 90 seconds..."
|
||||
sleep 90
|
||||
|
||||
if grep -q "Failed to launch worker processes" server.log; then
|
||||
echo "Failed to launch worker processes..."
|
||||
stop_processes
|
||||
cat server.log
|
||||
cat log/workerlog.0
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if grep -q "Traceback (most recent call last):" server.log; then
|
||||
echo "Some errors occurred..."
|
||||
stop_processes
|
||||
cat server.log
|
||||
cat log/workerlog.0
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Health check
|
||||
TIMEOUT=$((5 * 60))
|
||||
INTERVAL=10 # Check interval (seconds)
|
||||
ENDPOINT="http://0.0.0.0:8188/health"
|
||||
START_TIME=$(date +%s) # Record the start timestamp
|
||||
echo "Start the server health check, maximum waiting time: ${TIMEOUT} seconds..."
|
||||
while true; do
|
||||
# Used to calculate the time cost
|
||||
CURRENT_TIME=$(date +%s)
|
||||
ELAPSED=$((CURRENT_TIME - START_TIME))
|
||||
|
||||
# Timeout
|
||||
if [ $ELAPSED -ge $TIMEOUT ]; then
|
||||
echo -e "\nServer start timeout: After $((TIMEOUT/60)) minutes, the service still doesn't start!"
|
||||
cat server.log
|
||||
cat log/workerlog.0
|
||||
exit 1
|
||||
fi
|
||||
|
||||
HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" -m 2 "$ENDPOINT" || true)
|
||||
|
||||
if [ "$HTTP_CODE" = "200" ]; then
|
||||
echo -e "\nThe server was successfully launched! Totally takes $((ELAPSED+90)) seconds."
|
||||
break
|
||||
else
|
||||
sleep $INTERVAL
|
||||
fi
|
||||
done
|
||||
|
||||
cat server.log
|
||||
echo -e "\n"
|
||||
|
||||
echo "Start inference..."
|
||||
python test/ci_use/DCU/run_ernie.py
|
||||
exit_code=$?
|
||||
echo "exit_code is ${exit_code}.\n"
|
||||
|
||||
echo "Stop server..."
|
||||
stop_processes
|
||||
echo "Stop server done."
|
||||
|
||||
if [ ${exit_code} -ne 0 ]; then
|
||||
echo "Exit with error, please refer to log/workerlog.0"
|
||||
cat log/workerlog.0
|
||||
exit 1
|
||||
fi
|
39
test/ci_use/DCU/run_ernie.py
Normal file
39
test/ci_use/DCU/run_ernie.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import openai
|
||||
|
||||
ip = "0.0.0.0"
|
||||
service_http_port = "8188" # 服务配置的
|
||||
client = openai.Client(base_url=f"http://{ip}:{service_http_port}/v1", api_key="EMPTY_API_KEY")
|
||||
|
||||
# 非流式对话
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "user", "content": "The largest ocean is"},
|
||||
],
|
||||
temperature=1,
|
||||
top_p=0,
|
||||
max_tokens=64,
|
||||
stream=False,
|
||||
)
|
||||
print(f"response is: {response}", flush=True)
|
||||
|
||||
generate_context = response.choices[0].message.content
|
||||
print(f"\ngenerate_context is: {generate_context}", flush=True)
|
||||
|
||||
assert "pacific ocean" in generate_context.lower(), "The answer was incorrect!"
|
||||
|
||||
print("Test successfully!", flush=True)
|
Reference in New Issue
Block a user