mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[Sync] Update to latest code (#2679)
* [Sync] Update to latest code * Add new code files * Add new code files * update code * Try to fix build.sh * Try to fix build.sh * Update code * Update requirements.txt * Update code --------- Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
@@ -28,15 +28,17 @@ class ModelRegistry:
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class):
|
||||
def register(cls, model_class, suffix=""):
|
||||
"""register model class"""
|
||||
if issubclass(
|
||||
model_class,
|
||||
ModelForCasualLM) and model_class is not ModelForCasualLM:
|
||||
cls._registry[model_class.name()] = model_class
|
||||
cls._registry[f"{model_class.name()}{suffix}"] = model_class
|
||||
return model_class
|
||||
|
||||
@classmethod
|
||||
def get_class(cls, name):
|
||||
"""get model class"""
|
||||
if name not in cls._registry:
|
||||
raise ValueError(f"Model '{name}' is not registered!")
|
||||
return cls._registry[name]
|
||||
|
Reference in New Issue
Block a user