【Feature】add fd plugins && rm model_classes (#3123)

* add fd plugins && rm model_classed

* fix reviews

* add docs

* fix

* fix unitest ci
This commit is contained in:
gaoziyuan
2025-08-04 10:53:20 +08:00
committed by GitHub
parent 1582814905
commit 4021d66ea5
25 changed files with 524 additions and 59 deletions

View File

@@ -20,6 +20,7 @@ from typing import Dict, Union
import numpy as np
import paddle
from paddle import nn
from paddleformers.transformers import PretrainedModel
class ModelRegistry:
@@ -27,21 +28,46 @@ class ModelRegistry:
Used to register and retrieve model classes.
"""
_registry = {}
_arch_to_model_cls = {}
_arch_to_pretrained_model_cls = {}
@classmethod
def register(cls, model_class):
def register_model_class(cls, model_class):
"""register model class"""
if issubclass(model_class, ModelForCasualLM) and model_class is not ModelForCasualLM:
cls._registry[model_class.name()] = model_class
cls._arch_to_model_cls[model_class.name()] = model_class
return model_class
@classmethod
def register_pretrained_model(cls, pretrained_model):
"""register pretrained model class"""
if (
issubclass(pretrained_model, PretrainedModel)
and pretrained_model is not PretrainedModel
and hasattr(pretrained_model, "arch_name")
):
cls._arch_to_pretrained_model_cls[pretrained_model.arch_name()] = pretrained_model
return pretrained_model
@classmethod
def get_pretrain_cls(cls, architectures: str):
"""get_pretrain_cls"""
return cls._arch_to_pretrained_model_cls[architectures]
@classmethod
def get_class(cls, name):
"""get model class"""
if name not in cls._registry:
if name not in cls._arch_to_model_cls:
raise ValueError(f"Model '{name}' is not registered!")
return cls._registry[name]
return cls._arch_to_model_cls[name]
@classmethod
def get_supported_archs(cls):
assert len(cls._arch_to_model_cls) == len(
cls._arch_to_model_cls
), "model class / pretrained model registry num is not same"
return [key for key in cls._arch_to_model_cls.keys()]
class ModelForCasualLM(nn.Layer, ABC):