mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
fix MultimodalRegistry (#3699)
This commit is contained in:
@@ -14,32 +14,13 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
|
||||
class MultimodalRegistry:
|
||||
"""
|
||||
A registry for multimodal models
|
||||
"""
|
||||
|
||||
mm_models: set[str] = {"Ernie4_5_VLMoeForConditionalGeneration"}
|
||||
|
||||
@classmethod
|
||||
def register_model(cls, name: str = "") -> Callable:
|
||||
"""
|
||||
Register model with the given name, class name is used if name is not provided.
|
||||
"""
|
||||
|
||||
def _register(model):
|
||||
nonlocal name
|
||||
if len(name) == 0:
|
||||
name = model.__name__
|
||||
if name in cls.mm_models:
|
||||
raise ValueError(f"multimodal model {name} is already registered")
|
||||
cls.mm_models.add(name)
|
||||
return model
|
||||
|
||||
return _register
|
||||
mm_models: set[str] = {"Ernie4_5_VLMoeForConditionalGeneration", "Ernie5MoeForCausalLM"}
|
||||
|
||||
@classmethod
|
||||
def contains_model(cls, name: str) -> bool:
|
||||
|
Reference in New Issue
Block a user