mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] support top_k_top_p sampling (#2753)
* support top_k_top_p sampling * fix * add api param * add api para * fix * fix * fix * fix * fix * fix * fix
This commit is contained in:
@@ -16,8 +16,8 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from fastdeploy import envs
|
||||
|
||||
@@ -52,6 +52,7 @@ class SamplingParams:
|
||||
the model more random. Zero means greedy sampling.
|
||||
top_p: Float that controls the cumulative probability of the top tokens
|
||||
to consider. Must be in [0, 1]. Set to 1 to consider all tokens.
|
||||
top_k: Int that controls the number of top tokens to consider. Must be a positive integer.
|
||||
seed: Random seed to use for the generation.
|
||||
stop: list of strings that stop the generation when they are generated.
|
||||
The returned output will not contain the stop strings.
|
||||
@@ -81,7 +82,8 @@ class SamplingParams:
|
||||
frequency_penalty: float = None
|
||||
repetition_penalty: float = None
|
||||
temperature: float = None
|
||||
top_p: float = None
|
||||
top_p: float = 1.0
|
||||
top_k: int = 0
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||
@@ -111,6 +113,7 @@ class SamplingParams:
|
||||
repetition_penalty,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
seed=None,
|
||||
stop=None,
|
||||
stop_token_ids=None,
|
||||
@@ -129,7 +132,8 @@ class SamplingParams:
|
||||
repetition_penalty=repetition_penalty
|
||||
if repetition_penalty is not None else 1.0,
|
||||
temperature=temperature if temperature is not None else 1.0,
|
||||
top_p=top_p if top_p is not None else 0.7,
|
||||
top_p=top_p if top_p is not None else 1.0,
|
||||
top_k=top_k if top_k is not None else 0,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
|
||||
@@ -292,6 +292,7 @@ class CompletionRequest(BaseModel):
|
||||
suffix: Optional[dict] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
response_format: Optional[AnyResponseFormat] = None
|
||||
@@ -405,6 +406,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
@@ -27,34 +27,56 @@ if current_platform.is_gcu():
|
||||
|
||||
def top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
ps: paddle.Tensor,
|
||||
top_p: paddle.Tensor,
|
||||
top_k: Optional[paddle.Tensor] = None,
|
||||
threshold: Optional[paddle.Tensor] = None,
|
||||
topp_seed: Optional[paddle.Tensor] = None,
|
||||
seed: int = -1,
|
||||
k: int = 0,
|
||||
mode: Literal['truncated', 'non-truncated'] = "truncated",
|
||||
order: Literal['top_k_first', 'joint'] = "top_k_first",
|
||||
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
top_p_sampling
|
||||
x(Tensor): An input 2-D Tensor with type float32, float16 and bfloat16.
|
||||
top_p(Tensor): A 1-D Tensor with type float32, float16 and bfloat16,
|
||||
used to specify the top_p corresponding to each query.
|
||||
top_k(Tensor|None, optional): A 1-D Tensor with type int64,
|
||||
used to specify the top_k corresponding to each query.
|
||||
Only used when FD_SAMPLING_CLASS is `rejection`.
|
||||
threshold(Tensor|None, optional): A 1-D Tensor with type float32, float16 and bfloat16,
|
||||
used to avoid sampling low score tokens.
|
||||
topp_seed(Tensor|None, optional): A 1-D Tensor with type int64,
|
||||
used to specify the random seed for each query.
|
||||
seed(int, optional): the random seed. Default is -1,
|
||||
k(int): the number of top_k scores/ids to be returned. Default is 0.
|
||||
Only used when FD_SAMPLING_CLASS is `air`.
|
||||
mode(str): The mode to choose sampling strategy. If the mode is `truncated`, sampling will truncate the probability at top_p_value.
|
||||
If the mode is `non-truncated`, it will not be truncated. Default is `truncated`.
|
||||
Only used when FD_SAMPLING_CLASS is `air` or `base`.
|
||||
order(str): The order of applying top-k and top-p sampling, should be either `top_k_first` or `joint`.
|
||||
If `top_k_first`, we first apply top-k filter, then apply top-p sampling on the top-k results.
|
||||
If `joint`, we apply top-k and top-p filter simultaneously in each round. Default is `top_k_first`.
|
||||
Only used when FD_SAMPLING_CLASS is `rejection`.
|
||||
|
||||
"""
|
||||
top_p_class = envs.FD_SAMPLING_CLASS.lower()
|
||||
if top_p_class == "air":
|
||||
_, ids = air_top_p_sampling(x,
|
||||
ps,
|
||||
top_p,
|
||||
threshold,
|
||||
topp_seed,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode=mode)
|
||||
elif top_p_class == "rejection":
|
||||
ids = rejection_top_p_sampling(x, ps, seed)
|
||||
ids = rejection_top_p_sampling(x, top_p, top_k, seed, order)
|
||||
_ = None
|
||||
else:
|
||||
if current_platform.is_gcu():
|
||||
_, ids = gcu_top_p_sampling(x, ps)
|
||||
_, ids = gcu_top_p_sampling(x, top_p)
|
||||
else:
|
||||
_, ids = paddle.tensor.top_p_sampling(x,
|
||||
ps,
|
||||
top_p,
|
||||
threshold=threshold,
|
||||
topp_seed=topp_seed,
|
||||
seed=seed,
|
||||
@@ -65,7 +87,7 @@ def top_p_sampling(
|
||||
|
||||
def air_top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
ps: paddle.Tensor,
|
||||
top_p: paddle.Tensor,
|
||||
threshold: Optional[paddle.Tensor] = None,
|
||||
topp_seed: Optional[paddle.Tensor] = None,
|
||||
seed: int = -1,
|
||||
@@ -77,7 +99,7 @@ def air_top_p_sampling(
|
||||
"""
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import air_top_p_sampling
|
||||
out, ids = air_top_p_sampling(x, ps, threshold, topp_seed, seed, k,
|
||||
out, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed, k,
|
||||
mode)
|
||||
except ImportError:
|
||||
raise RuntimeError("Cannot import air_top_p_sampling op.")
|
||||
@@ -86,19 +108,46 @@ def air_top_p_sampling(
|
||||
|
||||
def rejection_top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
ps: paddle.Tensor,
|
||||
top_p: paddle.Tensor,
|
||||
top_k: Optional[paddle.Tensor] = None,
|
||||
seed: int = -1,
|
||||
order: Literal['top_k_first', 'joint'] = "top_k_first",
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
rejection_top_p_sampling
|
||||
"""
|
||||
assert top_p is not None, "Top_p should not be none when FD_SAMPLING_CLASS is rejection"
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import rejection_top_p_sampling
|
||||
ids = rejection_top_p_sampling(
|
||||
x,
|
||||
ps,
|
||||
seed,
|
||||
)
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
rejection_top_p_sampling, top_k_renorm_probs)
|
||||
|
||||
if top_k is None:
|
||||
ids = rejection_top_p_sampling(
|
||||
x,
|
||||
top_p,
|
||||
None,
|
||||
seed,
|
||||
)
|
||||
elif top_k is not None and top_p is not None:
|
||||
if order == "top_k_first":
|
||||
renorm_probs = top_k_renorm_probs(x, top_k)
|
||||
ids = rejection_top_p_sampling(
|
||||
renorm_probs,
|
||||
top_p,
|
||||
None,
|
||||
seed,
|
||||
)
|
||||
else:
|
||||
ids = rejection_top_p_sampling(
|
||||
x,
|
||||
top_p,
|
||||
top_k,
|
||||
seed,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Top_p cannot be none."
|
||||
)
|
||||
except ImportError:
|
||||
raise RuntimeError("Cannot import rejection_top_p_sampling op.")
|
||||
return ids
|
||||
|
||||
@@ -214,7 +214,7 @@ class Sampler(nn.Layer):
|
||||
|
||||
probs = F.softmax(logits)
|
||||
|
||||
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p)
|
||||
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
|
||||
|
||||
self.processor.update_output_tokens(next_tokens, skip_idx_list)
|
||||
return next_tokens
|
||||
@@ -367,5 +367,5 @@ class MTPSampler(nn.Layer):
|
||||
)
|
||||
probs = F.softmax(logits)
|
||||
|
||||
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p)
|
||||
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
|
||||
return next_tokens
|
||||
|
||||
@@ -154,12 +154,29 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
-1].disaggregate_info["role"] == "prefill":
|
||||
os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1"
|
||||
|
||||
top_k_reqs = []
|
||||
top_p_reqs = []
|
||||
max_num_seqs = self.parallel_config.max_num_seqs
|
||||
top_p_buffer = paddle.full([max_num_seqs, 1],
|
||||
self.model_config.top_p,
|
||||
dtype='float32')
|
||||
top_k_buffer = paddle.full([max_num_seqs, 1],
|
||||
0,
|
||||
dtype='int64')
|
||||
|
||||
req_len = len(req_dicts)
|
||||
for i in range(req_len):
|
||||
request = req_dicts[i]
|
||||
idx = request.idx
|
||||
length = len(request.prompt_token_ids)
|
||||
|
||||
if sampling_params := request.sampling_params:
|
||||
if sampling_params.top_p < 1:
|
||||
top_p_reqs.append(idx)
|
||||
top_k = sampling_params.top_k
|
||||
if top_k > 0:
|
||||
top_k_reqs.append(idx)
|
||||
|
||||
prefill_tokens = []
|
||||
if (request.guided_json is not None
|
||||
or request.guided_regex is not None
|
||||
@@ -234,8 +251,8 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
request.eos_token_ids.append(request.eos_token_ids[0])
|
||||
self.share_inputs["eos_token_id"][:] = np.array(
|
||||
request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
||||
|
||||
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7)
|
||||
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
|
||||
top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
|
||||
self.share_inputs["temperature"][idx:idx + 1] = request.get(
|
||||
"temperature", 0.95)
|
||||
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
|
||||
@@ -286,6 +303,16 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
if self.speculative_method in ["mtp"]:
|
||||
self.proposer.insert_prefill_inputs(req_dicts)
|
||||
|
||||
if len(top_k_reqs) == 0:
|
||||
self.share_inputs["top_k"] = None
|
||||
else:
|
||||
self.share_inputs["top_k"] = top_k_buffer
|
||||
|
||||
if len(top_p_reqs) == 0:
|
||||
self.share_inputs["top_p"] = None
|
||||
else:
|
||||
self.share_inputs["top_p"] = top_p_buffer
|
||||
|
||||
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
|
||||
expected_decode_len: int):
|
||||
""" Set dummy prefill inputs to share_inputs """
|
||||
@@ -340,8 +367,11 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["eos_token_id"] = paddle.full(
|
||||
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
|
||||
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
|
||||
self.model_config.top_p,
|
||||
dtype='float32')
|
||||
self.model_config.top_p,
|
||||
dtype='float32')
|
||||
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1],
|
||||
0,
|
||||
dtype='int64')
|
||||
self.share_inputs["temperature"] = paddle.full(
|
||||
[max_num_seqs, 1], self.model_config.temperature, dtype='float32')
|
||||
self.share_inputs["penalty_score"] = paddle.full(
|
||||
@@ -563,6 +593,7 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
self.sampling_metadata = SamplingMetadata(
|
||||
temperature=self.share_inputs["temperature"],
|
||||
top_p=self.share_inputs["top_p"],
|
||||
top_k=self.share_inputs["top_k"],
|
||||
step_idx=self.share_inputs["step_idx"],
|
||||
pre_token_ids=self.share_inputs["pre_ids"],
|
||||
frequency_penalties=self.share_inputs["frequency_score"],
|
||||
|
||||
@@ -161,6 +161,15 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
-1].disaggregate_info["role"] == "prefill":
|
||||
os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1"
|
||||
|
||||
top_k_reqs = []
|
||||
top_p_reqs = []
|
||||
max_num_seqs = self.parallel_config.max_num_seqs
|
||||
top_p_buffer = paddle.full([max_num_seqs, 1],
|
||||
self.model_config.top_p,
|
||||
dtype='float32')
|
||||
top_k_buffer = paddle.full([max_num_seqs, 1],
|
||||
0,
|
||||
dtype='int64')
|
||||
req_len = len(req_dicts)
|
||||
for i in range(req_len):
|
||||
request = req_dicts[i]
|
||||
@@ -168,6 +177,13 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
length = len(request.prompt_token_ids)
|
||||
assert length > 0, "The prompt requested must not be empty."
|
||||
|
||||
if sampling_params := request.sampling_params:
|
||||
if sampling_params.top_p < 1:
|
||||
top_p_reqs.append(idx)
|
||||
top_k = sampling_params.top_k
|
||||
if top_k > 0:
|
||||
top_k_reqs.append(idx)
|
||||
|
||||
prefill_tokens = []
|
||||
if (request.guided_json is not None
|
||||
or request.guided_regex is not None
|
||||
@@ -242,8 +258,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
request.eos_token_ids.append(request.eos_token_ids[0])
|
||||
self.share_inputs["eos_token_id"][:] = np.array(
|
||||
request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
||||
|
||||
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7)
|
||||
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
|
||||
top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
|
||||
self.share_inputs["temperature"][idx:idx + 1] = request.get(
|
||||
"temperature", 0.95)
|
||||
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
|
||||
@@ -294,6 +310,16 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if self.speculative_method in ["mtp"]:
|
||||
self.proposer.insert_prefill_inputs(req_dicts)
|
||||
|
||||
if len(top_k_reqs) == 0:
|
||||
self.share_inputs["top_k"] = None
|
||||
else:
|
||||
self.share_inputs["top_k"] = top_k_buffer
|
||||
|
||||
if len(top_p_reqs) == 0:
|
||||
self.share_inputs["top_p"] = None
|
||||
else:
|
||||
self.share_inputs["top_p"] = top_p_buffer
|
||||
|
||||
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
|
||||
expected_decode_len: int):
|
||||
""" Set dummy prefill inputs to share_inputs """
|
||||
@@ -349,8 +375,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["eos_token_id"] = paddle.full(
|
||||
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
|
||||
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
|
||||
self.model_config.top_p,
|
||||
dtype='float32')
|
||||
self.model_config.top_p,
|
||||
dtype='float32')
|
||||
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1],
|
||||
0,
|
||||
dtype='int64')
|
||||
self.share_inputs["temperature"] = paddle.full(
|
||||
[max_num_seqs, 1], self.model_config.temperature, dtype='float32')
|
||||
self.share_inputs["penalty_score"] = paddle.full(
|
||||
@@ -574,6 +603,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.sampling_metadata = SamplingMetadata(
|
||||
temperature=self.share_inputs["temperature"],
|
||||
top_p=self.share_inputs["top_p"],
|
||||
top_k=self.share_inputs["top_k"],
|
||||
step_idx=self.share_inputs["step_idx"],
|
||||
pre_token_ids=self.share_inputs["pre_ids"],
|
||||
frequency_penalties=self.share_inputs["frequency_score"],
|
||||
|
||||
@@ -29,9 +29,8 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import \
|
||||
AttentionBackend
|
||||
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.sampler import (Sampler,
|
||||
SpeculativeSampler
|
||||
)
|
||||
from fastdeploy.model_executor.layers.sample.sampler import (
|
||||
Sampler, SpeculativeSampler)
|
||||
from fastdeploy.model_executor.model_loader import get_model_from_loader
|
||||
from fastdeploy.model_executor.ops.iluvatar import set_value_by_flags_and_idx
|
||||
from fastdeploy.model_executor.pre_and_post_process import (post_process,
|
||||
@@ -145,12 +144,29 @@ class IluvatarModelRunner(ModelRunnerBase):
|
||||
-1].disaggregate_info["role"] == "prefill":
|
||||
os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1"
|
||||
|
||||
top_k_reqs = []
|
||||
top_p_reqs = []
|
||||
max_num_seqs = self.parallel_config.max_num_seqs
|
||||
top_p_buffer = paddle.full([max_num_seqs, 1],
|
||||
self.model_config.top_p,
|
||||
dtype='float32')
|
||||
top_k_buffer = paddle.full([max_num_seqs, 1],
|
||||
0,
|
||||
dtype='int64')
|
||||
|
||||
req_len = len(req_dicts)
|
||||
for i in range(req_len):
|
||||
request = req_dicts[i]
|
||||
idx = request.idx
|
||||
length = len(request.prompt_token_ids)
|
||||
|
||||
if sampling_params := request.sampling_params:
|
||||
if sampling_params.top_p < 1:
|
||||
top_p_reqs.append(idx)
|
||||
top_k = sampling_params.top_k
|
||||
if top_k > 0:
|
||||
top_k_reqs.append(idx)
|
||||
|
||||
prefill_tokens = []
|
||||
if (request.guided_json is not None
|
||||
or request.guided_regex is not None
|
||||
@@ -225,8 +241,8 @@ class IluvatarModelRunner(ModelRunnerBase):
|
||||
request.eos_token_ids.append(request.eos_token_ids[0])
|
||||
self.share_inputs["eos_token_id"][:] = np.array(
|
||||
request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
||||
|
||||
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7)
|
||||
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
|
||||
top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
|
||||
self.share_inputs["temperature"][idx:idx + 1] = request.get(
|
||||
"temperature", 0.95)
|
||||
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
|
||||
@@ -273,6 +289,15 @@ class IluvatarModelRunner(ModelRunnerBase):
|
||||
idx, request.get("logits_processor"), prefill_tokens)
|
||||
|
||||
self.share_inputs["not_need_stop"][0] = True
|
||||
if len(top_k_reqs) == 0:
|
||||
self.share_inputs["top_k"] = None
|
||||
else:
|
||||
self.share_inputs["top_k"] = top_k_buffer
|
||||
|
||||
if len(top_p_reqs) == 0:
|
||||
self.share_inputs["top_p"] = None
|
||||
else:
|
||||
self.share_inputs["top_p"] = top_p_buffer
|
||||
|
||||
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
|
||||
expected_decode_len: int):
|
||||
@@ -329,8 +354,11 @@ class IluvatarModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["eos_token_id"] = paddle.full(
|
||||
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
|
||||
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
|
||||
self.model_config.top_p,
|
||||
dtype='float32')
|
||||
self.model_config.top_p,
|
||||
dtype='float32')
|
||||
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1],
|
||||
0,
|
||||
dtype='int64')
|
||||
self.share_inputs["temperature"] = paddle.full(
|
||||
[max_num_seqs, 1], self.model_config.temperature, dtype='float32')
|
||||
self.share_inputs["penalty_score"] = paddle.full(
|
||||
@@ -558,6 +586,7 @@ class IluvatarModelRunner(ModelRunnerBase):
|
||||
self.sampling_metadata = SamplingMetadata(
|
||||
temperature=self.share_inputs["temperature"],
|
||||
top_p=self.share_inputs["top_p"],
|
||||
top_k=self.share_inputs["top_k"],
|
||||
step_idx=self.share_inputs["step_idx"],
|
||||
pre_token_ids=self.share_inputs["pre_ids"],
|
||||
frequency_penalties=self.share_inputs["frequency_score"],
|
||||
|
||||
@@ -14,14 +14,14 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import argparse
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
import paddle.distributed.fleet as fleet
|
||||
from fastdeploy.config import ModelConfig
|
||||
|
||||
from fastdeploy.config import ModelConfig
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("worker", "worker.log")
|
||||
|
||||
@@ -282,11 +282,26 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
|
||||
def process_prefill_inputs(self, req_dicts: List[Request]):
|
||||
""" Process inputs for prefill tasks and update share_inputs buffer """
|
||||
top_k_reqs = []
|
||||
top_p_reqs = []
|
||||
max_num_seqs = self.parallel_config.max_num_seqs
|
||||
top_p_buffer = paddle.full([max_num_seqs, 1],
|
||||
self.model_config.top_p,
|
||||
dtype='float32')
|
||||
top_k_buffer = paddle.full([max_num_seqs, 1],
|
||||
0,
|
||||
dtype='int64')
|
||||
req_len = len(req_dicts)
|
||||
for i in range(req_len):
|
||||
request = req_dicts[i]
|
||||
idx = request.idx
|
||||
length = request.prompt_token_ids_len
|
||||
if sampling_params := request.sampling_params:
|
||||
if sampling_params.top_p < 1:
|
||||
top_p_reqs.append(idx)
|
||||
top_k = sampling_params.top_k
|
||||
if top_k > 0:
|
||||
top_k_reqs.append(idx)
|
||||
self.share_inputs["input_ids"][idx:idx + 1, :length] = np.array(
|
||||
request.prompt_token_ids)
|
||||
if len(request.eos_token_ids
|
||||
@@ -295,7 +310,8 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["eos_token_id"][:] = np.array(
|
||||
request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
||||
self.share_inputs["pre_ids"][idx:idx + 1] = -1
|
||||
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7)
|
||||
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
|
||||
top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
|
||||
self.share_inputs["temperature"][idx:idx + 1] = request.get(
|
||||
"temperature", 0.95)
|
||||
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
|
||||
@@ -344,6 +360,15 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
request.get("stop_token_ids"), dtype="int64")
|
||||
|
||||
self.share_inputs["not_need_stop"][0] = True
|
||||
if len(top_k_reqs) == 0:
|
||||
self.share_inputs["top_k"] = None
|
||||
else:
|
||||
self.share_inputs["top_k"] = top_k_buffer
|
||||
|
||||
if len(top_p_reqs) == 0:
|
||||
self.share_inputs["top_p"] = None
|
||||
else:
|
||||
self.share_inputs["top_p"] = top_p_buffer
|
||||
|
||||
def _init_share_inputs(self, max_num_seqs: int):
|
||||
"""Initialize all share buffers for model inputs.
|
||||
@@ -363,8 +388,11 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["eos_token_id"] = paddle.full(
|
||||
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
|
||||
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
|
||||
self.model_config.top_p,
|
||||
dtype='float32')
|
||||
self.model_config.top_p,
|
||||
dtype='float32')
|
||||
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1],
|
||||
0,
|
||||
dtype='int64')
|
||||
self.share_inputs["temperature"] = paddle.full(
|
||||
[max_num_seqs, 1], self.model_config.temperature, dtype='float32')
|
||||
self.share_inputs["penalty_score"] = paddle.full(
|
||||
@@ -514,6 +542,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.sampling_metadata = SamplingMetadata(
|
||||
temperature=self.share_inputs["temperature"],
|
||||
top_p=self.share_inputs["top_p"],
|
||||
top_k=self.share_inputs["top_k"],
|
||||
step_idx=self.share_inputs["step_idx"],
|
||||
pre_token_ids=self.share_inputs["pre_ids"],
|
||||
frequency_penalties=self.share_inputs["frequency_score"],
|
||||
|
||||
Reference in New Issue
Block a user