mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature] Online Chat API Support Return logprobs (#2777)
* online chat support logprobs * check xpu * check vl_gpu_model_runner and xpu_model_runner * get_worker() check platform
This commit is contained in:
@@ -42,6 +42,8 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase:
|
||||
"""
|
||||
get worker of different device
|
||||
"""
|
||||
if fd_config.model_config.enable_logprob and not current_platform.is_cuda():
|
||||
raise NotImplementedError("Only CUDA platform supports logprob.")
|
||||
if current_platform.is_dcu():
|
||||
from fastdeploy.worker.dcu_worker import DcuWorker
|
||||
return DcuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
|
||||
@@ -571,6 +573,9 @@ def parse_args():
|
||||
"'ipc_snapshot': load from disk snapshot of IPC weights, "
|
||||
"'meta': provide RL traing worker, no_weights_load"
|
||||
"'normal':normal load weight")
|
||||
parser.add_argument("--enable_logprob",
|
||||
action='store_true',
|
||||
help="Enable output of token-level log probabilities.")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@@ -799,6 +804,8 @@ def initialize_fd_config(config_or_args) -> FDConfig:
|
||||
"No quantization config found and use original weight and act dtype."
|
||||
)
|
||||
|
||||
model_config.enable_logprob = config_or_args.enable_logprob
|
||||
|
||||
model_config.architectures = model_config_dict.get("architectures")
|
||||
|
||||
# Update load config
|
||||
|
Reference in New Issue
Block a user