mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
fix bug
This commit is contained in:
@@ -31,6 +31,7 @@ class Req(BaseModel):
|
||||
req_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
input_ids: Optional[List[int]] = None
|
||||
text: Optional[str] = None
|
||||
stop_sequences: Optional[List] = None
|
||||
messages: Optional[List] = None
|
||||
max_dec_len: Optional[int] = None
|
||||
seq_len: Optional[int] = None
|
||||
|
||||
@@ -102,9 +102,7 @@ class TritonTokenProcessor(engine.TokenProcessor):
|
||||
for i in range(len(batch_result)):
|
||||
is_end = batch_result[i].get("is_end", 0)
|
||||
token_ids = batch_result[i]["token_ids"]
|
||||
return_all_tokens = batch_result[i].get("return_all_tokens", False)
|
||||
cache_special_token = False if is_end == 1 else True
|
||||
if is_end != 1 and (cache_special_token or return_all_tokens or self.cfg.disable_streaming):
|
||||
if is_end != 1:
|
||||
if batch_result[i]["req_id"] not in self.token_buffer:
|
||||
self.token_buffer[batch_result[i]["req_id"]] = list()
|
||||
self.score_buffer[batch_result[i]["req_id"]] = list()
|
||||
|
||||
Reference in New Issue
Block a user