[Feature] Online Chat API Support Return logprobs (#2777)

* online chat support logprobs

* check xpu

* check vl_gpu_model_runner and xpu_model_runner

* get_worker() check platform
This commit is contained in:
chen
2025-07-10 16:33:40 +08:00
committed by GitHub
parent 24f934f1f9
commit d33105baeb
22 changed files with 608 additions and 114 deletions

View File

@@ -42,6 +42,8 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase:
"""
get worker of different device
"""
if fd_config.model_config.enable_logprob and not current_platform.is_cuda():
raise NotImplementedError("Only CUDA platform supports logprob.")
if current_platform.is_dcu():
from fastdeploy.worker.dcu_worker import DcuWorker
return DcuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
@@ -571,6 +573,9 @@ def parse_args():
"'ipc_snapshot': load from disk snapshot of IPC weights, "
"'meta': provide RL traing worker, no_weights_load"
"'normal':normal load weight")
parser.add_argument("--enable_logprob",
action='store_true',
help="Enable output of token-level log probabilities.")
args = parser.parse_args()
return args
@@ -799,6 +804,8 @@ def initialize_fd_config(config_or_args) -> FDConfig:
"No quantization config found and use original weight and act dtype."
)
model_config.enable_logprob = config_or_args.enable_logprob
model_config.architectures = model_config_dict.get("architectures")
# Update load config