mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Feature] Add temp_scaled_logprobs and top_p_normalized_logprobs parameters for logits and logprobs post processing (#3552)
* [feature] Add temp_scaled_logprobs and top_p_normalized_logprobs parameters for logits and logprobs post processing * infer engine support temp_scaled_logprobs and top_p_normalized_logprobs * delete some code * code check * code check and add doc * fix tokenizer.decoder(-1), return 'Invalid Token' * add ci for temp_scaled and top_p logprobs * check test * check seq len time shape * logprob clip inf --------- Co-authored-by: sunlei1024 <sunlei5788@gmail.com>
This commit is contained in:
@@ -45,8 +45,9 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
|
|||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "Hello!"}, "logprobs": true, "top_logprobs": 5
|
{"role": "user", "content": "Hello!"}
|
||||||
]
|
],
|
||||||
|
"logprobs": true, "top_logprobs": 0,
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -193,6 +194,12 @@ max_streaming_response_tokens: Optional[int] = None
|
|||||||
|
|
||||||
disable_chat_template: Optional[bool] = False
|
disable_chat_template: Optional[bool] = False
|
||||||
# Whether to disable chat template rendering, using raw input directly (default False means template is enabled).
|
# Whether to disable chat template rendering, using raw input directly (default False means template is enabled).
|
||||||
|
|
||||||
|
temp_scaled_logprobs: Optional[bool] = False
|
||||||
|
# Whether to divide the logits by the temperature coefficient when calculating logprobs (default is False, meaning the logits are not divided by the temperature coefficient).
|
||||||
|
|
||||||
|
top_p_normalized_logprobs: Optional[bool] = False
|
||||||
|
# Whether to perform top-p normalization when calculating logprobs (default is False, indicating that top-p normalization is not performed).
|
||||||
```
|
```
|
||||||
|
|
||||||
### Differences in Return Fields
|
### Differences in Return Fields
|
||||||
|
@@ -45,8 +45,9 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
|
|||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "Hello!"}, "logprobs": true, "top_logprobs": 5
|
{"role": "user", "content": "Hello!"}
|
||||||
]
|
],
|
||||||
|
"logprobs": true, "top_logprobs": 0,
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -192,6 +193,12 @@ max_streaming_response_tokens: Optional[int] = None
|
|||||||
|
|
||||||
disable_chat_template: Optional[bool] = False
|
disable_chat_template: Optional[bool] = False
|
||||||
# 是否禁用聊天模板渲染,直接使用原始输入(默认 False 表示启用模板)。
|
# 是否禁用聊天模板渲染,直接使用原始输入(默认 False 表示启用模板)。
|
||||||
|
|
||||||
|
temp_scaled_logprobs: Optional[bool] = False
|
||||||
|
# 计算logprob时是否对logits除以温度系数(默认 False 表示不除以温度系数)。
|
||||||
|
|
||||||
|
top_p_normalized_logprobs: Optional[bool] = False
|
||||||
|
# 计算logprob时是否进行 top_p 归一化(默认 False 表示不进行top_p归一化)。
|
||||||
```
|
```
|
||||||
|
|
||||||
### 返回字段差异
|
### 返回字段差异
|
||||||
|
@@ -98,6 +98,9 @@ class SamplingParams:
|
|||||||
reasoning_max_tokens: Optional[int] = None
|
reasoning_max_tokens: Optional[int] = None
|
||||||
min_tokens: int = 1
|
min_tokens: int = 1
|
||||||
logprobs: Optional[int] = None
|
logprobs: Optional[int] = None
|
||||||
|
# For logits and logprobs post processing
|
||||||
|
temp_scaled_logprobs: bool = False
|
||||||
|
top_p_normalized_logprobs: bool = False
|
||||||
bad_words: Optional[List[str]] = None
|
bad_words: Optional[List[str]] = None
|
||||||
_bad_words_token_ids: Optional[List[int]] = None
|
_bad_words_token_ids: Optional[List[int]] = None
|
||||||
|
|
||||||
|
@@ -403,6 +403,9 @@ class CompletionRequest(BaseModel):
|
|||||||
echo: Optional[bool] = False
|
echo: Optional[bool] = False
|
||||||
frequency_penalty: Optional[float] = None
|
frequency_penalty: Optional[float] = None
|
||||||
logprobs: Optional[int] = None
|
logprobs: Optional[int] = None
|
||||||
|
# For logits and logprobs post processing
|
||||||
|
temp_scaled_logprobs: bool = False
|
||||||
|
top_p_normalized_logprobs: bool = False
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
n: int = 1
|
n: int = 1
|
||||||
presence_penalty: Optional[float] = None
|
presence_penalty: Optional[float] = None
|
||||||
@@ -534,6 +537,11 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
frequency_penalty: Optional[float] = None
|
frequency_penalty: Optional[float] = None
|
||||||
logprobs: Optional[bool] = False
|
logprobs: Optional[bool] = False
|
||||||
top_logprobs: Optional[int] = 0
|
top_logprobs: Optional[int] = 0
|
||||||
|
|
||||||
|
# For logits and logprobs post processing
|
||||||
|
temp_scaled_logprobs: bool = False
|
||||||
|
top_p_normalized_logprobs: bool = False
|
||||||
|
|
||||||
# remove max_tokens when field is removed from OpenAI API
|
# remove max_tokens when field is removed from OpenAI API
|
||||||
max_tokens: Optional[int] = Field(
|
max_tokens: Optional[int] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -591,6 +599,8 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
|
|
||||||
req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens
|
req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens
|
||||||
req_dict["logprobs"] = self.top_logprobs if self.logprobs else None
|
req_dict["logprobs"] = self.top_logprobs if self.logprobs else None
|
||||||
|
req_dict["temp_scaled_logprobs"] = self.temp_scaled_logprobs
|
||||||
|
req_dict["top_p_normalized_logprobs"] = self.top_p_normalized_logprobs
|
||||||
|
|
||||||
# parse request model into dict, priority: request params > metadata params
|
# parse request model into dict, priority: request params > metadata params
|
||||||
if self.metadata is not None:
|
if self.metadata is not None:
|
||||||
|
@@ -15,7 +15,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
@@ -51,3 +51,6 @@ class SamplingMetadata:
|
|||||||
stop_flags: Optional[paddle.Tensor] = None
|
stop_flags: Optional[paddle.Tensor] = None
|
||||||
prompt_ids: Optional[paddle.Tensor] = None
|
prompt_ids: Optional[paddle.Tensor] = None
|
||||||
prompt_lens: Optional[paddle.Tensor] = None
|
prompt_lens: Optional[paddle.Tensor] = None
|
||||||
|
temp_scaled_logprobs: Optional[paddle.Tensor] = None
|
||||||
|
top_p_normalized_logprobs: Optional[paddle.Tensor] = None
|
||||||
|
share_inputs: Optional[Dict[str, paddle.Tensor]] = None
|
||||||
|
@@ -40,6 +40,18 @@ from fastdeploy.platforms import current_platform
|
|||||||
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
|
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
|
||||||
|
|
||||||
|
|
||||||
|
def top_p_normalize_probs_paddle(
|
||||||
|
probs: paddle.Tensor,
|
||||||
|
top_ps: paddle.Tensor,
|
||||||
|
):
|
||||||
|
probs_idx = probs.argsort(axis=-1, descending=True)
|
||||||
|
probs_sort = paddle.take_along_axis(probs, probs_idx, axis=-1)
|
||||||
|
probs_sum = paddle.cumsum(probs_sort, axis=-1)
|
||||||
|
probs_sort = paddle.where((probs_sum - probs_sort) > top_ps, paddle.zeros_like(probs_sort), probs_sort)
|
||||||
|
probs_sort.divide_(probs_sort.sum(axis=-1, keepdim=True))
|
||||||
|
return paddle.zeros_like(probs_sort).put_along_axis_(indices=probs_idx, values=probs_sort, axis=-1)
|
||||||
|
|
||||||
|
|
||||||
class SamplerProcessor:
|
class SamplerProcessor:
|
||||||
"""
|
"""
|
||||||
SamplingProcessor for guided decoding.
|
SamplingProcessor for guided decoding.
|
||||||
@@ -207,9 +219,45 @@ class Sampler(nn.Layer):
|
|||||||
"""pre process before running"""
|
"""pre process before running"""
|
||||||
self.processor.pre_process(skip_idx_list)
|
self.processor.pre_process(skip_idx_list)
|
||||||
|
|
||||||
def compute_logprobs(self, logits: paddle.Tensor) -> paddle.Tensor:
|
def compute_logprobs(
|
||||||
|
self,
|
||||||
|
logits: paddle.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> paddle.Tensor:
|
||||||
""" """
|
""" """
|
||||||
return F.log_softmax(logits, axis=-1)
|
last_logits = logits
|
||||||
|
real_bsz = last_logits.shape[0]
|
||||||
|
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
|
||||||
|
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
|
||||||
|
share_inputs = sampling_metadata.share_inputs
|
||||||
|
if temp_scaled_logprobs is not None:
|
||||||
|
real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz]
|
||||||
|
temperature = sampling_metadata.temperature[:real_bsz]
|
||||||
|
temp_temperature = paddle.where(real_bsz_temp_scaled, temperature, paddle.ones_like(temperature))
|
||||||
|
last_logits = last_logits / temp_temperature
|
||||||
|
|
||||||
|
last_logprobs = F.log_softmax(last_logits, axis=-1)
|
||||||
|
top_p_logprob = None
|
||||||
|
top_p_req_mask = None
|
||||||
|
|
||||||
|
if top_p_normalized_logprobs is not None and share_inputs is not None:
|
||||||
|
seq_lens_this_time = share_inputs["seq_lens_this_time"].reshape([-1, 1])[:real_bsz]
|
||||||
|
seq_lens_encoder = share_inputs["seq_lens_encoder"].reshape([-1, 1])[:real_bsz]
|
||||||
|
seq_lens_decoder = share_inputs["seq_lens_decoder"].reshape([-1, 1])[:real_bsz]
|
||||||
|
seq_lens_time_sum = seq_lens_this_time + seq_lens_encoder + seq_lens_decoder
|
||||||
|
real_req_mask = seq_lens_time_sum > 0
|
||||||
|
top_p_req_mask = paddle.logical_and(top_p_normalized_logprobs[:real_bsz], real_req_mask)
|
||||||
|
real_req_top_p = sampling_metadata.top_p[:real_bsz]
|
||||||
|
# Normalize logprobs if top_p normalization is enabled
|
||||||
|
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0
|
||||||
|
top_p_req_mask = paddle.logical_and(top_p_req_mask, real_req_top_p != 1.0)
|
||||||
|
if top_p_req_mask.any():
|
||||||
|
probs = F.softmax(last_logits, axis=-1)
|
||||||
|
probs = top_p_normalize_probs_paddle(probs, real_req_top_p)
|
||||||
|
top_p_logprob = paddle.log(probs)
|
||||||
|
if top_p_logprob is not None:
|
||||||
|
last_logprobs = paddle.where(top_p_req_mask, top_p_logprob, last_logprobs)
|
||||||
|
return last_logprobs
|
||||||
|
|
||||||
def gather_logprobs(
|
def gather_logprobs(
|
||||||
self,
|
self,
|
||||||
@@ -234,6 +282,7 @@ class Sampler(nn.Layer):
|
|||||||
Sampled token rank tensor, (num tokens)
|
Sampled token rank tensor, (num tokens)
|
||||||
"""
|
"""
|
||||||
assert token_ids.dtype == paddle.int64
|
assert token_ids.dtype == paddle.int64
|
||||||
|
logprobs.clip_(min=paddle.finfo(logprobs.dtype).min)
|
||||||
# Get with the logprob of the prompt or sampled token.
|
# Get with the logprob of the prompt or sampled token.
|
||||||
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
|
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
|
||||||
|
|
||||||
@@ -260,7 +309,7 @@ class Sampler(nn.Layer):
|
|||||||
""" """
|
""" """
|
||||||
num_logprobs = sampling_metadata.max_num_logprobs
|
num_logprobs = sampling_metadata.max_num_logprobs
|
||||||
if num_logprobs is not None:
|
if num_logprobs is not None:
|
||||||
raw_logprobs = self.compute_logprobs(logits)
|
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
|
||||||
|
|
||||||
logits = self.processor.apply_token_mask(logits, skip_idx_list)
|
logits = self.processor.apply_token_mask(logits, skip_idx_list)
|
||||||
|
|
||||||
|
@@ -323,6 +323,10 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
|
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
|
||||||
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
|
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
|
||||||
self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0)
|
self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0)
|
||||||
|
self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = request.get("temp_scaled_logprobs", False)
|
||||||
|
self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1] = request.get(
|
||||||
|
"top_p_normalized_logprobs", False
|
||||||
|
)
|
||||||
|
|
||||||
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
|
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
|
||||||
self.share_inputs["max_dec_len"][idx : idx + 1] = request.get(
|
self.share_inputs["max_dec_len"][idx : idx + 1] = request.get(
|
||||||
@@ -496,6 +500,12 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["presence_score"][idx : idx + 1] = get_attr_from_request(
|
self.share_inputs["presence_score"][idx : idx + 1] = get_attr_from_request(
|
||||||
request, "presence_penalty", 0.0
|
request, "presence_penalty", 0.0
|
||||||
)
|
)
|
||||||
|
self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = get_attr_from_request(
|
||||||
|
request, "temp_scaled_logprobs", False
|
||||||
|
)
|
||||||
|
self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1] = get_attr_from_request(
|
||||||
|
request, "top_p_normalized_logprobs", False
|
||||||
|
)
|
||||||
|
|
||||||
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
|
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
|
||||||
self.share_inputs["max_dec_len"][idx : idx + 1] = request.get(
|
self.share_inputs["max_dec_len"][idx : idx + 1] = request.get(
|
||||||
@@ -634,6 +644,8 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["presence_score"] = paddle.full(
|
self.share_inputs["presence_score"] = paddle.full(
|
||||||
[max_num_seqs, 1], self.model_config.presence_score, dtype="float32"
|
[max_num_seqs, 1], self.model_config.presence_score, dtype="float32"
|
||||||
)
|
)
|
||||||
|
self.share_inputs["temp_scaled_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype="bool")
|
||||||
|
self.share_inputs["top_p_normalized_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype="bool")
|
||||||
|
|
||||||
self.share_inputs["min_dec_len"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64")
|
self.share_inputs["min_dec_len"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64")
|
||||||
self.share_inputs["max_dec_len"] = paddle.full(
|
self.share_inputs["max_dec_len"] = paddle.full(
|
||||||
@@ -853,6 +865,9 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
max_num_logprobs=20 if self.enable_logprob else None,
|
max_num_logprobs=20 if self.enable_logprob else None,
|
||||||
enable_early_stop=self.enable_early_stop,
|
enable_early_stop=self.enable_early_stop,
|
||||||
stop_flags=self.share_inputs["stop_flags"],
|
stop_flags=self.share_inputs["stop_flags"],
|
||||||
|
temp_scaled_logprobs=self.share_inputs["temp_scaled_logprobs"],
|
||||||
|
top_p_normalized_logprobs=self.share_inputs["top_p_normalized_logprobs"],
|
||||||
|
share_inputs=self.share_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
|
@@ -154,8 +154,101 @@ def test_stream_without_logprobs():
|
|||||||
assert result_chunk["choices"][0]["logprobs"] is None
|
assert result_chunk["choices"][0]["logprobs"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_with_temp_scaled_logprobs():
|
||||||
|
"""
|
||||||
|
测试流式响应开启 temp_scaled_logprobs 后,首个 token 的概率信息是否正确。
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"stream": True,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
|
||||||
|
],
|
||||||
|
"max_tokens": 3,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 0,
|
||||||
|
"temp_scaled_logprobs": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
|
response = send_request(URL, payload)
|
||||||
|
|
||||||
|
# 解析首个包含 content 的流式 chunk
|
||||||
|
result_chunk = {}
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
decoded = line.decode("utf-8").removeprefix("data: ")
|
||||||
|
if decoded == "[DONE]":
|
||||||
|
break
|
||||||
|
|
||||||
|
chunk = json.loads(decoded)
|
||||||
|
content = chunk["choices"][0]["delta"].get("content")
|
||||||
|
if content:
|
||||||
|
result_chunk = chunk
|
||||||
|
print(json.dumps(result_chunk, indent=2, ensure_ascii=False))
|
||||||
|
break
|
||||||
|
|
||||||
|
# 校验概率字段
|
||||||
|
assert result_chunk["choices"][0]["delta"]["content"] == "牛顿"
|
||||||
|
assert result_chunk["choices"][0]["logprobs"]["content"][0]["token"] == "牛顿"
|
||||||
|
assert result_chunk["choices"][0]["logprobs"]["content"][0]["logprob"] == -0.006811376195400953
|
||||||
|
assert result_chunk["choices"][0]["logprobs"]["content"][0]["top_logprobs"][0] == {
|
||||||
|
"token": "牛顿",
|
||||||
|
"logprob": -0.006811376195400953,
|
||||||
|
"bytes": [231, 137, 155, 233, 161, 191],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_with_top_p_normalized_logprobs():
|
||||||
|
"""
|
||||||
|
测试流式响应开启 top_p_normalized_logprobs 后,首个 token 的概率信息是否正确。
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"stream": True,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
|
||||||
|
],
|
||||||
|
"max_tokens": 3,
|
||||||
|
"top_p": 0,
|
||||||
|
"top_p_normalized_logprobs": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = build_request_payload(TEMPLATE, data)
|
||||||
|
response = send_request(URL, payload)
|
||||||
|
|
||||||
|
# 解析首个包含 content 的流式 chunk
|
||||||
|
result_chunk = {}
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
decoded = line.decode("utf-8").removeprefix("data: ")
|
||||||
|
if decoded == "[DONE]":
|
||||||
|
break
|
||||||
|
|
||||||
|
chunk = json.loads(decoded)
|
||||||
|
content = chunk["choices"][0]["delta"].get("content")
|
||||||
|
if content:
|
||||||
|
result_chunk = chunk
|
||||||
|
print(json.dumps(result_chunk, indent=2, ensure_ascii=False))
|
||||||
|
break
|
||||||
|
|
||||||
|
# 校验概率字段
|
||||||
|
assert result_chunk["choices"][0]["delta"]["content"] == "牛顿"
|
||||||
|
assert result_chunk["choices"][0]["logprobs"]["content"][0]["token"] == "牛顿"
|
||||||
|
assert result_chunk["choices"][0]["logprobs"]["content"][0]["logprob"] == 0.0
|
||||||
|
assert result_chunk["choices"][0]["logprobs"]["content"][0]["top_logprobs"][0] == {
|
||||||
|
"token": "牛顿",
|
||||||
|
"logprob": 0.0,
|
||||||
|
"bytes": [231, 137, 155, 233, 161, 191],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_unstream_with_logprobs()
|
test_unstream_with_logprobs()
|
||||||
test_unstream_without_logprobs()
|
test_unstream_without_logprobs()
|
||||||
test_stream_with_logprobs()
|
test_stream_with_logprobs()
|
||||||
test_stream_without_logprobs()
|
test_stream_without_logprobs()
|
||||||
|
test_stream_with_temp_scaled_logprobs()
|
||||||
|
test_stream_with_top_p_normalized_logprobs()
|
||||||
|
Reference in New Issue
Block a user