diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py index 8365c6985..3e150abf2 100644 --- a/fastdeploy/entrypoints/llm.py +++ b/fastdeploy/entrypoints/llm.py @@ -285,6 +285,10 @@ class LLM: self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking) return req_ids + def _decode_token(self, token_id: int) -> str: + """Decodes a single token ID into its string representation.""" + return self.llm_engine.data_processor.process_logprob_response([token_id], clean_up_tokenization_spaces=False) + def _build_sample_logprobs(self, logprobs_lists: LogprobsLists, topk_logprobs: int) -> list[dict[int, Logprob]]: """ Constructs a list of dictionaries mapping token IDs to Logprob objects, @@ -318,8 +322,9 @@ class LLM: sliced_logprobs_lists = logprobs_lists.slice_columns(1, 1 + effective_topk_logprobs) result = [] for token_ids, logprobs in zip(sliced_logprobs_lists.logprob_token_ids, sliced_logprobs_lists.logprobs): + logprob_dict = { - token_id: Logprob(logprob=logprob, rank=i + 1, decoded_token=None) + token_id: Logprob(logprob=logprob, rank=i + 1, decoded_token=self._decode_token(token_id)) for i, (token_id, logprob) in enumerate(zip(token_ids, logprobs)) } result.append(logprob_dict) diff --git a/test/ce/server/test_base_chat.py b/test/ce/server/test_base_chat.py new file mode 100644 index 000000000..12be895fe --- /dev/null +++ b/test/ce/server/test_base_chat.py @@ -0,0 +1,221 @@ +#!/bin/env python3 +# -*- coding: utf-8 -*- +# @author DDDivano +# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python + +""" +some basic check for fd web api +""" + +import json + +from core import TEMPLATE, URL, build_request_payload, send_request + + +def test_stream_response(): + data = { + "stream": True, + "messages": [ + {"role": "system", "content": "你是一个知识渊博的 AI 助手"}, + {"role": "user", "content": "讲讲爱因斯坦的相对论"}, + ], + "max_tokens": 10, + } + payload = build_request_payload(TEMPLATE, data) + resp = send_request(URL, payload, stream=True) + + output = "" + for line in resp.iter_lines(decode_unicode=True): + if line.strip() == "" or not line.startswith("data: "): + continue + line = line[len("data: ") :] + if line.strip() == "[DONE]": + break + chunk = json.loads(line) + delta = chunk.get("choices", [{}])[0].get("delta", {}) + output += delta.get("content", "") + + print("Stream输出:", output) + assert "相对论" in output or len(output) > 0 + + +def test_system_prompt_effect(): + data = { + "stream": False, + "messages": [ + {"role": "system", "content": "请用一句话回答"}, + {"role": "user", "content": "什么是人工智能?"}, + ], + "max_tokens": 30, + } + payload = build_request_payload(TEMPLATE, data) + resp = send_request(URL, payload).json() + content = resp["choices"][0]["message"]["content"] + print("内容输出:", content) + assert len(content) < 50 + + +def test_logprobs_enabled(): + data = { + "stream": False, + "logprobs": True, + "top_logprobs": 5, + "messages": [{"role": "user", "content": "非洲的首都是?"}], + "max_tokens": 3, + } + payload = build_request_payload(TEMPLATE, data) + resp = send_request(URL, payload).json() + logprob_data = resp["choices"][0].get("logprobs") + print("LogProbs:", logprob_data) + assert logprob_data is not None + content_logprobs = logprob_data.get("content", []) + assert isinstance(content_logprobs, list) + assert all("token" in item for item in content_logprobs) + + +def test_stop_sequence(): + data = { + "stream": False, + "stop": ["果冻"], + "messages": [ + { + "role": "user", + "content": "你要严格按照我接下来的话输出,输出冒号后面的内容,请输出:这是第一段。果冻这是第二段啦啦啦啦啦。", + }, + ], + "max_tokens": 20, + "top_p": 0, + } + payload = build_request_payload(TEMPLATE, data) + resp = send_request(URL, payload).json() + content = resp["choices"][0]["message"]["content"] + print("截断输出:", content) + assert "第二段" not in content + + +def test_sampling_parameters(): + data = { + "stream": False, + "temperature": 0, + "top_p": 0, + "messages": [ + {"role": "user", "content": "1+1=?,直接回答答案"}, + ], + "max_tokens": 50, + } + payload = build_request_payload(TEMPLATE, data) + resp = send_request(URL, payload).json() + answer = resp["choices"][0]["message"]["content"] + print("Sampling输出:", answer) + assert any(ans in answer for ans in ["2", "二"]) + + +def test_multi_turn_conversation(): + data = { + "stream": False, + "messages": [ + {"role": "user", "content": "牛顿是谁?"}, + {"role": "assistant", "content": "牛顿是一位物理学家。"}, + {"role": "user", "content": "他提出了什么理论?"}, + ], + "max_tokens": 30, + } + payload = build_request_payload(TEMPLATE, data) + resp = send_request(URL, payload).json() + content = resp["choices"][0]["message"]["content"] + print("多轮记忆:", content) + assert "三大运动定律" in content or "万有引力" in content + + +def test_bad_words_filtering(): + banned_tokens = ["和", "呀"] + + data = { + "stream": False, + "messages": [ + {"role": "system", "content": "你是一个助手,回答简洁清楚"}, + {"role": "user", "content": "请输出冒号后面的字: 我爱吃果冻,和苹果,香蕉,和荔枝"}, + ], + "top_p": 0, + "max_tokens": 69, + "bad_words": banned_tokens, + } + + payload = build_request_payload(TEMPLATE, data) + response = send_request(URL, payload).json() + + content = response["choices"][0]["message"]["content"] + print("生成内容:", content) + + for word in banned_tokens: + assert word not in content, f"bad_word '{word}' 不应出现在生成结果中" + + print("test_bad_words_filtering 通过:生成结果未包含被禁词") + + data = { + "stream": False, + "messages": [ + {"role": "system", "content": "你是一个助手,回答简洁清楚"}, + {"role": "user", "content": "请输出冒号后面的字,一模一样: 我爱吃果冻,苹果,香蕉,和荔枝呀呀呀"}, + ], + "top_p": 0, + "max_tokens": 69, + # "bad_words": banned_tokens, + } + + payload = build_request_payload(TEMPLATE, data) + response = send_request(URL, payload).json() + + content = response["choices"][0]["message"]["content"] + print("生成内容:", content) + + for word in banned_tokens: + assert word not in content, f"bad_word '{word}' 不应出现在生成结果中" + + print("test_bad_words_filtering 通过:生成结果未包含被禁词") + + +def test_bad_words_filtering1(): + banned_tokens = ["和", "呀"] + + data = { + "stream": False, + "messages": [ + {"role": "system", "content": "你是一个助手,回答简洁清楚"}, + {"role": "user", "content": "请输出冒号后面的字: 我爱吃果冻,和苹果,香蕉,和荔枝"}, + ], + "top_p": 0, + "max_tokens": 69, + "bad_words": banned_tokens, + } + + payload = build_request_payload(TEMPLATE, data) + response = send_request(URL, payload).json() + + content = response["choices"][0]["message"]["content"] + print("生成内容:", content) + + for word in banned_tokens: + assert word not in content, f"bad_word '{word}' 不应出现在生成结果中" + + print("test_bad_words_filtering 通过:生成结果未包含被禁词") + word = "呀呀" + data = { + "stream": False, + "messages": [ + {"role": "system", "content": "你是一个助手,回答简洁清楚"}, + {"role": "user", "content": "请输出冒号后面的字,一模一样: 我爱吃果冻,苹果,香蕉,和荔枝呀呀呀"}, + ], + "top_p": 0, + "max_tokens": 69, + } + + payload = build_request_payload(TEMPLATE, data) + response = send_request(URL, payload).json() + + content = response["choices"][0]["message"]["content"] + print("生成内容:", content) + + assert word in content, f" '{word}' 应出现在生成结果中" + + print("test_bad_words_filtering 通过:生成结果未包含被禁词") diff --git a/test/entrypoints/openai/test_build_sample_logprobs.py b/test/entrypoints/openai/test_build_sample_logprobs.py new file mode 100644 index 000000000..76ff8e87b --- /dev/null +++ b/test/entrypoints/openai/test_build_sample_logprobs.py @@ -0,0 +1,78 @@ +import unittest +from unittest.mock import MagicMock, patch + +from fastdeploy.entrypoints.llm import LLM +from fastdeploy.worker.output import Logprob, LogprobsLists + + +def get_patch_path(cls, method="__init__"): + return f"{cls.__module__}.{cls.__qualname__}.{method}" + + +class TestBuildSampleLogprobs(unittest.TestCase): + + def setUp(self): + """ + Set up the test environment by creating an instance of the LLM class using Mock. + """ + patch_llm = get_patch_path(LLM) + with patch(patch_llm, return_value=None): + self.llm = LLM() + # mock d data_processor + self.llm.llm_engine = MagicMock() + self.llm.llm_engine.data_processor.process_logprob_response.side_effect = ( + lambda ids, **kwargs: f"token_{ids[0]}" + ) + + def test_build_sample_logprobs_basic(self): + """ + Test case for building sample logprobs when `topk_logprobs` is valid. + """ + logprob_token_ids = [[100, 101, 102]] + logprobs = [[-0.1, -0.5, -1.0]] + sampled_token_ranks = [0] + + logprobs_lists = LogprobsLists( + logprob_token_ids=logprob_token_ids, logprobs=logprobs, sampled_token_ranks=sampled_token_ranks + ) + + result = self.llm._build_sample_logprobs(logprobs_lists, topk_logprobs=2) + + expected = [ + { + 101: Logprob(logprob=-0.5, rank=1, decoded_token="token_101"), + 102: Logprob(logprob=-1.0, rank=2, decoded_token="token_102"), + } + ] + + self.assertEqual(result, expected) + + def test_build_sample_logprobs_empty_input(self): + """ + Test case where `logprob_token_ids` is empty. + """ + logprobs_lists = MagicMock(spec=LogprobsLists) + logprobs_lists.logprob_token_ids = [] + result = self.llm._build_sample_logprobs(logprobs_lists, topk_logprobs=2) + self.assertIsNone(result) + + def test_build_sample_logprobs_invalid_topk(self): + """ + Test case where `topk` value exceeds length of first element in `logprob_token_ids`. + """ + logprobs_lists = MagicMock(spec=LogprobsLists) + logprobs_lists.logprob_token_ids = [[100]] + result = self.llm._build_sample_logprobs(logprobs_lists, topk_logprobs=2) + self.assertIsNone(result) + + def test_decode_token(self): + """ + Test case for decoding a single token ID. + """ + token_id = 123 + decoded = self.llm._decode_token(token_id) + self.assertEqual(decoded, "token_123") + + +if __name__ == "__main__": + unittest.main()