mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-08 18:11:00 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -16,30 +16,50 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from .model_base import ModelForCasualLM, ModelRegistry
|
||||
|
||||
inference_runner_supported_models = ["Qwen2ForCausalLM"]
|
||||
inference_runner_supported_models = [
|
||||
"Ernie4_5_MoeForCausalLM",
|
||||
"Ernie4_5_MTPForCausalLM",
|
||||
"Qwen2ForCausalLM",
|
||||
"Qwen3MoeForCausalLM",
|
||||
"Ernie4_5_ForCausalLM",
|
||||
"Qwen3ForCausalLM",
|
||||
]
|
||||
|
||||
|
||||
def _find_py_files(root_dir):
|
||||
root_path = Path(root_dir)
|
||||
py_files = []
|
||||
for py_file in root_path.rglob("*.py"):
|
||||
rel_path = py_file.relative_to(root_dir)
|
||||
if "__init__" in str(py_file):
|
||||
continue
|
||||
dotted_path = str(rel_path).replace("/", ".").replace("\\",
|
||||
".").replace(
|
||||
".py", "")
|
||||
py_files.append(dotted_path)
|
||||
return py_files
|
||||
|
||||
|
||||
def auto_models_registry():
|
||||
"""
|
||||
auto registry all models in this folder
|
||||
"""
|
||||
for module_file in os.listdir(os.path.dirname(__file__)):
|
||||
if module_file.endswith('.py') and module_file != '__init__.py':
|
||||
module_name = module_file[:-3]
|
||||
try:
|
||||
module = importlib.import_module(
|
||||
f'fastdeploy.model_executor.models.{module_name}')
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
if inspect.isclass(attr) and issubclass(
|
||||
attr,
|
||||
ModelForCasualLM) and attr is not ModelForCasualLM:
|
||||
ModelRegistry.register(attr)
|
||||
except ImportError:
|
||||
raise ImportError(f"{module_name=} import error")
|
||||
for module_file in _find_py_files(os.path.dirname(__file__)):
|
||||
try:
|
||||
module = importlib.import_module(
|
||||
f'fastdeploy.model_executor.models.{module_file}')
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
if inspect.isclass(attr) and issubclass(
|
||||
attr,
|
||||
ModelForCasualLM) and attr is not ModelForCasualLM:
|
||||
ModelRegistry.register(attr)
|
||||
except ImportError:
|
||||
raise ImportError(f"{module_file=} import error")
|
||||
|
||||
|
||||
auto_models_registry()
|
||||
|
Reference in New Issue
Block a user