mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
* [feat] provide an interface for logits processors and a builtin LogitBiasLogitsProcessor * [chore] fix code style * [fix] add unit test & fix existing bugs * [feat] add engine/worker arg --logits-processors * [fix] redefine user args as logits_processors_args and fix some bugs * [fix] fix test_sampler * Update fastdeploy/model_executor/logits_processor/builtin.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/model_executor/logits_processor/__init__.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/model_executor/test_logits_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * [fix] fix typo * Update fastdeploy/engine/sampling_params.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * [fix] fix bracelet * [chore] redefine logits processor interface: pass the entire share_inputs into LP, do not copy share_inputs and logits * [doc] add docs * [fix] fix logit bias processor not applied when decoding is too fast & add docs and tests * [fix] fix redundant code * [feat] skip apply() if no bias is specified --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
163 lines
6.6 KiB
Python
163 lines
6.6 KiB
Python
import random
|
|
import unittest
|
|
from unittest.mock import Mock
|
|
|
|
import paddle
|
|
|
|
from fastdeploy.engine.request import Request
|
|
from fastdeploy.model_executor.logits_processor.builtin import LogitBiasLogitsProcessor
|
|
|
|
|
|
class TestLogitsProcessor(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.vocab_size = 10
|
|
self.max_num_seqs = 16
|
|
self.dtype = "float32"
|
|
self.share_inputs = {
|
|
"stop_flags": paddle.tensor([True for _ in range(self.max_num_seqs)]),
|
|
"logits_processors_args": [{} for _ in range(self.max_num_seqs)],
|
|
}
|
|
|
|
def create_request(self, **kwargs):
|
|
"""Create a mock request with specified logit bias"""
|
|
request = Mock(spec=Request)
|
|
for k, v in kwargs.items():
|
|
setattr(request, k, v)
|
|
return request
|
|
|
|
def create_logits(self):
|
|
return paddle.randn([self.get_batch_size(), self.vocab_size], dtype=self.dtype)
|
|
|
|
def add_request(self, req):
|
|
self.share_inputs["stop_flags"][req.idx] = False
|
|
self.share_inputs["logits_processors_args"][req.idx]["logit_bias"] = req.logit_bias
|
|
|
|
def del_request(self, req):
|
|
self.share_inputs["stop_flags"][req.idx] = True
|
|
self.share_inputs["logits_processors_args"][req.idx] = {}
|
|
|
|
def get_batch_size(self):
|
|
return self.max_num_seqs - sum(self.share_inputs["stop_flags"])
|
|
|
|
def test_logit_bias_logit_processor(self):
|
|
|
|
fd_config = Mock()
|
|
fd_config.model_config.dtype = self.dtype
|
|
logits_processor = LogitBiasLogitsProcessor(fd_config)
|
|
|
|
print("Phase 1: Empty batch")
|
|
logits = self.create_logits()
|
|
logits_processor.update_state(self.share_inputs)
|
|
processed_logits = logits_processor.apply(logits)
|
|
self.assertTrue(paddle.all(processed_logits == logits), "Logits should remain unchanged with empty batch")
|
|
|
|
print("Phase 2: Add first request")
|
|
request1 = self.create_request(
|
|
request_id="req1", idx=0, logit_bias={random.randint(0, self.vocab_size - 1): random.random() - 0.5}
|
|
)
|
|
self.add_request(request1)
|
|
logits = self.create_logits()
|
|
original_logits = logits.clone()
|
|
expected_logits = logits.clone()
|
|
logits_processor.update_state(self.share_inputs)
|
|
processed_logits = logits_processor.apply(logits)
|
|
batch_id = 0
|
|
for slot_id, flag in enumerate(self.share_inputs["stop_flags"]):
|
|
if not flag:
|
|
logit_bias = self.share_inputs["logits_processors_args"][slot_id].get("logit_bias", {})
|
|
for token_id, bias in logit_bias.items():
|
|
expected_logits[batch_id, token_id] += bias
|
|
batch_id += 1
|
|
self.assertTrue(
|
|
paddle.all(processed_logits == expected_logits),
|
|
f"Logits should be modified with req1 biases\n"
|
|
f"original: {original_logits}\n"
|
|
f"processed: {processed_logits}\n"
|
|
f"expected: {expected_logits}\n"
|
|
f"diff: {processed_logits-expected_logits}",
|
|
)
|
|
|
|
print("Phase 3: Add second request with multiple tokens to apply bias")
|
|
request2 = self.create_request(
|
|
request_id="req2",
|
|
idx=1,
|
|
logit_bias=dict(
|
|
zip(random.choices(range(0, self.vocab_size), k=3), [random.random() - 0.5 for _ in range(3)])
|
|
),
|
|
)
|
|
self.add_request(request2)
|
|
logits = self.create_logits()
|
|
original_logits = logits.clone()
|
|
expected_logits = logits.clone()
|
|
logits_processor.update_state(self.share_inputs)
|
|
processed_logits = logits_processor.apply(logits)
|
|
batch_id = 0
|
|
for slot_id, flag in enumerate(self.share_inputs["stop_flags"]):
|
|
if not flag:
|
|
logit_bias = self.share_inputs["logits_processors_args"][slot_id].get("logit_bias") or {}
|
|
for token_id, bias in logit_bias.items():
|
|
expected_logits[batch_id, token_id] += bias
|
|
batch_id += 1
|
|
self.assertTrue(
|
|
paddle.all(processed_logits == expected_logits),
|
|
"Logits should be modified with req1 and req2 biases\n"
|
|
f"original: {original_logits}\n"
|
|
f"processed: {processed_logits}\n"
|
|
f"expected: {expected_logits}\n"
|
|
f"diff: {processed_logits-expected_logits}",
|
|
)
|
|
|
|
print("Phase 4: Remove first request")
|
|
self.del_request(request1)
|
|
logits = self.create_logits()
|
|
original_logits = logits.clone()
|
|
expected_logits = logits.clone()
|
|
logits_processor.update_state(self.share_inputs)
|
|
processed_logits = logits_processor.apply(logits)
|
|
batch_id = 0
|
|
for slot_id, flag in enumerate(self.share_inputs["stop_flags"]):
|
|
if not flag:
|
|
logit_bias = self.share_inputs["logits_processors_args"][slot_id].get("logit_bias") or {}
|
|
for token_id, bias in logit_bias.items():
|
|
expected_logits[batch_id, token_id] += bias
|
|
batch_id += 1
|
|
self.assertTrue(
|
|
paddle.all(processed_logits == expected_logits),
|
|
"Logits should only have biases from request2 after removal\n"
|
|
f"original: {original_logits}\n"
|
|
f"processed: {processed_logits}\n"
|
|
f"expected: {expected_logits}\n"
|
|
f"diff: {processed_logits-expected_logits}",
|
|
)
|
|
|
|
print("Phase 5: Add third request with no logit bias")
|
|
request3 = self.create_request(request_id="req3", idx=0, logit_bias=None)
|
|
self.add_request(request3)
|
|
logits = self.create_logits()
|
|
original_logits = logits.clone()
|
|
expected_logits = logits.clone()
|
|
logits_processor.update_state(self.share_inputs)
|
|
processed_logits = logits_processor.apply(logits)
|
|
batch_id = 0
|
|
for slot_id, flag in enumerate(self.share_inputs["stop_flags"]):
|
|
if not flag:
|
|
logit_bias = self.share_inputs["logits_processors_args"][slot_id].get("logit_bias") or {}
|
|
for token_id, bias in logit_bias.items():
|
|
expected_logits[batch_id, token_id] += bias
|
|
batch_id += 1
|
|
self.assertTrue(
|
|
paddle.all(processed_logits == expected_logits),
|
|
"Logits should remain unchanged with request having no bias\n"
|
|
f"original: {original_logits}\n"
|
|
f"processed: {processed_logits}\n"
|
|
f"expected: {expected_logits}\n"
|
|
f"diff: {processed_logits-expected_logits}",
|
|
)
|
|
|
|
print("All test phases completed successfully!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|