Files
FastDeploy/test/ce/server/gsm8k.py
Divano 5885285e57 Ce add benchmark test (#3262)
* add repitation early stop cases

* add repitation early stop cases

* add bad cases

* add bad cases

* add evil cases

* add benchmark gsm8k
2025-08-07 15:28:30 +08:00

189 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/bin/env python3
# -*- coding: utf-8 -*-
# @author DDDivano
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
import os
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from urllib.parse import urlparse, urlunparse
import openai
from datasets import load_dataset
from tqdm import tqdm
BASELINE = {
"0.3B": 0.05,
"21B": 0.49,
"300B": 0.96,
}
baseline = BASELINE.get(os.environ.get("MODEL"), None)
base_url = os.environ.get("URL", None)
atol = 0.03
if baseline is None:
raise ValueError(f"Invalid MODEL value '{os.environ.get('MODEL')}', expected one of {list(BASELINE.keys())}")
if base_url is None:
raise ValueError(
"Environment variable 'URL' is not set. "
"Please specify the inference service address, e.g., 'http://localhost:8191/v1'."
)
def strip_path_suffix(url: str, suffix: str = "chat/completions") -> str:
"""
去除 URL 中的指定路径后缀(如 chat/completions
"""
parsed = urlparse(url)
# 移除末尾的 suffix注意确保只移除结尾部分
if parsed.path.endswith("/" + suffix):
new_path = parsed.path[: -(len(suffix) + 1)] # +1 是斜杠
else:
new_path = parsed.path
# 重新构造 URL
cleaned_url = urlunparse(
(
parsed.scheme,
parsed.netloc,
new_path.rstrip("/"), # 去掉末尾的斜杠
"",
"",
"", # 忽略 params/query/fragment
)
)
return cleaned_url
# ========== OpenAI 客户端配置 ==========
client = openai.OpenAI(
api_key="DDDivano",
# base_url="http://占位:8187/v1"
base_url=strip_path_suffix(base_url),
)
model_name = "eb"
max_samples = 690
max_tokens = 12288
max_workers = 33
# ========== 加载数据集 ==========
dataset = load_dataset("parquet", data_files="gsm8k.parquet", split="train")
dataset = dataset.select(range(min(len(dataset), max_samples)))
# ========== 提取 GT 中 "#### 数字" 格式的最终答案 ==========
def extract_gt_answer(text):
match = re.search(r"####\s*([\d,]+(?:\.\d+)?)", text)
if match:
return match.group(1).replace(",", "").strip()
return None
# ========== 提取模型输出中的“最后一句话”中的数字 ==========
def extract_model_answer(text):
if not text:
return None
text = text.replace(",", "").replace("$", "")
lines = text.strip().splitlines()
last_line = lines[-1] if lines else text
match = re.search(r"-?\d+(?:\.\d+)?", last_line)
return match.group(0) if match else None
# ========== 数值比较函数 ==========
def is_answer_equal(pred, gt, tol=1e-6):
if pred is None or gt is None:
return False
try:
return abs(float(pred) - float(gt)) < tol
except:
return pred == gt
# ========== 构造 Prompt ==========
def build_prompt(sample):
return f"以下是一个数学问题,请直接给出最终答案。一定要把最终答案数字在最后输出。\n\n问题:{sample['question']}\n\n答案:"
# ========== 模型请求函数 ==========
def query_model(prompt):
try:
response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": "你是一个数学专家,擅长严谨地解答数学问题。"},
{"role": "user", "content": prompt},
],
temperature=1.0,
top_p=0.8,
max_tokens=max_tokens,
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"[Error] {e}"
# ========== 评估函数 ==========
def evaluate_sample(sample):
prompt = build_prompt(sample)
model_output = query_model(prompt)
gt_value = extract_gt_answer(sample["answer"])
pred_value = extract_model_answer(model_output)
is_correct = is_answer_equal(pred_value, gt_value)
result = {
"question": sample["question"],
"gt_answer": gt_value,
"model_answer": pred_value,
"raw_gt_answer": sample["answer"],
"raw_model_output": model_output,
"is_correct": is_correct,
}
return result
# ========== 主流程 ==========
acc = []
times = 3
for i in range(times):
correct = 0
total = 0
results = []
print(f"🚀 Starting evaluation with {max_workers} threads...")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(evaluate_sample, sample) for sample in dataset]
for future in tqdm(as_completed(futures), total=len(futures), desc="Evaluating"):
result = future.result()
results.append(result)
total += 1
if result["is_correct"]:
correct += 1
else:
print("\n❌ Wrong prediction:")
print(f"Q: {result['question']}")
print(f"GT: {result['gt_answer']}")
print(f"Model: {result['model_answer']}")
print(f"Full GT: {result['raw_gt_answer']}")
print(f"Model Output: {result['raw_model_output']}")
# ========== 输出准确率 ==========
accuracy = correct / total * 100 if total > 0 else 0.0
print(f"\n🎯 Evaluation Complete: Accuracy = {accuracy:.2f}% ({correct}/{total})")
acc.append(accuracy)
avg_acc = round(sum(acc) / times / 100, 4) # 优化百分数
print(f"平均准确率:{avg_acc * 100:.2f}%")
assert (
abs(avg_acc - baseline) <= atol
), f"模型准确率 {avg_acc:.2f} 与基准 {baseline:.2f} 相差 {abs(avg_acc - baseline):.2f},超出容忍范围 {atol:.2f}"
# with open("eval_result_math.json", "w", encoding="utf-8") as f:
# json.dump(results, f, indent=2, ensure_ascii=False)