fix MultimodalRegistry (#3699)

This commit is contained in:
Yuanle Liu
2025-08-29 11:01:30 +08:00
committed by GitHub
parent 43d5bd62b4
commit 2fb2c0f46a
2 changed files with 2 additions and 25 deletions

View File

@@ -14,32 +14,13 @@
# limitations under the License.
"""
from typing import Callable
class MultimodalRegistry:
"""
A registry for multimodal models
"""
mm_models: set[str] = {"Ernie4_5_VLMoeForConditionalGeneration"}
@classmethod
def register_model(cls, name: str = "") -> Callable:
"""
Register model with the given name, class name is used if name is not provided.
"""
def _register(model):
nonlocal name
if len(name) == 0:
name = model.__name__
if name in cls.mm_models:
raise ValueError(f"multimodal model {name} is already registered")
cls.mm_models.add(name)
return model
return _register
mm_models: set[str] = {"Ernie4_5_VLMoeForConditionalGeneration", "Ernie5MoeForCausalLM"}
@classmethod
def contains_model(cls, name: str) -> bool: