mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
fix rollout_model and add rl ut (#2882)
This commit is contained in:
@@ -32,6 +32,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
import numpy as np
|
||||
import paddle
|
||||
import zmq
|
||||
from opentelemetry import trace
|
||||
from tqdm import tqdm
|
||||
|
||||
from fastdeploy.engine.args_utils import EngineArgs
|
||||
@@ -42,13 +43,13 @@ from fastdeploy.input.preprocess import InputPreprocessor
|
||||
from fastdeploy.inter_communicator import (EngineCacheQueue, EngineWorkerQueue,
|
||||
IPCSignal, ZmqClient)
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.metrics.trace_util import start_span, start_span_request
|
||||
from fastdeploy.model_executor.guided_decoding import schema_checker
|
||||
from fastdeploy.output.token_processor import (TokenProcessor,
|
||||
WarmUpTokenProcessor)
|
||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||
from fastdeploy.utils import EngineError, console_logger, llm_logger
|
||||
from fastdeploy.metrics.trace_util import extract_from_metadata, start_span, start_span_request
|
||||
from opentelemetry import trace
|
||||
|
||||
|
||||
class LLMEngine(object):
|
||||
"""
|
||||
@@ -359,9 +360,9 @@ class LLMEngine(object):
|
||||
request, insert_task = None, []
|
||||
results: List[Tuple[str, Optional[str]]] = list()
|
||||
if data:
|
||||
request = Request.from_dict(data)
|
||||
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
|
||||
|
||||
request = Request.from_dict(data)
|
||||
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
|
||||
|
||||
|
||||
llm_logger.debug(f"Receive request: {request}")
|
||||
|
||||
@@ -694,7 +695,7 @@ class LLMEngine(object):
|
||||
Insert tasks to engine.
|
||||
"""
|
||||
for task in tasks:
|
||||
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
|
||||
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
|
||||
# TODO 返回至 scheduler
|
||||
if allocated:
|
||||
current_tasks = []
|
||||
@@ -1033,10 +1034,9 @@ class LLMEngine(object):
|
||||
f" --speculative_model_name_or_path {self.cfg.speculative_config.model_name_or_path}"
|
||||
f" --speculative_model_quantization {self.cfg.speculative_config.quantization}"
|
||||
f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}"
|
||||
f" --graph_optimiaztion_config '{self.cfg.graph_optimization_config.to_json_string()}'"
|
||||
f" --graph_optimization_config '{self.cfg.graph_optimization_config.to_json_string()}'"
|
||||
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
|
||||
f" --load_strategy {self.cfg.model_config.load_strategy}"
|
||||
f" --enable_mm {self.cfg.enable_mm}")
|
||||
f" --load_strategy {self.cfg.model_config.load_strategy}")
|
||||
|
||||
|
||||
worker_append_flag = {
|
||||
@@ -1051,6 +1051,7 @@ class LLMEngine(object):
|
||||
"disable_any_whitespace": self.cfg.disable_any_whitespace,
|
||||
"enable-custom-all-reduce": self.cfg.parallel_config.enable_custom_all_reduce,
|
||||
"enable_logprob": self.cfg.enable_logprob,
|
||||
"enable_mm": self.cfg.enable_mm,
|
||||
}
|
||||
for worker_flag, value in worker_append_flag.items():
|
||||
if value:
|
||||
|
Reference in New Issue
Block a user