[Intel HPU] add example benchmark scripts for hpu (#5304)

* [Intel HPU] add example benchmark scripts for hpu

* Revise the code based on the copilot comments

* update code based on comments

* update ci ops version
This commit is contained in:
fmiao2372
2025-12-02 18:00:01 +08:00
committed by GitHub
parent fb7f951612
commit 429dd2b1db
12 changed files with 983 additions and 2 deletions

View File

@@ -0,0 +1,246 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Metric evaluation for Fastdeploy + ERNIE-4.5-Turbo"""
# adapted from https://github.com/sgl-project/sglang/blob/main/benchmark/gsm8k/bench_other.py
import argparse
import ast
import json
import re
import time
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import requests
from tqdm import tqdm
INVALID = -9999999
def call_generate(prompt, **kwargs):
"""
Generates response based on the input prompt.
Args:
prompt (str): The input prompt text.
**kwargs: Keyword arguments, including server IP address and port number.
Returns:
str: The response generated based on the prompt.
"""
url = f"http://{kwargs['ip']}:{kwargs['port']}/v1/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
"messages": [
{
"role": "user",
"content": prompt,
}
],
"temperature": 0.6,
"max_tokens": 2047,
"top_p": 0.95,
"do_sample": True,
}
response = requests.post(url, headers=headers, data=json.dumps(data))
out = response.json()
return out["choices"][0]["message"]["content"]
def get_one_example(lines, i, include_answer):
"""
Retrieves a question-answer example from the given list of text lines.
Args:
lines (list of dict): A list of question-answer pairs.
i (int): The index of the question-answer pair to retrieve from lines.
include_answer (bool): Whether to include the answer in the returned string.
Returns:
str: A formatted question-answer string in the format "Question: <question>\nAnswer: <answer>".
"""
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
if include_answer:
ret += " " + lines[i]["answer"]
return ret
def get_few_shot_examples(lines, k):
"""
Selects k examples from the given list of text lines and concatenates them into a single string.
Args:
lines (list): A list containing text lines.
k (int): The number of examples to select.
Returns:
str: A string composed of k examples, separated by two newline characters.
"""
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
def get_answer_value(answer_str):
"""
Extracts numerical values from an answer string and returns them.
Args:
answer_str (str): The string containing the answer.
Returns:
The extracted numerical value; returns "INVALID" if extraction fails.
"""
answer_str = answer_str.replace(",", "")
numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
def read_jsonl(filename: str):
"""
Reads a JSONL file.
Args:
filename (str): Path to the JSONL file.
Yields:
dict: A dictionary object corresponding to each line in the JSONL file.
"""
with open(filename) as fin:
for line in fin:
if line.startswith("#"):
continue
yield json.loads(line)
def main(args):
"""
Process inputs and generate answers by calling the model in parallel using a thread pool.
Args:
args (argparse.Namespace):
- num_questions (int): Number of questions to process.
- num_shots (int): Number of few-shot learning examples.
- ip (str): IP address of the model service.
- port (int): Port number of the model service.
- parallel (int): Number of questions to process in parallel.
- result_file (str): File path to store the results.
Returns:
None
"""
# Read data
filename = "test.jsonl"
lines = list(read_jsonl(filename))
# Construct prompts
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
labels = []
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
states = [None] * len(labels)
# Use thread pool
def get_one_answer(i):
answer = call_generate(
prompt=few_shot_examples + questions[i],
# stop=["Question", "Assistant:", "<|separator|>"],
ip=args.ip,
port=args.port,
)
states[i] = answer
tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(questions))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(get_one_answer, list(range(len(questions)))),
total=len(questions),
)
)
latency = time.time() - tic
preds = []
with open(args.acc_log, "w") as fout:
for i in range(len(states)):
preds.append(get_answer_value(states[i]))
answer = get_answer_value(states[i])
fout.write("\n################################################################\n")
fout.write("-----------prompt--------------\n")
fout.write(f"{few_shot_examples + questions[i]}\n")
fout.write("-----------answer--------------\n")
fout.write(f"answer= {states[i]}\n")
fout.write("-----------accuracy--------------\n")
fout.write(f"Correct={answer==labels[i]}, pred={answer}, label={labels[i]} \n")
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
# Print results
print(f"Accuracy: {acc:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Latency: {latency:.3f} s")
with open(args.result_file, "a") as fout:
value = {
"task": "gsm8k",
"backend": "paddlepaddle",
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ip", type=str, default="127.0.0.1")
parser.add_argument("--port", type=str, default="8188")
parser.add_argument("--num-shots", type=int, default=10)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=1319)
parser.add_argument("--result-file", type=str, default="result.jsonl")
parser.add_argument("--parallel", type=int, default=1)
parser.add_argument("--acc-log", type=str, default="accuracy.log")
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,72 @@
#!/bin/bash
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# set -x
model="ERNIE-4.5-21B-A3B-Paddle"
model_log_name="ERNIE-4.5-21B-A3B-Paddle"
model_yaml="yaml/eb45-21b-a3b-32k-bf16.yaml"
# model="ERNIE-4.5-300B-A47B-Paddle"
# model_log_name="ERNIE-4.5-300B-A47B-Paddle"
# model_yaml="yaml/eb45-300b-a47b-32k-bf16.yaml"
export SERVER_PORT=8188
export no_proxy=localhost,127.0.0.1,0.0.0.0,10.0.0.0/8,192.168.1.0/24
input_lengths=(1024 2048)
output_lengths=(1024)
batch_sizes=(1 2 4 8 16 32 64 128)
workspace=$(pwd)
cd $workspace
log_home=$workspace/benchmark_fastdeploy_logs/$(TZ='Asia/Shanghai' date '+WW%V')_$(TZ='Asia/Shanghai' date +%F-%H-%M-%S)_${model_log_name}_FixedLen
mkdir -p ${log_home}
for input_length in "${input_lengths[@]}"
do
for output_length in "${output_lengths[@]}"
do
for batch_size in "${batch_sizes[@]}"
do
> log/hpu_model_runner_profile.log
num_prompts=$(( batch_size * 3))
log_name_prefix="benchmarkdata_${model_log_name}_inputlength_${input_length}_outputlength_${output_length}_batchsize_${batch_size}_numprompts_${num_prompts}"
log_name=${log_name_prefix}_$(TZ='Asia/Shanghai' date +%F-%H-%M-%S)
echo "running benchmark with input length ${input_length}, output length ${output_length}, batch size ${batch_size}, log name ${log_name}"
cmd="python ../../benchmarks/benchmark_serving.py \
--backend openai-chat \
--model $model \
--endpoint /v1/chat/completions \
--host 0.0.0.0 \
--port ${SERVER_PORT} \
--dataset-name random \
--random-input-len ${input_length} \
--random-output-len ${output_length} \
--random-range-ratio 0 \
--hyperparameter-path ../../benchmarks/${model_yaml} \
--percentile-metrics ttft,tpot,itl,e2el,s_ttft,s_itl,s_e2el,s_decode,input_len,s_input_len,output_len \
--metric-percentiles 80,95,99,99.9,99.95,99.99 \
--num-prompts ${num_prompts} \
--max-concurrency ${batch_size} \
--ignore-eos"
echo $cmd | tee -a ${log_home}/${log_name}.log
eval $cmd >> ${log_home}/${log_name}.log 2>&1
cp log/hpu_model_runner_profile.log ${log_home}/${log_name}_profile.log
done
done
done

View File

@@ -0,0 +1,64 @@
#!/bin/bash
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# set -x
model="ERNIE-4.5-21B-A3B-Paddle"
model_log_name="ERNIE-4.5-21B-A3B-Paddle"
model_yaml="yaml/eb45-21b-a3b-32k-bf16.yaml"
# model="ERNIE-4.5-300B-A47B-Paddle"
# model_log_name="ERNIE-4.5-300B-A47B-Paddle"
# model_yaml="yaml/eb45-300b-a47b-32k-bf16.yaml"
export SERVER_PORT=8188
export no_proxy=.intel.com,intel.com,localhost,127.0.0.1,0.0.0.0,10.0.0.0/8,192.168.1.0/24
CARD_NUM=$1
if [[ "$CARD_NUM" == "1" ]]; then
batch_size=128
else
batch_size=64
fi
num_prompts=2000
workspace=$(pwd)
cd $workspace
log_home=$workspace/benchmark_fastdeploy_logs/$(TZ='Asia/Shanghai' date '+WW%V')_$(TZ='Asia/Shanghai' date +%F-%H-%M-%S)_${model_log_name}
mkdir -p ${log_home}
log_name_prefix="benchmarkdata_${model_log_name}_sharegpt"
log_name=${log_name_prefix}_$(TZ='Asia/Shanghai' date +%F-%H-%M-%S)
echo "running benchmark with sharegpt log name ${log_name}"
cmd="python ../../benchmarks/benchmark_serving.py \
--backend openai-chat \
--model $model \
--endpoint /v1/chat/completions \
--host 0.0.0.0 \
--port ${SERVER_PORT} \
--dataset-name EBChat \
--dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json \
--hyperparameter-path ../../benchmarks/${model_yaml} \
--percentile-metrics ttft,tpot,itl,e2el,s_ttft,s_itl,s_e2el,s_decode,input_len,s_input_len,output_len \
--metric-percentiles 80,95,99,99.9,99.95,99.99 \
--max-concurrency ${batch_size} \
--num-prompts ${num_prompts} \
--sharegpt-output-len 4096 \
--save-result "
echo $cmd | tee -a ${log_home}/${log_name}.log
eval $cmd >> ${log_home}/${log_name}.log 2>&1
cp log/hpu_model_runner_profile.log ${log_home}/${log_name}_profile.log

View File

@@ -0,0 +1,49 @@
#!/bin/bash
export GC_KERNEL_PATH=/usr/lib/habanalabs/libtpc_kernels.so
export GC_KERNEL_PATH=/usr/local/lib/python3.10/dist-packages/paddle_custom_device/intel_hpu/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH
export INTEL_HPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PADDLE_DISTRI_BACKEND=xccl
export PADDLE_XCCL_BACKEND=intel_hpu
# export FLAGS_intel_hpu_recipe_cache_config=/tmp/recipe,false,10240
export FLAGS_intel_hpu_recipe_cache_num=20480
export SERVER_PORT=8188
export ENGINE_WORKER_QUEUE_PORT=8002
export METRICS_PORT=8001
export CACHE_QUEUE_PORT=8003
export HABANA_PROFILE=0
export HPU_VISIBLE_DEVICES=0
rm -rf log 2>/dev/null
FD_ENC_DEC_BLOCK_NUM=8 HPU_PERF_BREAKDOWN_SYNC_MODE=1 HPU_WARMUP_BUCKET=1 HPU_WARMUP_MODEL_LEN=4096 FD_ATTENTION_BACKEND=HPU_ATTN \
python -m fastdeploy.entrypoints.openai.api_server \
--model ERNIE-4.5-21B-A3B-Paddle \
--port ${SERVER_PORT} \
--engine-worker-queue-port ${ENGINE_WORKER_QUEUE_PORT} \
--metrics-port ${METRICS_PORT} \
--cache-queue-port ${CACHE_QUEUE_PORT} \
--tensor-parallel-size 1 \
--max-model-len 32768 \
--max-num-seqs 128 \
--block-size 128 \
--num-gpu-blocks-override 3100 \
--kv-cache-ratio 0.991 \
--no-enable-prefix-caching \
--graph-optimization-config '{"use_cudagraph":false}'
# (2k + 1k) / 128(block_size) * 128(batch) = 3072
# export HPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# rm -rf log 2>/dev/null
# FD_ENC_DEC_BLOCK_NUM=8 HPU_PERF_BREAKDOWN_SYNC_MODE=1 HPU_WARMUP_BUCKET=1 HPU_WARMUP_MODEL_LEN=3072 FD_ATTENTION_BACKEND=HPU_ATTN \
# python -m fastdeploy.entrypoints.openai.api_server \
# --model ERNIE-4.5-300B-A47B-Paddle \
# --port ${SERVER_PORT} \
# --engine-worker-queue-port ${ENGINE_WORKER_QUEUE_PORT} \
# --metrics-port ${METRICS_PORT} \
# --cache-queue-port ${CACHE_QUEUE_PORT} \
# --tensor-parallel-size 8 \
# --max-model-len 32768 \
# --max-num-seqs 128 \
# --block-size 128 \
# --num-gpu-blocks-override 3100 \
# --kv-cache-ratio 0.991 \
# --no-enable-prefix-caching \
# --graph-optimization-config '{"use_cudagraph":false}'

View File

@@ -0,0 +1,35 @@
#!/bin/bash
export GC_KERNEL_PATH=/usr/lib/habanalabs/libtpc_kernels.so
export GC_KERNEL_PATH=/usr/local/lib/python3.10/dist-packages/paddle_custom_device/intel_hpu/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH
export INTEL_HPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PADDLE_DISTRI_BACKEND=xccl
export PADDLE_XCCL_BACKEND=intel_hpu
# export FLAGS_intel_hpu_recipe_cache_config=/tmp/recipe,false,10240
export FLAGS_intel_hpu_recipe_cache_num=20480
export SERVER_PORT=8188
export ENGINE_WORKER_QUEUE_PORT=8002
export METRICS_PORT=8001
export CACHE_QUEUE_PORT=8003
export HABANA_PROFILE=0
CARD_NUM=$1
if [[ "$CARD_NUM" == "1" ]]; then
export HPU_VISIBLE_DEVICES=0
export MODEL="ERNIE-4.5-21B-A3B-Paddle"
export GPU_BLOCKS=5000
elif [[ "$CARD_NUM" == "8" ]]; then
export HPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export MODEL="ERNIE-4.5-300B-A47B-Paddle"
export GPU_BLOCKS=3000
else
exit 0
fi
rm -rf log 2>/dev/null
FD_ENC_DEC_BLOCK_NUM=8 HPU_PERF_BREAKDOWN_SYNC_MODE=1 HPU_WARMUP_BUCKET=0 FD_ATTENTION_BACKEND=HPU_ATTN ENABLE_V1_KVCACHE_SCHEDULER=0 \
python -m fastdeploy.entrypoints.openai.api_server --model ${MODEL} --port ${SERVER_PORT} \
--engine-worker-queue-port ${ENGINE_WORKER_QUEUE_PORT} --metrics-port ${METRICS_PORT} \
--cache-queue-port ${CACHE_QUEUE_PORT} --tensor-parallel-size ${CARD_NUM} --max-model-len 16384 \
--max-num-seqs 128 --block-size 128 --kv-cache-ratio 0.5 --num-gpu-blocks-override ${GPU_BLOCKS} \
--graph-optimization-config '{"use_cudagraph":false}'

View File

@@ -0,0 +1,173 @@
import csv
import os
import re
import sys
from datetime import datetime
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
log_patterns = [
re.compile(
r"benchmarkdata_(.+?)_inputlength_(\d+)_outputlength_(\d+)_batchsize_(\d+)_numprompts_(\d+)_.*_profile\.log$"
),
]
def draw_time_graph(log_dir, log_filename, max_num_seqs, mode):
# Store extracted time and BT values
timestamps_model = []
times_model = []
bt_values_model = []
block_list_shapes_model = []
block_indices_shapes_model = []
timestamps_pp = []
times_pp = []
bt_values_pp = []
# Use regex to extract Model execution time and BT information
pattern_model = re.compile(
r"(\d+-\d+-\d+ \d+:\d+:\d+,\d+) .* Model execution time\(ms\): ([\d\.]+), BT=(\d+), block_list_shape=\[(\d+)\], block_indices_shape=\[(\d+)\]"
)
pattern_pp = re.compile(
r"(\d+-\d+-\d+ \d+:\d+:\d+,\d+) .* PostProcessing execution time\(ms\): ([\d\.]+), BT=(\d+)"
)
# Read log file
with open(os.path.join(log_dir, log_filename), "r") as file:
for line in file:
match_model = pattern_model.search(line)
if match_model:
bt_value = int(match_model.group(3))
timestamps_model.append(datetime.strptime(match_model.group(1), "%Y-%m-%d %H:%M:%S,%f"))
if mode == "prefill" and bt_value <= max_num_seqs:
times_model.append(None)
bt_values_model.append(None)
continue
if mode == "decode" and bt_value > max_num_seqs:
times_model.append(None)
bt_values_model.append(None)
continue
times_model.append(float(match_model.group(2)))
bt_values_model.append(bt_value)
block_list_shapes_model.append(int(match_model.group(4)))
block_indices_shapes_model.append(int(match_model.group(5)))
else:
match_pp = pattern_pp.search(line)
if match_pp:
bt_value = int(match_pp.group(3))
timestamps_pp.append(datetime.strptime(match_pp.group(1), "%Y-%m-%d %H:%M:%S,%f"))
if mode == "prefill" and bt_value <= max_num_seqs:
times_pp.append(None)
bt_values_pp.append(None)
continue
if mode == "decode" and bt_value > max_num_seqs:
times_pp.append(None)
bt_values_pp.append(None)
continue
times_pp.append(float(match_pp.group(2)))
bt_values_pp.append(bt_value)
# Plot graphs
plt.figure(figsize=(15, 7))
date_format = mdates.DateFormatter("%m-%d %H:%M:%S")
# Plot time graph
plt.subplot(2, 1, 1)
ax1 = plt.gca()
ax2 = ax1.twinx()
ax1.plot(timestamps_model, times_model, label="Model Execution Time (ms)", color="blue")
ax2.plot(timestamps_pp, times_pp, label="PostProcessing Time (ms)", color="red")
ax1.set_ylabel("Model Execution Time (ms)")
ax2.set_ylabel("PostProcessing Time (ms)")
ax1.xaxis.set_major_formatter(date_format)
# Merge legends
lines_1, labels_1 = ax1.get_legend_handles_labels()
lines_2, labels_2 = ax2.get_legend_handles_labels()
ax1.legend(lines_1 + lines_2, labels_1 + labels_2)
# Plot BT value graph
plt.subplot(2, 1, 2)
plt.plot(timestamps_model, bt_values_model, label="BT [" + mode + "]", color="orange")
plt.ylabel("BT Value")
plt.xlabel(log_filename, fontsize=8)
plt.gca().xaxis.set_major_formatter(date_format)
plt.legend()
plt.tight_layout()
output_filename = log_filename[:-4] + "_analysis_" + mode + ".png"
plt.savefig(os.path.join(log_dir, output_filename), dpi=300)
plt.close()
# Write to CSV file
if mode == "all":
csv_filename = log_filename[:-4] + "_analysis.csv"
with open(os.path.join(log_dir, csv_filename), "w", newline="") as csvfile:
writer = csv.writer(csvfile)
writer.writerow(
[
"Timestamp",
"ModelTime(ms)",
"BT",
"block_list_shape",
"block_indices_shape",
"Timestamp",
"PostProcessing(ms)",
"BT",
]
)
for i in range(len(times_model)):
writer.writerow(
[
timestamps_model[i],
times_model[i],
bt_values_model[i],
block_list_shapes_model[i],
block_indices_shapes_model[i],
timestamps_pp[i],
times_pp[i],
bt_values_pp[i],
]
)
def main():
if len(sys.argv) > 1:
log_dir = sys.argv[1]
else:
log_dir = "."
try:
from natsort import natsorted
natsort_available = True
except ImportError:
natsort_available = False
files = []
for f in os.listdir(log_dir):
for pat in log_patterns:
if pat.match(f):
files.append(f)
break
if natsort_available:
files = natsorted(files)
else:
import re as _re
def natural_key(s):
return [int(text) if text.isdigit() else text.lower() for text in _re.split("([0-9]+)", s)]
files.sort(key=natural_key)
for file in files:
for idx, pat in enumerate(log_patterns):
m = pat.match(file)
if m:
draw_time_graph(log_dir, file, 128, "prefill")
draw_time_graph(log_dir, file, 128, "decode")
draw_time_graph(log_dir, file, 128, "all")
if __name__ == "__main__":
print("Starting to draw logs...")
main()

View File

@@ -0,0 +1,5 @@
max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 8
max_num_batched_tokens: 32768

View File

@@ -0,0 +1,70 @@
# Intel HPU serving benchmark
These scripts are used to launch FastDeploy Paddle large model inference service for performance and stress testing.
## Main HPU-Specific Parameter
- `HPU_WARMUP_BUCKET`: Whether to enable warmup (1 means enabled)
- `HPU_WARMUP_MODEL_LEN`: Model length for warmup (including input and output)
- `MAX_PREFILL_NUM`: Maximum batch in prefill stage, default 3
- `BATCH_STEP_PREFILL`: Batch step in prefill stage, default 1
- `SEQUENCE_STEP_PREFILL`: Sequence step in prefill stage, default 128, same as block size
- `CONTEXT_BLOCK_STEP_PREFILL`: Step size for block hit when prefill caching is enabled, default 1
- `BATCH_STEP_DECODE`: Batch step in decode stage, default 4
- `BLOCK_STEP_DECODE`: Block step in decode stage, default 16
- `FLAGS_intel_hpu_recipe_cache_num`: Limit for HPU recipe cache number
- `FLAGS_intel_hpu_recipe_cache_config`: HPU recipe cache config, can be used for warmup optimization
- `GC_KERNEL_PATH`: The default path of the HPU TPC kernels library
- `HABANA_PROFILE`: Whether to enable profiler (1 means enabled)
- `PROFILE_START`: Profiler start step.
- `PROFILE_END`: Profiler end step.
## Usage
### 1. Start server
There are different setup scripts are provided to start the vllm server, one for RandomDataset and the other for ShareGPT.
Before running, please make sure to correctly set the model path and port number in the script.
```bash
./benchmark_paddle_hpu_server.sh
./benchmark_paddle_hpu_server_sharegpt.sh
```
You can use HPU_VISIBLE_DEVICES in the script to select the HPU card.
### 2. Run client
Correspondingly, there are different client test scripts. `benchmark_paddle_hpu_cli.sh` supports both variable and fixed length tests.
Before running, please make sure to correctly set the model path, port number, and input/output settings in the script.
```bash
./benchmark_paddle_hpu_cli.sh
./benchmark_paddle_hpu_cli_sharegpt.sh
```
### 3. Parse logs
After batch testing, run the following script to automatically parse the logs and generate a CSV file.
```python
python parse_benchmark_logs.py benchmark_fastdeploy_logs/[the targeted folder]
```
The performance data will be saved as a CSV file.
### 4. Analyse logs
During HPU_MODEL_RUNNER execution, performance logs are generated. The following script can parse these logs and produce performance graphs to help identify bottlenecks.
```python
python draw_benchmark_data.py benchmark_fastdeploy_logs/[the targeted folder]
```
The script will save the model execution times and batch tokens as a CSV file and plot them in a graph.
### 5. Accuracy test
Accuracy testing uses GSM8K. Use the following conversion to generate the test file.
```python
>>> import pandas as pd
>>> df = pd.read_parquet('tests/ce/accuracy_cases/gsm8k.parquet', engine='pyarrow')
>>> df.to_json('test.jsonl', orient='records', lines=True)
```
Run the following command to perform the accuracy test.
```bash
python -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parallel 64
```
### 6. Offline demo
To run a offline demo on HPU quickly, after set model_path in offline_demo.py, run the start script directly.
```bash
./run_offline_demo.sh
```

View File

@@ -0,0 +1,53 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM
model_name_or_path = "ERNIE-4.5-21B-A3B-Paddle"
# model_name_or_path = "ERNIE-4.5-300B-A47B-Paddle"
# Hyperparameter settings
input_bs = 1
input_seq = None # 1000
max_out_tokens = 128
server_max_bs = 128
TP = 1
# num_gpu_blocks_override = ceil((input_seq + max_out_tokens) / 128) * server_max_bs
num_gpu_blocks_override = 2000
sampling_params = SamplingParams(max_tokens=max_out_tokens)
graph_optimization_config = {"use_cudagraph": False}
llm = LLM(
model=model_name_or_path,
tensor_parallel_size=TP,
engine_worker_queue_port=8602,
num_gpu_blocks_override=num_gpu_blocks_override,
block_size=128,
max_model_len=32768,
max_num_seqs=server_max_bs,
graph_optimization_config=graph_optimization_config,
)
if input_seq is None:
prompt = "user: who are you?"
else:
prompt = "hi " * input_seq
prompts = [prompt] * input_bs
for i in range(2):
output = llm.generate(prompts=prompts, use_tqdm=True, sampling_params=sampling_params)
print(output)

View File

@@ -0,0 +1,195 @@
import csv
import os
import re
import sys
log_patterns = [
re.compile(
r"benchmarkdata_(.+?)_inputlength_(\d+)_outputlength_(\d+)_batchsize_(\d+)_numprompts_(\d+)_.*(?<!_profile)\.log$"
),
re.compile(r"benchmarkdata_(.+?)_sharegpt_prompts_(\d+)_concurrency_(\d+)_.*(?<!_profile)\.log$"),
]
metrics = [
("Mean Decode", r"Mean Decode:\s+([\d\.]+)"),
("Mean TTFT (ms)", r"Mean TTFT \(ms\):\s+([\d\.]+)"),
("Mean S_TTFT (ms)", r"Mean S_TTFT \(ms\):\s+([\d\.]+)"),
("Mean TPOT (ms)", r"Mean TPOT \(ms\):\s+([\d\.]+)"),
("Mean ITL (ms)", r"Mean ITL \(ms\):\s+([\d\.]+)"),
("Mean S_ITL (ms)", r"Mean S_ITL \(ms\):\s+([\d\.]+)"),
("Mean E2EL (ms)", r"Mean E2EL \(ms\):\s+([\d\.]+)"),
("Mean S_E2EL (ms)", r"Mean S_E2EL \(ms\):\s+([\d\.]+)"),
("Mean Input Length", r"Mean Input Length:\s+([\d\.]+)"),
("Mean Output Length", r"Mean Output Length:\s+([\d\.]+)"),
("Request throughput (req/s)", r"Request throughput \(req/s\):\s+([\d\.]+)"),
("Output token throughput (tok/s)", r"Output token throughput \(tok/s\):\s+([\d\.]+)"),
("Total Token throughput (tok/s)", r"Total Token throughput \(tok/s\):\s+([\d\.]+)"),
]
def parse_benchmark_log_file(filepath):
with open(filepath, "r", encoding="utf-8") as f:
content = f.read()
result = {}
for name, pattern in metrics:
match = re.search(pattern, content)
result[name] = match.group(1) if match else ""
return result
def parse_profile_log_file(file_path):
prepare_input_times = []
model_times = []
postprocessing_times = []
steppaddle_times = []
with open(file_path, "r") as file:
for line in file:
prepare_input_match = re.search(r"_prepare_inputs time\(ms\): (\d+\.\d+)", line)
model_match = re.search(r"Model execution time\(ms\): (\d+\.\d+)", line)
postprocessing_match = re.search(r"PostProcessing execution time\(ms\): (\d+\.\d+)", line)
steppaddle_match = re.search(r"StepPaddle execution time\(ms\): (\d+\.\d+)", line)
if prepare_input_match:
prepare_input_times.append(float(prepare_input_match.group(1)))
if model_match:
model_times.append(float(model_match.group(1)))
if postprocessing_match:
postprocessing_times.append(float(postprocessing_match.group(1)))
if steppaddle_match:
steppaddle_times.append(float(steppaddle_match.group(1)))
return prepare_input_times, model_times, postprocessing_times, steppaddle_times
def calculate_times(times, separate_first):
if len(times) < 2:
return times[0], None
if separate_first:
first_time = times[0]
average_time = sum(times[1:]) / len(times[1:])
return first_time, average_time
else:
return None, sum(times) / len(times)
def main():
if len(sys.argv) > 1:
log_dir = sys.argv[1]
else:
log_dir = "."
try:
from natsort import natsorted
natsort_available = True
except ImportError:
natsort_available = False
all_files = set(os.listdir(log_dir))
files = []
for f in os.listdir(log_dir):
for pat in log_patterns:
if pat.match(f):
files.append(f)
break
if natsort_available:
files = natsorted(files)
else:
import re as _re
def natural_key(s):
return [int(text) if text.isdigit() else text.lower() for text in _re.split("([0-9]+)", s)]
files.sort(key=natural_key)
rows = []
for file in files:
m = None
matched_idx = -1
for idx, pat in enumerate(log_patterns):
m = pat.match(file)
if m:
matched_idx = idx
break
if not m:
continue
# model_name, input_len, output_len, batch_size, num_prompts
# model_name, num_prompts, max_concurrency
if matched_idx == 0:
model_name, input_len, output_len, batch_size, num_prompts = m.groups()
elif matched_idx == 1:
model_name, num_prompts, max_concurrency = m.groups()
input_len = "-"
output_len = "-"
if file.endswith(".log"):
profile_file = file[:-4] + "_profile.log"
else:
profile_file = ""
model_first = model_average = postprocessing_average = steppaddle_average = ""
if profile_file in all_files:
prepare_input_times, model_times, postprocessing_times, steppaddle_times = parse_profile_log_file(
os.path.join(log_dir, profile_file)
)
_, pia = calculate_times(prepare_input_times, False)
mf, ma = calculate_times(model_times, True)
_, pa = calculate_times(postprocessing_times, False)
_, sa = calculate_times(steppaddle_times, False)
prepare_input_average = pia if pia is not None else ""
model_first = mf if mf is not None else ""
model_average = ma if ma is not None else ""
postprocessing_average = pa if pa is not None else ""
steppaddle_average = sa if sa is not None else ""
data = parse_benchmark_log_file(os.path.join(log_dir, file))
data["dataset"] = "Fixed-Length" if matched_idx == 0 else "ShareGPT"
data["model_name"] = model_name
data["input_length"] = input_len
data["output_length"] = output_len
data["batch_size"] = batch_size if matched_idx == 0 else max_concurrency
data["num_prompts"] = num_prompts
data["prepare_input_average"] = prepare_input_average
data["model_execute_first"] = model_first
data["model_execute_average"] = model_average
data["postprocessing_execute_average"] = postprocessing_average
data["steppaddle_execute_average"] = steppaddle_average
rows.append(data)
import datetime
import pytz
shanghai_tz = pytz.timezone("Asia/Shanghai")
now = datetime.datetime.now(shanghai_tz)
ts = now.strftime("%Y%m%d_%H%M%S")
log_dir_name = os.path.basename(os.path.abspath(log_dir))
if log_dir_name == "" or log_dir == "." or log_dir == "/":
csv_filename = f"benchmark_summary_{ts}.csv"
else:
csv_filename = f"benchmark_summary_{log_dir_name}_{ts}.csv"
fieldnames = (
[
"model_name",
"dataset",
"input_length",
"output_length",
"batch_size",
"num_prompts",
]
+ [name for name, _ in metrics]
+ [
"prepare_input_average",
"model_execute_first",
"model_execute_average",
"postprocessing_execute_average",
"steppaddle_execute_average",
]
)
with open(csv_filename, "w", newline="", encoding="utf-8") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for row in rows:
writer.writerow(row)
print(f"CSV saved as: {csv_filename}")
if __name__ == "__main__":
print("Starting to parse logs...")
main()

View File

@@ -0,0 +1,19 @@
#!/bin/bash
export GC_KERNEL_PATH=/usr/lib/habanalabs/libtpc_kernels.so
export GC_KERNEL_PATH=/usr/local/lib/python3.10/dist-packages/paddle_custom_device/intel_hpu/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH
export INTEL_HPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PADDLE_DISTRI_BACKEND=xccl
export PADDLE_XCCL_BACKEND=intel_hpu
# export HPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export HPU_VISIBLE_DEVICES=0
export HABANA_PROFILE=0
export PROFILE_START=1
export PROFILE_END=3
# export HABANA_LOGS=hpu_logs
# export LOG_LEVEL_ALL=0
# export FLAGS_intel_hpu_runtime_debug=1
# export FLAGS_intel_hpu_reciperunner_debug=1
rm -rf log
FD_ATTENTION_BACKEND=HPU_ATTN python offline_demo.py

View File

@@ -26,8 +26,8 @@ python -m pip uninstall fastdeploy_intel_hpu -y
#to install paddlepaddle
pip install paddlepaddle==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/
#to install paddlecustomdevice? (paddle-intel-hpu)
pip install https://paddle-qa.bj.bcebos.com/suijiaxin/HPU/paddle_intel_hpu-0.0.1-cp310-cp310-linux_x86_64.whl
pip install https://paddle-qa.bj.bcebos.com/suijiaxin/HPU/paddlenlp_ops-0.0.0-cp310-cp310-linux_x86_64.whl
pip install https://paddle-qa.bj.bcebos.com/suijiaxin/HPU/paddle_intel_hpu-0.0.2-cp310-cp310-linux_x86_64.whl
pip install https://paddle-qa.bj.bcebos.com/suijiaxin/HPU/paddlenlp_ops-0.0.2-cp310-cp310-linux_x86_64.whl
#to build and install fastdeploy
echo "build whl"