mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
【Feature】add fd plugins && rm model_classes (#3123)
* add fd plugins && rm model_classed * fix reviews * add docs * fix * fix unitest ci
This commit is contained in:
@@ -20,6 +20,7 @@ from typing import Dict, Union
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.transformers import PretrainedModel
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
@@ -27,21 +28,46 @@ class ModelRegistry:
|
||||
Used to register and retrieve model classes.
|
||||
"""
|
||||
|
||||
_registry = {}
|
||||
_arch_to_model_cls = {}
|
||||
_arch_to_pretrained_model_cls = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class):
|
||||
def register_model_class(cls, model_class):
|
||||
"""register model class"""
|
||||
if issubclass(model_class, ModelForCasualLM) and model_class is not ModelForCasualLM:
|
||||
cls._registry[model_class.name()] = model_class
|
||||
cls._arch_to_model_cls[model_class.name()] = model_class
|
||||
return model_class
|
||||
|
||||
@classmethod
|
||||
def register_pretrained_model(cls, pretrained_model):
|
||||
"""register pretrained model class"""
|
||||
if (
|
||||
issubclass(pretrained_model, PretrainedModel)
|
||||
and pretrained_model is not PretrainedModel
|
||||
and hasattr(pretrained_model, "arch_name")
|
||||
):
|
||||
cls._arch_to_pretrained_model_cls[pretrained_model.arch_name()] = pretrained_model
|
||||
|
||||
return pretrained_model
|
||||
|
||||
@classmethod
|
||||
def get_pretrain_cls(cls, architectures: str):
|
||||
"""get_pretrain_cls"""
|
||||
return cls._arch_to_pretrained_model_cls[architectures]
|
||||
|
||||
@classmethod
|
||||
def get_class(cls, name):
|
||||
"""get model class"""
|
||||
if name not in cls._registry:
|
||||
if name not in cls._arch_to_model_cls:
|
||||
raise ValueError(f"Model '{name}' is not registered!")
|
||||
return cls._registry[name]
|
||||
return cls._arch_to_model_cls[name]
|
||||
|
||||
@classmethod
|
||||
def get_supported_archs(cls):
|
||||
assert len(cls._arch_to_model_cls) == len(
|
||||
cls._arch_to_model_cls
|
||||
), "model class / pretrained model registry num is not same"
|
||||
return [key for key in cls._arch_to_model_cls.keys()]
|
||||
|
||||
|
||||
class ModelForCasualLM(nn.Layer, ABC):
|
||||
|
Reference in New Issue
Block a user