mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Union
|
||||
|
||||
@@ -25,14 +26,13 @@ class ModelRegistry:
|
||||
"""
|
||||
Used to register and retrieve model classes.
|
||||
"""
|
||||
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class):
|
||||
"""register model class"""
|
||||
if issubclass(
|
||||
model_class,
|
||||
ModelForCasualLM) and model_class is not ModelForCasualLM:
|
||||
if issubclass(model_class, ModelForCasualLM) and model_class is not ModelForCasualLM:
|
||||
cls._registry[model_class.name()] = model_class
|
||||
return model_class
|
||||
|
||||
@@ -59,8 +59,7 @@ class ModelForCasualLM(nn.Layer, ABC):
|
||||
self.fd_config = configs
|
||||
|
||||
@abstractmethod
|
||||
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray,
|
||||
paddle.Tensor]]):
|
||||
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
|
||||
"""
|
||||
Load model parameters from a given state dictionary.
|
||||
|
||||
|
Reference in New Issue
Block a user