This commit is contained in:
minghaipeng
2025-01-08 06:51:43 +00:00
parent c7e1d58699
commit c249b98aaa
2 changed files with 2 additions and 3 deletions

View File

@@ -31,6 +31,7 @@ class Req(BaseModel):
req_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
input_ids: Optional[List[int]] = None
text: Optional[str] = None
stop_sequences: Optional[List] = None
messages: Optional[List] = None
max_dec_len: Optional[int] = None
seq_len: Optional[int] = None

View File

@@ -102,9 +102,7 @@ class TritonTokenProcessor(engine.TokenProcessor):
for i in range(len(batch_result)):
is_end = batch_result[i].get("is_end", 0)
token_ids = batch_result[i]["token_ids"]
return_all_tokens = batch_result[i].get("return_all_tokens", False)
cache_special_token = False if is_end == 1 else True
if is_end != 1 and (cache_special_token or return_all_tokens or self.cfg.disable_streaming):
if is_end != 1:
if batch_result[i]["req_id"] not in self.token_buffer:
self.token_buffer[batch_result[i]["req_id"]] = list()
self.score_buffer[batch_result[i]["req_id"]] = list()