mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[plugin] Custom model_runner/model support (#3186)
* support custom model&&model_runner * fix merge * add test && update doc * fix codestyle * fix unittest * load model in rl
This commit is contained in:
@@ -26,13 +26,19 @@ from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.plugins.model_runner import load_model_runner_plugins
|
||||
from fastdeploy.utils import get_logger
|
||||
from fastdeploy.worker.gpu_model_runner import GPUModelRunner
|
||||
from fastdeploy.worker.model_runner_base import ModelRunnerBase
|
||||
from fastdeploy.worker.output import ModelRunnerOutput
|
||||
from fastdeploy.worker.worker_base import WorkerBase
|
||||
|
||||
logger = get_logger("gpu_worker", "gpu_worker.log")
|
||||
|
||||
try:
|
||||
ModelRunner = load_model_runner_plugins()
|
||||
except:
|
||||
from fastdeploy.worker.gpu_model_runner import GPUModelRunner as ModelRunner
|
||||
|
||||
|
||||
class GpuWorker(WorkerBase):
|
||||
def __init__(
|
||||
@@ -70,7 +76,7 @@ class GpuWorker(WorkerBase):
|
||||
raise RuntimeError(f"Not support device type: {self.device_config.device}")
|
||||
|
||||
# Construct model runner
|
||||
self.model_runner: GPUModelRunner = GPUModelRunner(
|
||||
self.model_runner: ModelRunnerBase = ModelRunner(
|
||||
fd_config=self.fd_config,
|
||||
device=self.device,
|
||||
device_id=self.device_ids[self.local_rank % self.max_chips_per_node],
|
||||
|
Reference in New Issue
Block a user