mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
[Feature] Support return logprob of generated tokens (#2784)
* online chat support logprobs * check xpu * check vl_gpu_model_runner * only cuda support logprob * get_worker() check platform --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -42,6 +42,8 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase:
|
||||
"""
|
||||
get worker of different device
|
||||
"""
|
||||
if fd_config.model_config.enable_logprob and not current_platform.is_cuda():
|
||||
raise NotImplementedError("Only CUDA platform supports logprob.")
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.worker.gpu_worker import GpuWorker
|
||||
return GpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
|
||||
@@ -550,6 +552,9 @@ def parse_args():
|
||||
"'ipc_snapshot': load from disk snapshot of IPC weights, "
|
||||
"'meta': provide RL traing worker, no_weights_load"
|
||||
"'normal':normal load weight")
|
||||
parser.add_argument("--enable_logprob",
|
||||
action='store_true',
|
||||
help="Enable output of token-level log probabilities.")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@@ -771,6 +776,8 @@ def initialize_fd_config(config) -> FDConfig:
|
||||
"No quantization config found and use original weight and act dtype."
|
||||
)
|
||||
|
||||
model_config.enable_logprob = config.enable_logprob
|
||||
|
||||
model_config.architectures = model_config_dict.get("architectures")
|
||||
|
||||
# Update load config
|
||||
|
Reference in New Issue
Block a user