Files
FastDeploy/tests/model_executor/test_logits_processor.py
李泳桦 a012e3608b [Feature] support logits processors (#4515)
* [feat] provide an interface for logits processors and a builtin LogitBiasLogitsProcessor

* [chore] fix code style

* [fix] add unit test & fix existing bugs

* [feat] add engine/worker arg --logits-processors

* [fix] redefine user args as logits_processors_args and fix some bugs

* [fix] fix test_sampler

* Update fastdeploy/model_executor/logits_processor/builtin.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/model_executor/logits_processor/__init__.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update tests/model_executor/test_logits_processor.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* [fix] fix typo

* Update fastdeploy/engine/sampling_params.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* [fix] fix bracelet

* [chore] redefine logits processor interface: pass the entire share_inputs into LP, do not copy share_inputs and logits

* [doc] add docs

* [fix] fix logit bias processor not applied when decoding is too fast & add docs and tests

* [fix] fix redundant code

* [feat] skip apply() if no bias is specified

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-29 00:08:53 +08:00

163 lines
6.6 KiB
Python

import random
import unittest
from unittest.mock import Mock
import paddle
from fastdeploy.engine.request import Request
from fastdeploy.model_executor.logits_processor.builtin import LogitBiasLogitsProcessor
class TestLogitsProcessor(unittest.TestCase):
def setUp(self):
self.vocab_size = 10
self.max_num_seqs = 16
self.dtype = "float32"
self.share_inputs = {
"stop_flags": paddle.tensor([True for _ in range(self.max_num_seqs)]),
"logits_processors_args": [{} for _ in range(self.max_num_seqs)],
}
def create_request(self, **kwargs):
"""Create a mock request with specified logit bias"""
request = Mock(spec=Request)
for k, v in kwargs.items():
setattr(request, k, v)
return request
def create_logits(self):
return paddle.randn([self.get_batch_size(), self.vocab_size], dtype=self.dtype)
def add_request(self, req):
self.share_inputs["stop_flags"][req.idx] = False
self.share_inputs["logits_processors_args"][req.idx]["logit_bias"] = req.logit_bias
def del_request(self, req):
self.share_inputs["stop_flags"][req.idx] = True
self.share_inputs["logits_processors_args"][req.idx] = {}
def get_batch_size(self):
return self.max_num_seqs - sum(self.share_inputs["stop_flags"])
def test_logit_bias_logit_processor(self):
fd_config = Mock()
fd_config.model_config.dtype = self.dtype
logits_processor = LogitBiasLogitsProcessor(fd_config)
print("Phase 1: Empty batch")
logits = self.create_logits()
logits_processor.update_state(self.share_inputs)
processed_logits = logits_processor.apply(logits)
self.assertTrue(paddle.all(processed_logits == logits), "Logits should remain unchanged with empty batch")
print("Phase 2: Add first request")
request1 = self.create_request(
request_id="req1", idx=0, logit_bias={random.randint(0, self.vocab_size - 1): random.random() - 0.5}
)
self.add_request(request1)
logits = self.create_logits()
original_logits = logits.clone()
expected_logits = logits.clone()
logits_processor.update_state(self.share_inputs)
processed_logits = logits_processor.apply(logits)
batch_id = 0
for slot_id, flag in enumerate(self.share_inputs["stop_flags"]):
if not flag:
logit_bias = self.share_inputs["logits_processors_args"][slot_id].get("logit_bias", {})
for token_id, bias in logit_bias.items():
expected_logits[batch_id, token_id] += bias
batch_id += 1
self.assertTrue(
paddle.all(processed_logits == expected_logits),
f"Logits should be modified with req1 biases\n"
f"original: {original_logits}\n"
f"processed: {processed_logits}\n"
f"expected: {expected_logits}\n"
f"diff: {processed_logits-expected_logits}",
)
print("Phase 3: Add second request with multiple tokens to apply bias")
request2 = self.create_request(
request_id="req2",
idx=1,
logit_bias=dict(
zip(random.choices(range(0, self.vocab_size), k=3), [random.random() - 0.5 for _ in range(3)])
),
)
self.add_request(request2)
logits = self.create_logits()
original_logits = logits.clone()
expected_logits = logits.clone()
logits_processor.update_state(self.share_inputs)
processed_logits = logits_processor.apply(logits)
batch_id = 0
for slot_id, flag in enumerate(self.share_inputs["stop_flags"]):
if not flag:
logit_bias = self.share_inputs["logits_processors_args"][slot_id].get("logit_bias") or {}
for token_id, bias in logit_bias.items():
expected_logits[batch_id, token_id] += bias
batch_id += 1
self.assertTrue(
paddle.all(processed_logits == expected_logits),
"Logits should be modified with req1 and req2 biases\n"
f"original: {original_logits}\n"
f"processed: {processed_logits}\n"
f"expected: {expected_logits}\n"
f"diff: {processed_logits-expected_logits}",
)
print("Phase 4: Remove first request")
self.del_request(request1)
logits = self.create_logits()
original_logits = logits.clone()
expected_logits = logits.clone()
logits_processor.update_state(self.share_inputs)
processed_logits = logits_processor.apply(logits)
batch_id = 0
for slot_id, flag in enumerate(self.share_inputs["stop_flags"]):
if not flag:
logit_bias = self.share_inputs["logits_processors_args"][slot_id].get("logit_bias") or {}
for token_id, bias in logit_bias.items():
expected_logits[batch_id, token_id] += bias
batch_id += 1
self.assertTrue(
paddle.all(processed_logits == expected_logits),
"Logits should only have biases from request2 after removal\n"
f"original: {original_logits}\n"
f"processed: {processed_logits}\n"
f"expected: {expected_logits}\n"
f"diff: {processed_logits-expected_logits}",
)
print("Phase 5: Add third request with no logit bias")
request3 = self.create_request(request_id="req3", idx=0, logit_bias=None)
self.add_request(request3)
logits = self.create_logits()
original_logits = logits.clone()
expected_logits = logits.clone()
logits_processor.update_state(self.share_inputs)
processed_logits = logits_processor.apply(logits)
batch_id = 0
for slot_id, flag in enumerate(self.share_inputs["stop_flags"]):
if not flag:
logit_bias = self.share_inputs["logits_processors_args"][slot_id].get("logit_bias") or {}
for token_id, bias in logit_bias.items():
expected_logits[batch_id, token_id] += bias
batch_id += 1
self.assertTrue(
paddle.all(processed_logits == expected_logits),
"Logits should remain unchanged with request having no bias\n"
f"original: {original_logits}\n"
f"processed: {processed_logits}\n"
f"expected: {expected_logits}\n"
f"diff: {processed_logits-expected_logits}",
)
print("All test phases completed successfully!")
if __name__ == "__main__":
unittest.main()