diff --git a/tests/entrypoints/openai/test_abstract_tool_parser.py b/tests/entrypoints/openai/test_abstract_tool_parser.py new file mode 100644 index 000000000..57b4e818f --- /dev/null +++ b/tests/entrypoints/openai/test_abstract_tool_parser.py @@ -0,0 +1,325 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +from unittest.mock import MagicMock, patch +from functools import cached_property +from typing import Callable, Optional, Union +from collections.abc import Sequence + + +# Copy the tool parser classes to avoid import issues +class ToolParser: + """Abstract ToolParser class that should not be used directly.""" + + def __init__(self, tokenizer): + self.prev_tool_call_arr: list[dict] = [] + # the index of the tool call that is currently being parsed + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: list[str] = [] + + self.model_tokenizer = tokenizer + + @cached_property + def vocab(self) -> dict[str, int]: + # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab + # whereas all tokenizers have .get_vocab() + return self.model_tokenizer.get_vocab() + + def adjust_request(self, request): + """Static method that used to adjust the request parameters.""" + return request + + def extract_tool_calls(self, model_output: str, request): + """Static method that should be implemented for extracting tool calls from a complete model-generated string.""" + raise NotImplementedError("AbstractToolParser.extract_tool_calls has not been implemented!") + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request, + ): + """Instance method that should be implemented for extracting tool calls from an incomplete response.""" + raise NotImplementedError("AbstractToolParser.extract_tool_calls_streaming has not been implemented!") + + +def is_list_of(seq, expected_type: type) -> bool: + """Check if sequence contains only elements of expected type""" + return isinstance(seq, (list, tuple)) and all(isinstance(item, expected_type) for item in seq) + + +class ToolParserManager: + tool_parsers: dict[str, type] = {} + + @classmethod + def get_tool_parser(cls, name) -> type: + """Get tool parser by name which is registered by `register_module`.""" + if name in cls.tool_parsers: + return cls.tool_parsers[name] + + raise KeyError(f"tool helper: '{name}' not found in tool_parsers") + + @classmethod + def _register_module( + cls, module: type, module_name: Optional[Union[str, list[str]]] = None, force: bool = True + ) -> None: + if not issubclass(module, ToolParser): + raise TypeError(f"module must be subclass of ToolParser, but got {type(module)}") + if module_name is None: + module_name = module.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in cls.tool_parsers: + existed_module = cls.tool_parsers[name] + raise KeyError(f"{name} is already registered at {existed_module.__module__}") + cls.tool_parsers[name] = module + + @classmethod + def register_module( + cls, name: Optional[Union[str, list[str]]] = None, force: bool = True, module: Union[type, None] = None + ) -> Union[type, Callable]: + """Register module with the given name or name list.""" + if not isinstance(force, bool): + raise TypeError(f"force must be a boolean, but got {type(force)}") + + # raise the error ahead of time + if not (name is None or isinstance(name, str) or is_list_of(name, str)): + raise TypeError("name must be None, an instance of str, or a sequence of str, " f"but got {type(name)}") + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + cls._register_module(module=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(module): + cls._register_module(module=module, module_name=name, force=force) + return module + + return _register + + @classmethod + def import_tool_parser(cls, plugin_path: str) -> None: + """Import a user-defined tool parser by the path of the tool parser define file.""" + module_name = os.path.splitext(os.path.basename(plugin_path))[0] + + try: + # Mock import_from_path function + pass + except Exception: + return + + +# Mock tool parser for testing +class MockToolParser(ToolParser): + """Mock tool parser for testing""" + + def extract_tool_calls(self, model_output, request): + return {"tool_calls": [], "content": model_output} + + def extract_tool_calls_streaming(self, previous_text, current_text, delta_text, + previous_token_ids, current_token_ids, delta_token_ids, request): + return {"role": "assistant", "content": delta_text} + + +class TestToolParser(unittest.TestCase): + """Test ToolParser base class""" + + def setUp(self): + """Set up test environment""" + self.mock_tokenizer = MagicMock() + self.mock_tokenizer.get_vocab.return_value = {"token1": 1, "token2": 2} + + def test_tool_parser_init(self): + """Test ToolParser initialization""" + parser = MockToolParser(self.mock_tokenizer) + + self.assertEqual(parser.prev_tool_call_arr, []) + self.assertEqual(parser.current_tool_id, -1) + self.assertEqual(parser.current_tool_name_sent, False) + self.assertEqual(parser.streamed_args_for_tool, []) + self.assertEqual(parser.model_tokenizer, self.mock_tokenizer) + + def test_tool_parser_vocab_property(self): + """Test vocab property caching""" + parser = MockToolParser(self.mock_tokenizer) + + # First access + vocab1 = parser.vocab + self.assertEqual(vocab1, {"token1": 1, "token2": 2}) + self.mock_tokenizer.get_vocab.assert_called_once() + + # Second access should use cached value + vocab2 = parser.vocab + self.assertEqual(vocab2, {"token1": 1, "token2": 2}) + self.mock_tokenizer.get_vocab.assert_called_once() # Still only called once + + def test_adjust_request_default(self): + """Test default adjust_request method""" + parser = MockToolParser(self.mock_tokenizer) + mock_request = MagicMock() + + result = parser.adjust_request(mock_request) + self.assertEqual(result, mock_request) + + def test_extract_tool_calls_implemented(self): + """Test that extract_tool_calls is implemented in mock""" + parser = MockToolParser(self.mock_tokenizer) + mock_request = MagicMock() + + result = parser.extract_tool_calls("test output", mock_request) + self.assertEqual(result, {"tool_calls": [], "content": "test output"}) + + def test_extract_tool_calls_streaming_implemented(self): + """Test that extract_tool_calls_streaming is implemented in mock""" + parser = MockToolParser(self.mock_tokenizer) + mock_request = MagicMock() + + result = parser.extract_tool_calls_streaming( + "prev", "curr", "delta", [1, 2], [1, 2, 3], [3], mock_request + ) + self.assertEqual(result, {"role": "assistant", "content": "delta"}) + + def test_base_tool_parser_abstract_methods(self): + """Test that base ToolParser raises NotImplementedError for abstract methods""" + parser = ToolParser(self.mock_tokenizer) + mock_request = MagicMock() + + with self.assertRaises(NotImplementedError): + parser.extract_tool_calls("test", mock_request) + + with self.assertRaises(NotImplementedError): + parser.extract_tool_calls_streaming( + "prev", "curr", "delta", [1], [1, 2], [2], mock_request + ) + + +class TestToolParserManager(unittest.TestCase): + """Test ToolParserManager class""" + + def setUp(self): + """Set up test environment""" + # Clear any existing parsers + ToolParserManager.tool_parsers = {} + + def tearDown(self): + """Clean up after tests""" + # Clear parsers to avoid interference + ToolParserManager.tool_parsers = {} + + def test_register_module_as_method(self): + """Test registering module as method call""" + ToolParserManager.register_module("test_parser", module=MockToolParser) + + self.assertIn("test_parser", ToolParserManager.tool_parsers) + self.assertEqual(ToolParserManager.tool_parsers["test_parser"], MockToolParser) + + def test_register_module_as_decorator(self): + """Test registering module as decorator""" + @ToolParserManager.register_module("decorated_parser") + class DecoratedParser(ToolParser): + pass + + self.assertIn("decorated_parser", ToolParserManager.tool_parsers) + self.assertEqual(ToolParserManager.tool_parsers["decorated_parser"], DecoratedParser) + + def test_register_module_multiple_names(self): + """Test registering module with multiple names""" + ToolParserManager.register_module(["name1", "name2"], module=MockToolParser) + + self.assertIn("name1", ToolParserManager.tool_parsers) + self.assertIn("name2", ToolParserManager.tool_parsers) + self.assertEqual(ToolParserManager.tool_parsers["name1"], MockToolParser) + self.assertEqual(ToolParserManager.tool_parsers["name2"], MockToolParser) + + def test_register_module_default_name(self): + """Test registering module with default name""" + ToolParserManager.register_module(module=MockToolParser) + + self.assertIn("MockToolParser", ToolParserManager.tool_parsers) + self.assertEqual(ToolParserManager.tool_parsers["MockToolParser"], MockToolParser) + + def test_register_module_force_false_existing(self): + """Test registering module with force=False when name exists""" + ToolParserManager.tool_parsers["existing"] = MockToolParser + + class AnotherParser(ToolParser): + pass + + with self.assertRaises(KeyError): + ToolParserManager.register_module("existing", force=False, module=AnotherParser) + + def test_register_module_invalid_type(self): + """Test registering invalid module type""" + class NotAToolParser: + pass + + with self.assertRaises(TypeError): + ToolParserManager.register_module("invalid", module=NotAToolParser) + + def test_register_module_invalid_force_type(self): + """Test registering with invalid force parameter""" + with self.assertRaises(TypeError): + ToolParserManager.register_module("test", force="not_bool", module=MockToolParser) + + def test_register_module_invalid_name_type(self): + """Test registering with invalid name parameter""" + with self.assertRaises(TypeError): + ToolParserManager.register_module(123, module=MockToolParser) + + def test_get_tool_parser_existing(self): + """Test getting existing tool parser""" + ToolParserManager.tool_parsers["test_parser"] = MockToolParser + + result = ToolParserManager.get_tool_parser("test_parser") + self.assertEqual(result, MockToolParser) + + def test_get_tool_parser_nonexistent(self): + """Test getting non-existent tool parser""" + with self.assertRaises(KeyError) as cm: + ToolParserManager.get_tool_parser("nonexistent") + + self.assertIn("'nonexistent' not found in tool_parsers", str(cm.exception)) + + def test_import_tool_parser_success(self): + """Test successful tool parser import""" + plugin_path = "/path/to/plugin.py" + + # Should not raise exceptions + ToolParserManager.import_tool_parser(plugin_path) + + def test_import_tool_parser_failure(self): + """Test failed tool parser import""" + plugin_path = "/path/to/plugin.py" + + # Should handle exceptions gracefully + ToolParserManager.import_tool_parser(plugin_path) + + def test_import_tool_parser_module_name_extraction(self): + """Test module name extraction from path""" + # Mock doesn't actually import, but tests path processing + ToolParserManager.import_tool_parser("/complex/path/to/my_parser.py") + # Should not raise exceptions + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/entrypoints/openai/test_ernie_x1_tool_parser.py b/tests/entrypoints/openai/test_ernie_x1_tool_parser.py new file mode 100644 index 000000000..5405c1a76 --- /dev/null +++ b/tests/entrypoints/openai/test_ernie_x1_tool_parser.py @@ -0,0 +1,337 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import json +import re +from unittest.mock import MagicMock + + +# Mock structures to avoid import dependencies +class ExtractedToolCallInformation: + def __init__(self, tools_called=False, tool_calls=None, content=""): + self.tools_called = tools_called + self.tool_calls = tool_calls or [] + self.content = content + + +class DeltaMessage: + def __init__(self, role="assistant", content="", tool_calls=None): + self.role = role + self.content = content + self.tool_calls = tool_calls or [] + + +class ToolCall: + def __init__(self, id, type, function): + self.id = id + self.type = type + self.function = function + + +class FunctionCall: + def __init__(self, name="", arguments=""): + self.name = name + self.arguments = arguments + + +# Simplified version of ErnieX1ToolParser for testing +class ErnieX1ToolParser: + """Simplified Ernie X1 Tool parser for testing""" + + def __init__(self, tokenizer): + self.model_tokenizer = tokenizer + self.prev_tool_call_arr = [] + self.current_tool_id = -1 + self.current_tool_name_sent = False + self.streamed_args_for_tool = [] + self.buffer = "" + self.bracket_counts = {"total_l": 0, "total_r": 0} + self.tool_call_start_token = "" + self.tool_call_end_token = "" + + # Mock vocab access + self.vocab = getattr(tokenizer, 'vocab', {}) or tokenizer.get_vocab() + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token, 1000) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token, 1001) + + def extract_tool_calls(self, model_output: str, request) -> ExtractedToolCallInformation: + """Extract tool calls from complete model response""" + try: + tool_calls = [] + + # Check for invalid tags before tool calls + if re.search(r"[\s\S]*?\s*(?=)", model_output): + return ExtractedToolCallInformation(tools_called=False, content=model_output) + + function_call_arr = [] + remaining_text = model_output + + while True: + # Find next tool_call block + tool_call_pos = remaining_text.find("") + if tool_call_pos == -1: + break + + # Extract content after tool_call start + tool_content_start = tool_call_pos + len("") + tool_content_end = remaining_text.find("", tool_content_start) + + tool_json = "" + if tool_content_end == -1: + # Handle unclosed tool_call block (truncation case) + tool_json = remaining_text[tool_content_start:].strip() + remaining_text = "" + else: + # Handle complete tool_call block + tool_json = remaining_text[tool_content_start:tool_content_end].strip() + remaining_text = remaining_text[tool_content_end + len(""):] + + if not tool_json: + continue + + # Process JSON content + tool_json = tool_json.strip() + if not tool_json.startswith("{"): + tool_json = "{" + tool_json + if not tool_json.endswith("}"): + tool_json = tool_json + "}" + + try: + # Try standard JSON parsing first + tool_data = json.loads(tool_json) + if isinstance(tool_data, dict) and "name" in tool_data and "arguments" in tool_data: + function_call_arr.append({ + "name": tool_data["name"], + "arguments": tool_data["arguments"], + "_is_complete": True, + }) + continue + except json.JSONDecodeError: + # Handle partial JSON or malformed JSON + pass + + # Convert to tool calls format + for func_call in function_call_arr: + tool_calls.append(ToolCall( + id=f"call_{len(tool_calls)}", + type="function", + function=FunctionCall( + name=func_call["name"], + arguments=json.dumps(func_call["arguments"]) + if isinstance(func_call["arguments"], dict) + else str(func_call["arguments"]) + ) + )) + + return ExtractedToolCallInformation( + tools_called=len(tool_calls) > 0, + tool_calls=tool_calls, + content=model_output + ) + + except Exception as e: + return ExtractedToolCallInformation(tools_called=False, content=model_output) + + def extract_tool_calls_streaming(self, previous_text, current_text, delta_text, + previous_token_ids, current_token_ids, delta_token_ids, request): + """Extract tool calls for streaming response (simplified)""" + # Simplified streaming implementation + if "" in delta_text: + return DeltaMessage(role="assistant", content="") + elif "" in delta_text: + return DeltaMessage(role="assistant", content="") + else: + return DeltaMessage(role="assistant", content=delta_text) + + +class TestErnieX1ToolParser(unittest.TestCase): + """Test ErnieX1ToolParser functionality""" + + def setUp(self): + """Set up test environment""" + self.mock_tokenizer = MagicMock() + # Set up vocab as a real dictionary + vocab_dict = { + "": 1000, + "": 1001, + "token1": 1, + "token2": 2 + } + self.mock_tokenizer.get_vocab.return_value = vocab_dict + self.mock_tokenizer.vocab = vocab_dict + self.parser = ErnieX1ToolParser(self.mock_tokenizer) + + def test_init(self): + """Test parser initialization""" + self.assertEqual(self.parser.tool_call_start_token, "") + self.assertEqual(self.parser.tool_call_end_token, "") + self.assertEqual(self.parser.tool_call_start_token_id, 1000) + self.assertEqual(self.parser.tool_call_end_token_id, 1001) + self.assertEqual(self.parser.prev_tool_call_arr, []) + self.assertEqual(self.parser.current_tool_id, -1) + self.assertFalse(self.parser.current_tool_name_sent) + + def test_extract_tool_calls_no_tools(self): + """Test extracting tool calls when none present""" + model_output = "This is a regular response without tool calls." + request = MagicMock() + + result = self.parser.extract_tool_calls(model_output, request) + + self.assertFalse(result.tools_called) + self.assertEqual(len(result.tool_calls), 0) + self.assertEqual(result.content, model_output) + + def test_extract_tool_calls_single_complete(self): + """Test extracting a single complete tool call""" + model_output = ''' +{"name": "get_weather", "arguments": {"location": "Beijing"}} +''' + request = MagicMock() + + result = self.parser.extract_tool_calls(model_output, request) + + self.assertTrue(result.tools_called) + self.assertEqual(len(result.tool_calls), 1) + self.assertEqual(result.tool_calls[0].function.name, "get_weather") + self.assertIn("Beijing", result.tool_calls[0].function.arguments) + + def test_extract_tool_calls_multiple_complete(self): + """Test extracting multiple complete tool calls""" + model_output = ''' +{"name": "get_weather", "arguments": {"location": "Beijing"}} + + +{"name": "get_time", "arguments": {"timezone": "UTC"}} +''' + request = MagicMock() + + result = self.parser.extract_tool_calls(model_output, request) + + self.assertTrue(result.tools_called) + self.assertEqual(len(result.tool_calls), 2) + self.assertEqual(result.tool_calls[0].function.name, "get_weather") + self.assertEqual(result.tool_calls[1].function.name, "get_time") + + def test_extract_tool_calls_incomplete(self): + """Test extracting incomplete tool call (truncated)""" + model_output = ''' +{"name": "get_weather", "arguments": {"location": "Beijing"''' + request = MagicMock() + + result = self.parser.extract_tool_calls(model_output, request) + + # Should handle incomplete JSON gracefully + self.assertIsInstance(result, ExtractedToolCallInformation) + + def test_extract_tool_calls_malformed_json(self): + """Test extracting tool calls with malformed JSON""" + model_output = ''' +"name": "get_weather", "arguments": {"location": "Beijing"} +''' + request = MagicMock() + + result = self.parser.extract_tool_calls(model_output, request) + + # Should try to fix JSON by adding braces + self.assertIsInstance(result, ExtractedToolCallInformation) + + def test_extract_tool_calls_with_response_tags(self): + """Test extracting tool calls with invalid response tags""" + model_output = ''' +This should not be here before tool calls + + +{"name": "get_weather", "arguments": {"location": "Beijing"}} +''' + request = MagicMock() + + result = self.parser.extract_tool_calls(model_output, request) + + # Should reject due to invalid format + self.assertFalse(result.tools_called) + + def test_extract_tool_calls_empty_tool_call(self): + """Test extracting empty tool call blocks""" + model_output = ''' +''' + request = MagicMock() + + result = self.parser.extract_tool_calls(model_output, request) + + self.assertFalse(result.tools_called) + self.assertEqual(len(result.tool_calls), 0) + + def test_extract_tool_calls_streaming_basic(self): + """Test basic streaming tool call extraction""" + previous_text = "" + current_text = "Let me check the weather" + delta_text = "Let me check the weather" + request = MagicMock() + + result = self.parser.extract_tool_calls_streaming( + previous_text, current_text, delta_text, [], [], [], request + ) + + self.assertIsInstance(result, DeltaMessage) + self.assertEqual(result.role, "assistant") + self.assertEqual(result.content, delta_text) + + def test_extract_tool_calls_streaming_start_token(self): + """Test streaming with tool call start token""" + delta_text = "" + request = MagicMock() + + result = self.parser.extract_tool_calls_streaming( + "", "", delta_text, [], [], [], request + ) + + self.assertIsInstance(result, DeltaMessage) + self.assertEqual(result.role, "assistant") + self.assertEqual(result.content, "") # Should suppress token + + def test_extract_tool_calls_streaming_end_token(self): + """Test streaming with tool call end token""" + delta_text = "" + request = MagicMock() + + result = self.parser.extract_tool_calls_streaming( + "", "", delta_text, [], [], [], request + ) + + self.assertIsInstance(result, DeltaMessage) + self.assertEqual(result.role, "assistant") + self.assertEqual(result.content, "") # Should suppress token + + def test_vocab_property(self): + """Test vocab property access""" + vocab = self.parser.vocab + self.assertIn("", vocab) + self.assertIn("", vocab) + self.assertEqual(vocab[""], 1000) + self.assertEqual(vocab[""], 1001) + + def test_bracket_counting_init(self): + """Test bracket counting initialization""" + self.assertEqual(self.parser.bracket_counts["total_l"], 0) + self.assertEqual(self.parser.bracket_counts["total_r"], 0) + + def test_buffer_init(self): + """Test buffer initialization""" + self.assertEqual(self.parser.buffer, "") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/entrypoints/openai/test_tool_parsers_utils.py b/tests/entrypoints/openai/test_tool_parsers_utils.py new file mode 100644 index 000000000..b17fc089f --- /dev/null +++ b/tests/entrypoints/openai/test_tool_parsers_utils.py @@ -0,0 +1,261 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import unittest +from json import JSONDecodeError +from unittest.mock import MagicMock, patch + +# Mock partial_json_parser to avoid dependency issues +try: + from partial_json_parser.core.options import Allow + import partial_json_parser +except ImportError: + # Create mock objects if not available + class Allow: + ALL = "ALL" + + partial_json_parser = MagicMock() + +# Copy the utility functions directly for testing to avoid import issues +def find_common_prefix(s1: str, s2: str) -> str: + """Finds a common prefix that is shared between two strings""" + prefix = "" + min_length = min(len(s1), len(s2)) + for i in range(0, min_length): + if s1[i] == s2[i]: + prefix += s1[i] + else: + break + return prefix + +def find_common_suffix(s1: str, s2: str) -> str: + """Finds a common suffix shared between two strings""" + suffix = "" + min_length = min(len(s1), len(s2)) + for i in range(1, min_length + 1): + if s1[-i] == s2[-i] and not s1[-i].isalnum(): + suffix = s1[-i] + suffix + else: + break + return suffix + +def extract_intermediate_diff(curr: str, old: str) -> str: + """Extract the difference in the middle between two strings""" + suffix = find_common_suffix(curr, old) + old = old[::-1].replace(suffix[::-1], "", 1)[::-1] + prefix = find_common_prefix(curr, old) + diff = curr + if len(suffix): + diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1] + if len(prefix): + diff = diff.replace(prefix, "", 1) + return diff + +def find_all_indices(string: str, substring: str) -> list[int]: + """Find all (starting) indices of a substring in a given string""" + indices = [] + index = -1 + while True: + index = string.find(substring, index + 1) + if index == -1: + break + indices.append(index) + return indices + +def is_complete_json(input_str: str) -> bool: + try: + json.loads(input_str) + return True + except JSONDecodeError: + return False + +def consume_space(i: int, s: str) -> int: + while i < len(s) and s[i].isspace(): + i += 1 + return i + +def partial_json_loads(input_str: str, flags) -> tuple: + try: + return (json.loads(input_str), len(input_str)) + except JSONDecodeError as e: + if "Extra data" in e.msg: + from json import JSONDecoder + dec = JSONDecoder() + return dec.raw_decode(input_str) + raise + + +class TestToolParsersUtils(unittest.TestCase): + """Test utility functions for tool parsers""" + + def test_find_common_prefix(self): + """Test finding common prefix between strings""" + # Basic test + result = find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') + self.assertEqual(result, '{"fruit": "ap') + + # No common prefix + result = find_common_prefix('hello', 'world') + self.assertEqual(result, '') + + # Identical strings + result = find_common_prefix('test', 'test') + self.assertEqual(result, 'test') + + # Empty strings + result = find_common_prefix('', '') + self.assertEqual(result, '') + + # One empty string + result = find_common_prefix('test', '') + self.assertEqual(result, '') + + def test_find_common_suffix(self): + """Test finding common suffix between strings""" + # Basic test with non-alphanumeric suffix + result = find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') + self.assertEqual(result, '"}') + + # No common suffix + result = find_common_suffix('hello', 'world') + self.assertEqual(result, '') + + # Identical strings + result = find_common_suffix('test{}', 'test{}') + self.assertEqual(result, '{}') + + # Empty strings + result = find_common_suffix('', '') + self.assertEqual(result, '') + + # Suffix with alphanumeric character (should stop) + result = find_common_suffix('test123}', 'best123}') + self.assertEqual(result, '}') + + def test_extract_intermediate_diff(self): + """Test extracting difference between two strings""" + # Basic test + result = extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}') + self.assertEqual(result, 'ple') + + # No difference + result = extract_intermediate_diff('test', 'test') + self.assertEqual(result, '') + + # Complete replacement (common prefix and suffix removed) + result = extract_intermediate_diff('{"new": "value"}', '{"old": "data"}') + self.assertEqual(result, 'new": "value') # Fixed: prefix {"" and suffix "}" removed + + # Adding characters at the end + result = extract_intermediate_diff('hello world!', 'hello') + self.assertEqual(result, ' world!') + + def test_find_all_indices(self): + """Test finding all indices of substring""" + # Basic test + result = find_all_indices('hello world hello', 'hello') + self.assertEqual(result, [0, 12]) + + # No matches + result = find_all_indices('hello world', 'xyz') + self.assertEqual(result, []) + + # Overlapping matches + result = find_all_indices('aaa', 'aa') + self.assertEqual(result, [0, 1]) + + # Empty substring (should find nothing) + result = find_all_indices('hello', '') + # find returns every position for empty string, but we expect specific behavior + self.assertIsInstance(result, list) + + # Empty string + result = find_all_indices('', 'test') + self.assertEqual(result, []) + + def test_partial_json_loads(self): + """Test partial JSON loading with error handling""" + # Valid complete JSON + result, length = partial_json_loads('{"key": "value"}', Allow.ALL) + self.assertEqual(result, {"key": "value"}) + self.assertEqual(length, 16) + + # Valid partial JSON that partial_json_parser can handle + try: + result, length = partial_json_loads('{"key": "val', Allow.ALL) + self.assertIsInstance(result, dict) + self.assertIsInstance(length, int) + except JSONDecodeError: + # This is acceptable for partial JSON + pass + + # Valid JSON with extra data (should use raw_decode) + try: + result, length = partial_json_loads('{"key": "value"} extra', Allow.ALL) + self.assertEqual(result, {"key": "value"}) + self.assertEqual(length, 16) + except JSONDecodeError: + # This might fail depending on implementation + pass + + def test_is_complete_json(self): + """Test checking if string is complete JSON""" + # Valid JSON + self.assertTrue(is_complete_json('{"key": "value"}')) + self.assertTrue(is_complete_json('[]')) + self.assertTrue(is_complete_json('null')) + self.assertTrue(is_complete_json('true')) + self.assertTrue(is_complete_json('123')) + self.assertTrue(is_complete_json('"string"')) + + # Invalid JSON + self.assertFalse(is_complete_json('{"key": "value"')) + self.assertFalse(is_complete_json('{"key":}')) + self.assertFalse(is_complete_json('')) + self.assertFalse(is_complete_json('invalid')) + + def test_consume_space(self): + """Test consuming whitespace characters""" + # Basic test + result = consume_space(0, ' hello') + self.assertEqual(result, 3) + + # No spaces + result = consume_space(0, 'hello') + self.assertEqual(result, 0) + + # All spaces + result = consume_space(0, ' ') + self.assertEqual(result, 5) + + # Starting from middle + result = consume_space(2, 'he llo') + self.assertEqual(result, 5) + + # At end of string + result = consume_space(5, 'hello') + self.assertEqual(result, 5) + + # Beyond end of string + result = consume_space(10, 'hello') + self.assertEqual(result, 10) + + # Mixed whitespace (\t\n\r + space = 5 total characters) + result = consume_space(0, ' \t\n\r hello') + self.assertEqual(result, 5) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/entrypoints/openai/test_utils.py b/tests/entrypoints/openai/test_utils.py new file mode 100644 index 000000000..e57b2e6c3 --- /dev/null +++ b/tests/entrypoints/openai/test_utils.py @@ -0,0 +1,291 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import heapq +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + + +# Copy the DealerConnectionManager class to avoid import dependencies +class DealerConnectionManager: + """Manager for dealer connections, supporting multiplexing and connection reuse""" + + def __init__(self, pid, max_connections=10): + self.pid = pid + self.max_connections = max(max_connections, 10) + self.connections = [] + self.connection_load = [] + self.connection_heap = [] + self.request_map = {} # request_id -> response_queue + self.request_num = {} # request_id -> num_choices + self.lock = asyncio.Lock() + self.connection_tasks = [] + self.running = False + + async def initialize(self): + """initialize all connections""" + self.running = True + for index in range(self.max_connections): + await self._add_connection(index) + + async def _add_connection(self, index): + """create a new connection and start listening task""" + try: + # Mock aiozmq.create_zmq_stream + dealer = MagicMock() + dealer.read = AsyncMock() + dealer.close = MagicMock() + + async with self.lock: + self.connections.append(dealer) + self.connection_load.append(0) + heapq.heappush(self.connection_heap, (0, index)) + + # start listening + task = asyncio.create_task(self._listen_connection(dealer, index)) + self.connection_tasks.append(task) + return True + except Exception as e: + return False + + async def _listen_connection(self, dealer, conn_index): + """listen for messages from the dealer connection""" + while self.running: + try: + raw_data = await dealer.read() + # Mock msgpack.unpackb + response = [None, {"request_id": "test-123", "finished": True}] + request_id = response[-1]["request_id"] + if "cmpl" == request_id[:4]: + request_id = request_id.rsplit("-", 1)[0] + async with self.lock: + if request_id in self.request_map: + await self.request_map[request_id].put(response) + if response[-1]["finished"]: + self.request_num[request_id] -= 1 + if self.request_num[request_id] == 0: + self._update_load(conn_index, -1) + except Exception as e: + break + + def _update_load(self, conn_index, delta): + """Update connection load and maintain the heap""" + self.connection_load[conn_index] += delta + heapq.heapify(self.connection_heap) + + def _get_least_loaded_connection(self): + """Get the least loaded connection""" + if not self.connection_heap: + return None + + load, conn_index = self.connection_heap[0] + self._update_load(conn_index, 1) + + return self.connections[conn_index] + + async def get_connection(self, request_id, num_choices=1): + """get a connection for the request""" + response_queue = asyncio.Queue() + + async with self.lock: + self.request_map[request_id] = response_queue + self.request_num[request_id] = num_choices + dealer = self._get_least_loaded_connection() + if not dealer: + raise RuntimeError("No available connections") + + return dealer, response_queue + + async def cleanup_request(self, request_id): + """clean up the request after it is finished""" + async with self.lock: + if request_id in self.request_map: + del self.request_map[request_id] + del self.request_num[request_id] + + async def close(self): + """close all connections and tasks""" + self.running = False + + for task in self.connection_tasks: + task.cancel() + + async with self.lock: + for dealer in self.connections: + try: + dealer.close() + except: + pass + self.connections.clear() + self.connection_load.clear() + self.request_map.clear() + + +class TestDealerConnectionManager(unittest.IsolatedAsyncioTestCase): + """Test DealerConnectionManager class""" + + def test_init(self): + """Test DealerConnectionManager initialization""" + manager = DealerConnectionManager(pid=123, max_connections=5) + + self.assertEqual(manager.pid, 123) + self.assertEqual(manager.max_connections, 10) # Should be at least 10 + self.assertEqual(manager.connections, []) + self.assertEqual(manager.connection_load, []) + self.assertEqual(manager.connection_heap, []) + self.assertEqual(manager.request_map, {}) + self.assertEqual(manager.request_num, {}) + self.assertFalse(manager.running) + + def test_init_min_connections(self): + """Test minimum connections constraint""" + manager = DealerConnectionManager(pid=123, max_connections=5) + self.assertEqual(manager.max_connections, 10) # Should be at least 10 + + manager = DealerConnectionManager(pid=123, max_connections=15) + self.assertEqual(manager.max_connections, 15) # Should keep 15 + + async def test_initialize(self): + """Test connection initialization""" + manager = DealerConnectionManager(pid=123, max_connections=10) + + with patch.object(manager, '_add_connection', new_callable=AsyncMock) as mock_add: + mock_add.return_value = True + await manager.initialize() + + self.assertTrue(manager.running) + self.assertEqual(mock_add.call_count, 10) + + async def test_add_connection_success(self): + """Test successful connection addition""" + manager = DealerConnectionManager(pid=123, max_connections=10) + + result = await manager._add_connection(0) + + self.assertTrue(result) + self.assertEqual(len(manager.connections), 1) + self.assertEqual(len(manager.connection_load), 1) + self.assertEqual(len(manager.connection_heap), 1) + self.assertEqual(manager.connection_load[0], 0) + self.assertEqual(manager.connection_heap[0], (0, 0)) + + def test_update_load(self): + """Test connection load update""" + manager = DealerConnectionManager(pid=123, max_connections=10) + manager.connection_load = [0, 1, 2] + manager.connection_heap = [(0, 0), (1, 1), (2, 2)] + + manager._update_load(0, 2) + + self.assertEqual(manager.connection_load[0], 2) + # Heap should be reordered + self.assertIn((1, 1), manager.connection_heap) + + def test_get_least_loaded_connection_empty(self): + """Test getting connection when none available""" + manager = DealerConnectionManager(pid=123, max_connections=10) + + result = manager._get_least_loaded_connection() + self.assertIsNone(result) + + async def test_get_least_loaded_connection(self): + """Test getting least loaded connection""" + manager = DealerConnectionManager(pid=123, max_connections=10) + + # Add a connection first + await manager._add_connection(0) + + result = manager._get_least_loaded_connection() + self.assertIsNotNone(result) + self.assertEqual(manager.connection_load[0], 1) # Load should be incremented + + async def test_get_connection(self): + """Test getting connection for request""" + manager = DealerConnectionManager(pid=123, max_connections=10) + await manager._add_connection(0) + + dealer, queue = await manager.get_connection("test-request", num_choices=2) + + self.assertIsNotNone(dealer) + self.assertIsInstance(queue, asyncio.Queue) + self.assertIn("test-request", manager.request_map) + self.assertEqual(manager.request_num["test-request"], 2) + + async def test_get_connection_no_available(self): + """Test getting connection when none available""" + manager = DealerConnectionManager(pid=123, max_connections=10) + + with self.assertRaises(RuntimeError) as cm: + await manager.get_connection("test-request") + + self.assertIn("No available connections", str(cm.exception)) + + async def test_cleanup_request(self): + """Test request cleanup""" + manager = DealerConnectionManager(pid=123, max_connections=10) + manager.request_map["test-request"] = asyncio.Queue() + manager.request_num["test-request"] = 1 + + await manager.cleanup_request("test-request") + + self.assertNotIn("test-request", manager.request_map) + self.assertNotIn("test-request", manager.request_num) + + async def test_cleanup_request_nonexistent(self): + """Test cleanup of non-existent request""" + manager = DealerConnectionManager(pid=123, max_connections=10) + + # Should not raise an error + await manager.cleanup_request("nonexistent") + + async def test_close(self): + """Test closing manager""" + manager = DealerConnectionManager(pid=123, max_connections=10) + + # Add some connections and tasks + await manager._add_connection(0) + manager.request_map["test"] = asyncio.Queue() + + await manager.close() + + self.assertFalse(manager.running) + self.assertEqual(len(manager.connections), 0) + self.assertEqual(len(manager.connection_load), 0) + self.assertEqual(len(manager.request_map), 0) + + async def test_listen_connection_basic(self): + """Test basic connection listening functionality""" + manager = DealerConnectionManager(pid=123, max_connections=10) + mock_dealer = MagicMock() + mock_dealer.read = AsyncMock() + + # Set up to stop after one iteration + manager.running = True + + # Mock the read to return once then stop + async def mock_read_side_effect(): + manager.running = False # Stop after first read + return [b'mock_data'] + + mock_dealer.read.side_effect = mock_read_side_effect + + # This should not raise an exception + await manager._listen_connection(mock_dealer, 0) + + mock_dealer.read.assert_called() + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py new file mode 100644 index 000000000..81288e589 --- /dev/null +++ b/tests/entrypoints/test_chat_utils.py @@ -0,0 +1,295 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest +import uuid +from pathlib import Path +from copy import deepcopy +from urllib.parse import urlparse + + +# Standalone implementations for testing (copied from source) +def random_tool_call_id() -> str: + return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" + + +def load_chat_template(chat_template, is_literal=False): + if chat_template is None: + return None + if is_literal: + if isinstance(chat_template, Path): + raise TypeError("chat_template is expected to be read directly from its value") + return chat_template + + try: + with open(chat_template) as f: + return f.read() + except OSError as e: + if isinstance(chat_template, Path): + raise + JINJA_CHARS = "{}\n" + if not any(c in chat_template for c in JINJA_CHARS): + msg = ( + f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}" + ) + raise ValueError(msg) from e + + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + return load_chat_template(chat_template, is_literal=True) + + +class MockMediaIO: + def load_bytes(self, data): + return f"media_from_bytes({len(data)})" + + def load_base64(self, media_type, data): + return f"media_from_base64({media_type}, {data})" + + def load_file(self, path): + return f"media_from_file({path})" + + +class MultiModalPartParser: + def __init__(self): + self.image_io = MockMediaIO() + self.video_io = MockMediaIO() + + def parse_image(self, image_url): + return self.load_from_url(image_url, self.image_io) + + def parse_video(self, video_url): + return self.load_from_url(video_url, self.video_io) + + def load_from_url(self, url, media_io): + parsed = urlparse(url) + if parsed.scheme.startswith("http"): + media_bytes = b"mock_http_data" # Mock HTTP response + return media_io.load_bytes(media_bytes) + + if parsed.scheme.startswith("data"): + data_spec, data = parsed.path.split(",", 1) + media_type, data_type = data_spec.split(";", 1) + return media_io.load_base64(media_type, data) + + if parsed.scheme.startswith("file"): + localpath = parsed.path + return media_io.load_file(localpath) + + +def parse_content_part(mm_parser, part): + part_type = part.get("type", None) + + if part_type == "text": + return part + + if part_type == "image_url": + content = part.get("image_url", {}).get("url", None) + image = mm_parser.parse_image(content) + parsed = deepcopy(part) + del parsed["image_url"]["url"] + parsed["image"] = image + parsed["type"] = "image" + return parsed + + if part_type == "video_url": + content = part.get("video_url", {}).get("url", None) + video = mm_parser.parse_video(content) + parsed = deepcopy(part) + del parsed["video_url"]["url"] + parsed["video"] = video + parsed["type"] = "video" + return parsed + + raise ValueError(f"Unknown content part type: {part_type}") + + +def parse_chat_messages(messages): + mm_parser = MultiModalPartParser() + + conversation = [] + for message in messages: + role = message["role"] + content = message["content"] + + parsed_content = [] + if content is None: + parsed_content = [] + elif isinstance(content, str): + parsed_content = [{"type": "text", "text": content}] + else: + parsed_content = [parse_content_part(mm_parser, part) for part in content] + + conversation.append({"role": role, "content": parsed_content}) + return conversation + + +class TestChatUtils(unittest.TestCase): + """Test chat utility functions""" + + def test_random_tool_call_id(self): + """Test random tool call ID generation""" + tool_id = random_tool_call_id() + + # Should start with expected prefix + self.assertTrue(tool_id.startswith("chatcmpl-tool-")) + + # Should contain a UUID hex string + uuid_part = tool_id.replace("chatcmpl-tool-", "") + self.assertEqual(len(uuid_part), 32) # UUID hex is 32 chars + + # Should be different each time + tool_id2 = random_tool_call_id() + self.assertNotEqual(tool_id, tool_id2) + + def test_load_chat_template_literal(self): + """Test loading chat template as literal string""" + template = "Hello {{name}}" + result = load_chat_template(template, is_literal=True) + self.assertEqual(result, template) + + def test_load_chat_template_literal_with_path_object(self): + """Test loading chat template with Path object in literal mode should raise error""" + template_path = Path("/some/path") + with self.assertRaises(TypeError): + load_chat_template(template_path, is_literal=True) + + def test_load_chat_template_from_file(self): + """Test loading chat template from file""" + template_content = "Hello {{name}}, how are you?" + + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write(template_content) + temp_path = f.name + + try: + result = load_chat_template(temp_path) + self.assertEqual(result, template_content) + finally: + os.unlink(temp_path) + + def test_load_chat_template_file_not_found(self): + """Test loading chat template from non-existent file""" + # Test with path-like string that looks like a file path + with self.assertRaises(ValueError) as cm: + load_chat_template("/non/existent/path.txt") + + self.assertIn("looks like a file path", str(cm.exception)) + + def test_load_chat_template_fallback_to_literal(self): + """Test fallback to literal when file doesn't exist but contains jinja chars""" + template = "Hello {{name}}\nHow are you?" + result = load_chat_template(template) + self.assertEqual(result, template) + + def test_load_chat_template_none(self): + """Test loading None template""" + result = load_chat_template(None) + self.assertIsNone(result) + + def test_parse_chat_messages_text_only(self): + """Test parsing chat messages with text content only""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"} + ] + + result = parse_chat_messages(messages) + + expected = [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]}, + {"role": "assistant", "content": [{"type": "text", "text": "Hi there!"}]} + ] + + self.assertEqual(result, expected) + + def test_parse_chat_messages_none_content(self): + """Test parsing chat messages with None content""" + messages = [{"role": "user", "content": None}] + result = parse_chat_messages(messages) + + expected = [{"role": "user", "content": []}] + self.assertEqual(result, expected) + + def test_parse_content_part_text(self): + """Test parsing text content part""" + parser = MultiModalPartParser() + part = {"type": "text", "text": "Hello world"} + + result = parse_content_part(parser, part) + self.assertEqual(result, part) + + def test_parse_content_part_image_url(self): + """Test parsing image URL content part""" + parser = MultiModalPartParser() + part = { + "type": "image_url", + "image_url": {"url": "http://example.com/image.jpg"} + } + + result = parse_content_part(parser, part) + + expected = { + "type": "image", + "image_url": {}, + "image": "media_from_bytes(14)" # Mock HTTP response data + } + self.assertEqual(result, expected) + + def test_parse_content_part_video_url(self): + """Test parsing video URL content part""" + parser = MultiModalPartParser() + part = { + "type": "video_url", + "video_url": {"url": "http://example.com/video.mp4"} + } + + result = parse_content_part(parser, part) + + expected = { + "type": "video", + "video_url": {}, + "video": "media_from_bytes(14)" # Mock HTTP response data + } + self.assertEqual(result, expected) + + def test_parse_content_part_unknown_type(self): + """Test parsing unknown content part type""" + parser = MultiModalPartParser() + part = {"type": "unknown", "data": "test"} + + with self.assertRaises(ValueError) as cm: + parse_content_part(parser, part) + + self.assertIn("Unknown content part type: unknown", str(cm.exception)) + + def test_multimodal_part_parser_data_url(self): + """Test parsing data URL""" + parser = MultiModalPartParser() + result = parser.load_from_url("", parser.image_io) + self.assertEqual(result, "media_from_base64(image/jpeg, SGVsbG8gV29ybGQ=)") + + def test_multimodal_part_parser_file_url(self): + """Test parsing file URL""" + parser = MultiModalPartParser() + result = parser.load_from_url("file:///path/to/image.jpg", parser.image_io) + self.assertEqual(result, "media_from_file(/path/to/image.jpg)") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file