[CI]【Hackathon 9th Sprint No.13】NO.13 功能模块 fastdeploy/model_executor/ops/triton_ops/triton_utils.py 单测补充 (#5035)

* Add unit tests for triton_utils.py

* update name

* update

* update

* update
This commit is contained in:
Winters Montagne
2025-11-17 11:43:31 +08:00
committed by GitHub
parent 36216e62f0
commit 02c83d65db
2 changed files with 518 additions and 1 deletions

View File

@@ -710,7 +710,7 @@ class KernelInterface:
+ f""" -s"{address_hint} {value_hint} {const_args}" """
+ f""" -g "{lanuch_grid}" """
)
all_tune_config = list(self.tune_config)
all_tune_config = [{key: value} for key, value in self.tune_config.items()]
if len(all_tune_config) == 0:
# when user do not specify config, we use const_hint_dict as config.
all_tune_config = [const_hint_dict]

View File

@@ -0,0 +1,517 @@
import os
import sys
import unittest
from unittest.mock import MagicMock, patch
import paddle
import triton
from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
KernelInterface,
SubstituteTemplate,
build_package,
extract_triton_kernel,
find_so_path,
get_dtype_str,
get_op_name_with_suffix,
get_pointer_hint,
get_value_hint,
multi_process_do,
paddle_use_triton,
rename_c_to_cu,
rendering_common_template,
)
TRITON_UTILS_PATH = "fastdeploy.model_executor.ops.triton_ops.triton_utils"
MOCK_GENERATED_DIR = "/tmp/generated"
class TestTritonUtils(unittest.TestCase):
# Test case to validate KernelInterface initialization
@patch("triton.runtime.jit.JITFunction")
@patch("os.system")
@patch("multiprocessing.Process")
def test_kernel_interface_initialization(self, mock_process, mock_system, mock_jit):
# Mock function for testing
def mock_function(a, b):
return a + b
# Add type annotations to the mock function
mock_function.__annotations__ = {"a": int, "b": int}
# Initialize KernelInterface with the mock function
kernel_interface = KernelInterface(mock_function, other_config={})
# Validate that the function and argument names are correctly initialized
self.assertIsNotNone(kernel_interface.func)
self.assertEqual(kernel_interface.key_args, ["1"])
self.assertIn("a", kernel_interface.arg_names)
self.assertIn("b", kernel_interface.arg_names)
# Test case for validating the paddle_use_triton decorator
@patch("triton.runtime.jit.JITFunction")
@patch("os.system")
def test_paddle_use_triton_decorator(self, mock_system, mock_jit):
mock_jit.return_value.fn = MagicMock()
# Apply the paddle_use_triton decorator to a mock function
@paddle_use_triton()
def mock_kernel(a, b):
return a + b
# Validate the result of the decorator
self.assertIsInstance(mock_kernel, KernelInterface)
# Test case for validating the build_package function
@patch("builtins.open", new_callable=MagicMock)
@patch("os.system")
def test_build_package(self, mock_system, mock_open):
MOCK_NAME = "test_package"
mock_system.return_value = 0
# Call build_package with mocked directory and package name
build_package(MOCK_GENERATED_DIR, MOCK_NAME)
# Assert that system command was called correctly
mock_system.assert_called_with(f"cd {MOCK_GENERATED_DIR} && {sys.executable} setup_cuda.py build")
# Test case for extracting Triton kernel with a JIT kernel
@patch("builtins.open", new_callable=MagicMock)
def test_extract_triton_kernel_with_jit_kernel(self, mock_open):
@triton.jit
def mock_kernel(x, y):
return x + y
mock_file = MagicMock()
mock_file.write = MagicMock()
mock_open.return_value = mock_file
file_name = "kernel.py"
# Extract the Triton kernel and write to the specified file
extract_triton_kernel(mock_kernel, file_name)
# Assert that file write was performed as expected
mock_open.assert_called_with(file_name, "w")
# Test case for extracting Triton kernel with a Python kernel
@patch("builtins.open", new_callable=MagicMock)
def test_extract_triton_kernel_with_python_kernel(self, mock_open):
def mock_kernel(x, y):
return x + y
mock_file = MagicMock()
mock_file.write = MagicMock()
mock_open.return_value = mock_file
file_name = "kernel.py"
# Extract the Triton kernel and write to the specified file
extract_triton_kernel(mock_kernel, file_name)
# Assert that file write was performed as expected
mock_open.assert_called_with(file_name, "w")
# Test case for validating multi-process execution of commands
@patch("os.system")
@patch("multiprocessing.Process")
def test_multi_process_do(self, mock_process, mock_system):
commands = ["echo 'hello'"] * 5
mock_system.return_value = 0
mock_process_instance = MagicMock()
mock_process.return_value = mock_process_instance
# Call multi_process_do with the list of commands
multi_process_do(commands)
# Assert the expected behavior of mock_process
self.assertEqual(mock_process.call_count, 40)
mock_process_instance.start.assert_called()
mock_process_instance.join.assert_called()
# Test case for renaming .c files to .cu in the specified directory
@patch("os.listdir")
@patch("os.rename")
def test_rename_c_to_cu(self, mock_rename, mock_listdir):
mock_listdir.return_value = ["file1.c"]
# Call rename_c_to_cu to rename the files in the directory
rename_c_to_cu(MOCK_GENERATED_DIR)
# Assert that the rename operation was performed correctly
mock_rename.assert_called_with(
os.path.join(MOCK_GENERATED_DIR, "file1.c"), os.path.join(MOCK_GENERATED_DIR, "file1.cu")
)
# Test case for validating template substitution
def test_substitute_template(self):
template = "Hello, ${name}! Welcome to ${place}."
values = {"name": "Alice", "place": "Wonderland"}
# Call SubstituteTemplate to replace placeholders in the template
result = SubstituteTemplate(template, values)
# Assert that the substitution worked as expected
self.assertEqual(result, "Hello, Alice! Welcome to Wonderland.")
# Test case for finding shared object (.so) file in a directory
@patch("os.walk")
def test_find_so_path_found(self, mock_os_walk):
mock_os_walk.return_value = [("/path/to/dir", [], ["file1.so", "file2.so"])]
# Call find_so_path to locate the .so file
so_path = find_so_path("/path/to/dir", "file1")
# Assert that the correct path is returned
self.assertEqual(so_path, "/path/to/dir/file1.so")
# Test case for handling the scenario when the .so file is not found
@patch("os.walk")
def test_find_so_path_not_found(self, mock_os_walk):
mock_os_walk.return_value = [("/path/to/dir", [], ["file1.txt", "file2.txt"])]
# Call find_so_path when the .so file is not present
so_path = find_so_path("/path/to/dir", "file")
# Assert that None is returned when the .so file is not found
self.assertIsNone(so_path)
# Test case for getting the operator name with suffix
def test_get_op_name_with_suffix(self):
result = get_op_name_with_suffix("op_name", [16, 1, 32])
# Assert the correct suffix is added to the operator name
self.assertEqual(result, "op_name16_1_16")
# Test case for getting value hints from a list
def test_get_value_hint(self):
result = get_value_hint([16, 1, 32])
# Assert that the correct value hint string is generated
self.assertEqual(result, "i64:16,i64:1,i64:16,")
# Test case for getting the string representation of a data type
def test_get_dtype_str(self):
result = get_dtype_str(paddle.float32)
# Assert the correct dtype string is returned
self.assertEqual(result, "_fp32")
with self.assertRaises(ValueError):
get_dtype_str(paddle.bool)
# Test case for getting pointer hints for different data types
def test_get_pointer_hint(self):
result = get_pointer_hint([paddle.float16, paddle.int32, paddle.uint8])
# Assert the correct pointer hint string is generated
self.assertEqual(result, "*fp16:16,*i32:16,*u8:16,")
class TestRenderingCommonTemplate(unittest.TestCase):
def mock_function(self):
def func(a: int, b: float = 2.0, c: bool = True, d: str = "test"):
pass
return func
# Test case for rendering a function template without return tensor
def test_rendering_with_no_return_tensor(self):
func = self.mock_function()
prepare_attr_for_triton_kernel = "prepare_attr_code"
prepare_ptr_for_triton_kernel = "prepare_ptr_code"
result = rendering_common_template(func, prepare_attr_for_triton_kernel, prepare_ptr_for_triton_kernel)
self.assertIn('Outputs({"useless"}', result)
# Test case for rendering a function template with return tensor
def test_rendering_with_return_tensor(self):
func = self.mock_function()
prepare_attr_for_triton_kernel = "prepare_attr_code"
prepare_ptr_for_triton_kernel = "prepare_ptr_code"
return_tensor_names = "out_tensor"
result = rendering_common_template(
func,
prepare_attr_for_triton_kernel,
prepare_ptr_for_triton_kernel,
return_tensor_names=return_tensor_names,
)
self.assertIn('Outputs({"out_tensor"})', result)
self.assertIn("std::vector<std::vector<int64_t>> ${op_name}_InferShape", result)
self.assertIn("std::vector<paddle::DataType> ${op_name}_InferDtype", result)
# Test case for rendering a function template with d2s inference code
def test_rendering_with_d2s_infer_code(self):
func = self.mock_function()
prepare_attr_for_triton_kernel = "prepare_attr_code"
prepare_ptr_for_triton_kernel = "prepare_ptr_code"
return_tensor_names = "out_tensor"
d2s_infer_code = "existing_infer_code"
result = rendering_common_template(
func,
prepare_attr_for_triton_kernel,
prepare_ptr_for_triton_kernel,
return_tensor_names=return_tensor_names,
d2s_infer_code=d2s_infer_code,
)
self.assertIn("existing_infer_code", result)
# Test case for rendering a function template with default parameters
def test_rendering_with_default_parameters(self):
func = self.mock_function()
prepare_attr_for_triton_kernel = "prepare_attr_code"
prepare_ptr_for_triton_kernel = "prepare_ptr_code"
result = rendering_common_template(func, prepare_attr_for_triton_kernel, prepare_ptr_for_triton_kernel)
self.assertIn("float b", result)
self.assertIn("bool c", result)
self.assertIn("std::string d", result)
# Test case for rendering a function template with an invalid function
def test_rendering_with_invalid_function(self):
def invalid_func():
pass
prepare_attr_for_triton_kernel = "prepare_attr_code"
prepare_ptr_for_triton_kernel = "prepare_ptr_code"
result = rendering_common_template(invalid_func, prepare_attr_for_triton_kernel, prepare_ptr_for_triton_kernel)
self.assertIn("useless", result)
# Test case for rendering a function template with multiple return tensors
def test_rendering_with_multiple_return_tensors(self):
func = self.mock_function()
prepare_attr_for_triton_kernel = "prepare_attr_code"
prepare_ptr_for_triton_kernel = "prepare_ptr_code"
return_tensor_names = "out_tensor, aux_tensor"
result = rendering_common_template(
func,
prepare_attr_for_triton_kernel,
prepare_ptr_for_triton_kernel,
return_tensor_names=return_tensor_names,
)
self.assertIn('Outputs({"out_tensor","aux_tensor"})', result)
# Test case for rendering a function template with edge case return tensor names
def test_rendering_with_edge_case_return_tensor_names(self):
func = self.mock_function()
prepare_attr_for_triton_kernel = "prepare_attr_code"
prepare_ptr_for_triton_kernel = "prepare_ptr_code"
return_tensor_names = ""
result = rendering_common_template(
func,
prepare_attr_for_triton_kernel,
prepare_ptr_for_triton_kernel,
return_tensor_names=return_tensor_names,
)
self.assertIn('Outputs({""}', result)
class TestKernelInterface(unittest.TestCase):
# A mock kernel function with constant arguments
def mock_kernel_func(
self,
a,
b: int,
config_key0: triton.language.core.constexpr,
config_key1: triton.language.core.constexpr,
config_key2: triton.language.core.constexpr,
):
return a + b
# Test case for when values do not match in decorator
@patch(f"{TRITON_UTILS_PATH}.OpProtoHolder.instance")
@patch(f"{TRITON_UTILS_PATH}.multi_process_do")
@patch(f"{TRITON_UTILS_PATH}.build_package")
@patch(f"{TRITON_UTILS_PATH}.find_so_path")
@patch(f"{TRITON_UTILS_PATH}.extract_triton_kernel")
@patch("paddle.distributed.get_rank")
@patch("os.path")
@patch("os.makedirs")
@patch("builtins.open", new_callable=MagicMock)
@patch("os.system")
@patch("os.rename")
@patch("os.listdir")
def test_with_values_do_not_match(
self,
mock_listdir,
mock_rename,
mock_system,
mock_open,
mock_makedirs,
mock_os_path,
mock_get_rank,
mock_extract_triton_kernel,
mock_find_so_path,
mock_build_package,
mock_multi_process_do,
mock_op_proto_instance,
):
mock_system.return_value = 0
mock_get_rank.return_value = 0
mock_extract_triton_kernel.return_value = None
mock_find_so_path.return_value = None
mock_build_package.return_value = None
mock_multi_process_do.return_value = None
mock_op_proto_map = {"mock_op": "some_value"}
mock_op_proto_instance_return_value = MagicMock()
mock_op_proto_instance_return_value.op_proto_map = mock_op_proto_map
mock_op_proto_instance.return_value = mock_op_proto_instance_return_value
kernel_interface = KernelInterface(self.mock_kernel_func, other_config={})
op_name_and_grid = [
"mock_op",
"custom_template",
[1, "1", 1],
{"config_key0": "config_value0", "config_key1": "config_value1", "config_key2": "config_value2"},
]
kernel_interface[op_name_and_grid]
self.assertIsNotNone(kernel_interface.func)
self.assertIn("a", kernel_interface.arg_names)
self.assertIn("b", kernel_interface.arg_names)
self.assertIn("config_key0", kernel_interface.arg_names)
# Test if ValueError is raised when values do not match
with self.assertRaises(ValueError):
kernel_interface.decorator("mock_op", "custom_template", "config_value", "config_value1", "config_value2")
mock_extract_triton_kernel.assert_called_once_with(
self.mock_kernel_func, "/tmp/triton_cache/rank0/mock_op/triton_kernels.py"
)
mock_open.assert_called_once_with("/tmp/triton_cache/rank0/mock_op/mock_op.cu", "w")
# Test case for when parameter values match in decorator
@patch(f"{TRITON_UTILS_PATH}.OpProtoHolder.instance")
@patch(f"{TRITON_UTILS_PATH}.multi_process_do")
@patch(f"{TRITON_UTILS_PATH}.build_package")
@patch(f"{TRITON_UTILS_PATH}.find_so_path")
@patch(f"{TRITON_UTILS_PATH}.extract_triton_kernel")
@patch("paddle.distributed.get_rank")
@patch("os.path")
@patch("os.makedirs")
@patch("builtins.open", new_callable=MagicMock)
@patch("os.system")
@patch("os.rename")
@patch("os.listdir")
def test_with_values_match(
self,
mock_listdir,
mock_rename,
mock_system,
mock_open,
mock_makedirs,
mock_os_path,
mock_get_rank,
mock_extract_triton_kernel,
mock_find_so_path,
mock_build_package,
mock_multi_process_do,
mock_op_proto_instance,
):
mock_system.return_value = 0
mock_get_rank.return_value = 0
mock_extract_triton_kernel.return_value = None
mock_find_so_path.return_value = None
mock_build_package.return_value = None
mock_multi_process_do.return_value = None
mock_op_proto_map = {"mock_op": "some_value"}
mock_op_proto_instance_return_value = MagicMock()
mock_op_proto_instance_return_value.op_proto_map = mock_op_proto_map
mock_op_proto_instance.return_value = mock_op_proto_instance_return_value
kernel_interface = KernelInterface(self.mock_kernel_func, other_config={})
op_name_and_grid = [
"mock_op",
"custom_template",
[1, "1", 1],
{"config_key0": "config_value0", "config_key1": "config_value1"},
]
kernel_interface[op_name_and_grid]
self.assertIsNotNone(kernel_interface.func)
self.assertIn("a", kernel_interface.arg_names)
self.assertIn("b", kernel_interface.arg_names)
self.assertIn("config_key0", kernel_interface.arg_names)
# Validate if the decorator works correctly when values match
kernel_interface.decorator("mock_op", "custom_template", "config_value0", "config_value1", "config_value2")
mock_op_proto_instance.assert_called_once_with()
mock_extract_triton_kernel.assert_called_once_with(
self.mock_kernel_func, "/tmp/triton_cache/rank0/mock_op/triton_kernels.py"
)
mock_open.assert_called_once_with("/tmp/triton_cache/rank0/mock_op/mock_op.cu", "w")
mock_system.assert_called()
mock_build_package.assert_called_once_with("/tmp/triton_cache/rank0/mock_op", "mock_op_package")
# Test case for when parameter values match in decorator
@patch(f"{TRITON_UTILS_PATH}.OpProtoHolder.instance")
@patch(f"{TRITON_UTILS_PATH}.multi_process_do")
@patch(f"{TRITON_UTILS_PATH}.build_package")
@patch(f"{TRITON_UTILS_PATH}.find_so_path")
@patch(f"{TRITON_UTILS_PATH}.extract_triton_kernel")
@patch("paddle.distributed.get_rank")
@patch("os.path")
@patch("os.makedirs")
@patch("builtins.open", new_callable=MagicMock)
@patch("os.system")
@patch("os.rename")
@patch("os.listdir")
def test_with_missing_parameter(
self,
mock_listdir,
mock_rename,
mock_system,
mock_open,
mock_makedirs,
mock_os_path,
mock_get_rank,
mock_extract_triton_kernel,
mock_find_so_path,
mock_build_package,
mock_multi_process_do,
mock_op_proto_instance,
):
mock_system.return_value = 0
mock_get_rank.return_value = 0
mock_extract_triton_kernel.return_value = None
mock_find_so_path.return_value = None
mock_build_package.return_value = None
mock_multi_process_do.return_value = None
mock_op_proto_map = {"mock_op": "some_value"}
mock_op_proto_instance_return_value = MagicMock()
mock_op_proto_instance_return_value.op_proto_map = mock_op_proto_map
mock_op_proto_instance.return_value = mock_op_proto_instance_return_value
kernel_interface = KernelInterface(self.mock_kernel_func, other_config={})
op_name_and_grid = ["mock_op", "custom_template", [1, "1", 1]]
kernel_interface[op_name_and_grid]
self.assertIsNotNone(kernel_interface.func)
self.assertIn("a", kernel_interface.arg_names)
self.assertIn("b", kernel_interface.arg_names)
self.assertIn("config_key0", kernel_interface.arg_names)
with self.assertRaises(AssertionError):
kernel_interface.decorator("mock_op", "custom_template")
mock_extract_triton_kernel.assert_called_once_with(
self.mock_kernel_func, "/tmp/triton_cache/rank0/mock_op/triton_kernels.py"
)
mock_open.assert_called_once_with("/tmp/triton_cache/rank0/mock_op/mock_op.cu", "w")
if __name__ == "__main__":
unittest.main()