mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 09:31:35 +08:00
[Bugfix] Fix uninitialized decoded_token and add corresponding unit test (#3201)
* Update test_base_chat.py (#3183) * [Bugfix] Fix uninitialized decoded_token and add corresponding unit test. --------- Co-authored-by: Divano <dddivano@outlook.com>
This commit is contained in:
@@ -285,6 +285,10 @@ class LLM:
|
|||||||
self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking)
|
self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking)
|
||||||
return req_ids
|
return req_ids
|
||||||
|
|
||||||
|
def _decode_token(self, token_id: int) -> str:
|
||||||
|
"""Decodes a single token ID into its string representation."""
|
||||||
|
return self.llm_engine.data_processor.process_logprob_response([token_id], clean_up_tokenization_spaces=False)
|
||||||
|
|
||||||
def _build_sample_logprobs(self, logprobs_lists: LogprobsLists, topk_logprobs: int) -> list[dict[int, Logprob]]:
|
def _build_sample_logprobs(self, logprobs_lists: LogprobsLists, topk_logprobs: int) -> list[dict[int, Logprob]]:
|
||||||
"""
|
"""
|
||||||
Constructs a list of dictionaries mapping token IDs to Logprob objects,
|
Constructs a list of dictionaries mapping token IDs to Logprob objects,
|
||||||
@@ -318,8 +322,9 @@ class LLM:
|
|||||||
sliced_logprobs_lists = logprobs_lists.slice_columns(1, 1 + effective_topk_logprobs)
|
sliced_logprobs_lists = logprobs_lists.slice_columns(1, 1 + effective_topk_logprobs)
|
||||||
result = []
|
result = []
|
||||||
for token_ids, logprobs in zip(sliced_logprobs_lists.logprob_token_ids, sliced_logprobs_lists.logprobs):
|
for token_ids, logprobs in zip(sliced_logprobs_lists.logprob_token_ids, sliced_logprobs_lists.logprobs):
|
||||||
|
|
||||||
logprob_dict = {
|
logprob_dict = {
|
||||||
token_id: Logprob(logprob=logprob, rank=i + 1, decoded_token=None)
|
token_id: Logprob(logprob=logprob, rank=i + 1, decoded_token=self._decode_token(token_id))
|
||||||
for i, (token_id, logprob) in enumerate(zip(token_ids, logprobs))
|
for i, (token_id, logprob) in enumerate(zip(token_ids, logprobs))
|
||||||
}
|
}
|
||||||
result.append(logprob_dict)
|
result.append(logprob_dict)
|
||||||
|
221
test/ce/server/test_base_chat.py
Normal file
221
test/ce/server/test_base_chat.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
#!/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @author DDDivano
|
||||||
|
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
|
||||||
|
|
||||||
|
"""
|
||||||
|
some basic check for fd web api
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from core import TEMPLATE, URL, build_request_payload, send_request
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_response():
|
||||||
|
data = {
|
||||||
|
"stream": True,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "你是一个知识渊博的 AI 助手"},
|
||||||
|
{"role": "user", "content": "讲讲爱因斯坦的相对论"},
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
|
}
|
||||||
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
|
resp = send_request(URL, payload, stream=True)
|
||||||
|
|
||||||
|
output = ""
|
||||||
|
for line in resp.iter_lines(decode_unicode=True):
|
||||||
|
if line.strip() == "" or not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
line = line[len("data: ") :]
|
||||||
|
if line.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
chunk = json.loads(line)
|
||||||
|
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||||
|
output += delta.get("content", "")
|
||||||
|
|
||||||
|
print("Stream输出:", output)
|
||||||
|
assert "相对论" in output or len(output) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_system_prompt_effect():
|
||||||
|
data = {
|
||||||
|
"stream": False,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "请用一句话回答"},
|
||||||
|
{"role": "user", "content": "什么是人工智能?"},
|
||||||
|
],
|
||||||
|
"max_tokens": 30,
|
||||||
|
}
|
||||||
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
|
resp = send_request(URL, payload).json()
|
||||||
|
content = resp["choices"][0]["message"]["content"]
|
||||||
|
print("内容输出:", content)
|
||||||
|
assert len(content) < 50
|
||||||
|
|
||||||
|
|
||||||
|
def test_logprobs_enabled():
|
||||||
|
data = {
|
||||||
|
"stream": False,
|
||||||
|
"logprobs": True,
|
||||||
|
"top_logprobs": 5,
|
||||||
|
"messages": [{"role": "user", "content": "非洲的首都是?"}],
|
||||||
|
"max_tokens": 3,
|
||||||
|
}
|
||||||
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
|
resp = send_request(URL, payload).json()
|
||||||
|
logprob_data = resp["choices"][0].get("logprobs")
|
||||||
|
print("LogProbs:", logprob_data)
|
||||||
|
assert logprob_data is not None
|
||||||
|
content_logprobs = logprob_data.get("content", [])
|
||||||
|
assert isinstance(content_logprobs, list)
|
||||||
|
assert all("token" in item for item in content_logprobs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_sequence():
|
||||||
|
data = {
|
||||||
|
"stream": False,
|
||||||
|
"stop": ["果冻"],
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "你要严格按照我接下来的话输出,输出冒号后面的内容,请输出:这是第一段。果冻这是第二段啦啦啦啦啦。",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"max_tokens": 20,
|
||||||
|
"top_p": 0,
|
||||||
|
}
|
||||||
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
|
resp = send_request(URL, payload).json()
|
||||||
|
content = resp["choices"][0]["message"]["content"]
|
||||||
|
print("截断输出:", content)
|
||||||
|
assert "第二段" not in content
|
||||||
|
|
||||||
|
|
||||||
|
def test_sampling_parameters():
|
||||||
|
data = {
|
||||||
|
"stream": False,
|
||||||
|
"temperature": 0,
|
||||||
|
"top_p": 0,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "1+1=?,直接回答答案"},
|
||||||
|
],
|
||||||
|
"max_tokens": 50,
|
||||||
|
}
|
||||||
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
|
resp = send_request(URL, payload).json()
|
||||||
|
answer = resp["choices"][0]["message"]["content"]
|
||||||
|
print("Sampling输出:", answer)
|
||||||
|
assert any(ans in answer for ans in ["2", "二"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_turn_conversation():
|
||||||
|
data = {
|
||||||
|
"stream": False,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "牛顿是谁?"},
|
||||||
|
{"role": "assistant", "content": "牛顿是一位物理学家。"},
|
||||||
|
{"role": "user", "content": "他提出了什么理论?"},
|
||||||
|
],
|
||||||
|
"max_tokens": 30,
|
||||||
|
}
|
||||||
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
|
resp = send_request(URL, payload).json()
|
||||||
|
content = resp["choices"][0]["message"]["content"]
|
||||||
|
print("多轮记忆:", content)
|
||||||
|
assert "三大运动定律" in content or "万有引力" in content
|
||||||
|
|
||||||
|
|
||||||
|
def test_bad_words_filtering():
|
||||||
|
banned_tokens = ["和", "呀"]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"stream": False,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "你是一个助手,回答简洁清楚"},
|
||||||
|
{"role": "user", "content": "请输出冒号后面的字: 我爱吃果冻,和苹果,香蕉,和荔枝"},
|
||||||
|
],
|
||||||
|
"top_p": 0,
|
||||||
|
"max_tokens": 69,
|
||||||
|
"bad_words": banned_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
|
response = send_request(URL, payload).json()
|
||||||
|
|
||||||
|
content = response["choices"][0]["message"]["content"]
|
||||||
|
print("生成内容:", content)
|
||||||
|
|
||||||
|
for word in banned_tokens:
|
||||||
|
assert word not in content, f"bad_word '{word}' 不应出现在生成结果中"
|
||||||
|
|
||||||
|
print("test_bad_words_filtering 通过:生成结果未包含被禁词")
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"stream": False,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "你是一个助手,回答简洁清楚"},
|
||||||
|
{"role": "user", "content": "请输出冒号后面的字,一模一样: 我爱吃果冻,苹果,香蕉,和荔枝呀呀呀"},
|
||||||
|
],
|
||||||
|
"top_p": 0,
|
||||||
|
"max_tokens": 69,
|
||||||
|
# "bad_words": banned_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
|
response = send_request(URL, payload).json()
|
||||||
|
|
||||||
|
content = response["choices"][0]["message"]["content"]
|
||||||
|
print("生成内容:", content)
|
||||||
|
|
||||||
|
for word in banned_tokens:
|
||||||
|
assert word not in content, f"bad_word '{word}' 不应出现在生成结果中"
|
||||||
|
|
||||||
|
print("test_bad_words_filtering 通过:生成结果未包含被禁词")
|
||||||
|
|
||||||
|
|
||||||
|
def test_bad_words_filtering1():
|
||||||
|
banned_tokens = ["和", "呀"]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"stream": False,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "你是一个助手,回答简洁清楚"},
|
||||||
|
{"role": "user", "content": "请输出冒号后面的字: 我爱吃果冻,和苹果,香蕉,和荔枝"},
|
||||||
|
],
|
||||||
|
"top_p": 0,
|
||||||
|
"max_tokens": 69,
|
||||||
|
"bad_words": banned_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
|
response = send_request(URL, payload).json()
|
||||||
|
|
||||||
|
content = response["choices"][0]["message"]["content"]
|
||||||
|
print("生成内容:", content)
|
||||||
|
|
||||||
|
for word in banned_tokens:
|
||||||
|
assert word not in content, f"bad_word '{word}' 不应出现在生成结果中"
|
||||||
|
|
||||||
|
print("test_bad_words_filtering 通过:生成结果未包含被禁词")
|
||||||
|
word = "呀呀"
|
||||||
|
data = {
|
||||||
|
"stream": False,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "你是一个助手,回答简洁清楚"},
|
||||||
|
{"role": "user", "content": "请输出冒号后面的字,一模一样: 我爱吃果冻,苹果,香蕉,和荔枝呀呀呀"},
|
||||||
|
],
|
||||||
|
"top_p": 0,
|
||||||
|
"max_tokens": 69,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
|
response = send_request(URL, payload).json()
|
||||||
|
|
||||||
|
content = response["choices"][0]["message"]["content"]
|
||||||
|
print("生成内容:", content)
|
||||||
|
|
||||||
|
assert word in content, f" '{word}' 应出现在生成结果中"
|
||||||
|
|
||||||
|
print("test_bad_words_filtering 通过:生成结果未包含被禁词")
|
78
test/entrypoints/openai/test_build_sample_logprobs.py
Normal file
78
test/entrypoints/openai/test_build_sample_logprobs.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from fastdeploy.entrypoints.llm import LLM
|
||||||
|
from fastdeploy.worker.output import Logprob, LogprobsLists
|
||||||
|
|
||||||
|
|
||||||
|
def get_patch_path(cls, method="__init__"):
|
||||||
|
return f"{cls.__module__}.{cls.__qualname__}.{method}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildSampleLogprobs(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""
|
||||||
|
Set up the test environment by creating an instance of the LLM class using Mock.
|
||||||
|
"""
|
||||||
|
patch_llm = get_patch_path(LLM)
|
||||||
|
with patch(patch_llm, return_value=None):
|
||||||
|
self.llm = LLM()
|
||||||
|
# mock d data_processor
|
||||||
|
self.llm.llm_engine = MagicMock()
|
||||||
|
self.llm.llm_engine.data_processor.process_logprob_response.side_effect = (
|
||||||
|
lambda ids, **kwargs: f"token_{ids[0]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_build_sample_logprobs_basic(self):
|
||||||
|
"""
|
||||||
|
Test case for building sample logprobs when `topk_logprobs` is valid.
|
||||||
|
"""
|
||||||
|
logprob_token_ids = [[100, 101, 102]]
|
||||||
|
logprobs = [[-0.1, -0.5, -1.0]]
|
||||||
|
sampled_token_ranks = [0]
|
||||||
|
|
||||||
|
logprobs_lists = LogprobsLists(
|
||||||
|
logprob_token_ids=logprob_token_ids, logprobs=logprobs, sampled_token_ranks=sampled_token_ranks
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.llm._build_sample_logprobs(logprobs_lists, topk_logprobs=2)
|
||||||
|
|
||||||
|
expected = [
|
||||||
|
{
|
||||||
|
101: Logprob(logprob=-0.5, rank=1, decoded_token="token_101"),
|
||||||
|
102: Logprob(logprob=-1.0, rank=2, decoded_token="token_102"),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def test_build_sample_logprobs_empty_input(self):
|
||||||
|
"""
|
||||||
|
Test case where `logprob_token_ids` is empty.
|
||||||
|
"""
|
||||||
|
logprobs_lists = MagicMock(spec=LogprobsLists)
|
||||||
|
logprobs_lists.logprob_token_ids = []
|
||||||
|
result = self.llm._build_sample_logprobs(logprobs_lists, topk_logprobs=2)
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
def test_build_sample_logprobs_invalid_topk(self):
|
||||||
|
"""
|
||||||
|
Test case where `topk` value exceeds length of first element in `logprob_token_ids`.
|
||||||
|
"""
|
||||||
|
logprobs_lists = MagicMock(spec=LogprobsLists)
|
||||||
|
logprobs_lists.logprob_token_ids = [[100]]
|
||||||
|
result = self.llm._build_sample_logprobs(logprobs_lists, topk_logprobs=2)
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
def test_decode_token(self):
|
||||||
|
"""
|
||||||
|
Test case for decoding a single token ID.
|
||||||
|
"""
|
||||||
|
token_id = 123
|
||||||
|
decoded = self.llm._decode_token(token_id)
|
||||||
|
self.assertEqual(decoded, "token_123")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Reference in New Issue
Block a user