mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 11:56:44 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			101 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			101 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """
 | |
| from abc import ABC, abstractmethod
 | |
| from typing import Dict, Union
 | |
| 
 | |
| import numpy as np
 | |
| import paddle
 | |
| from paddle import nn
 | |
| 
 | |
| 
 | |
| class ModelRegistry:
 | |
|     """
 | |
|     Used to register and retrieve model classes.
 | |
|     """
 | |
|     _registry = {}
 | |
| 
 | |
|     @classmethod
 | |
|     def register(cls, model_class):
 | |
|         """register model class"""
 | |
|         if issubclass(
 | |
|                 model_class,
 | |
|                 ModelForCasualLM) and model_class is not ModelForCasualLM:
 | |
|             cls._registry[model_class.name()] = model_class
 | |
|         return model_class
 | |
| 
 | |
|     @classmethod
 | |
|     def get_class(cls, name):
 | |
|         """get model class"""
 | |
|         if name not in cls._registry:
 | |
|             raise ValueError(f"Model '{name}' is not registered!")
 | |
|         return cls._registry[name]
 | |
| 
 | |
| 
 | |
| class ModelForCasualLM(nn.Layer, ABC):
 | |
|     """
 | |
|     Base class for LM
 | |
|     """
 | |
| 
 | |
|     def __init__(self, configs):
 | |
|         """
 | |
|         Args:
 | |
|             configs (dict): Configurations including parameters such as max_dec_len, min_dec_len, decode_strategy,
 | |
|                 ori_vocab_size, use_topp_sampling, etc.
 | |
|         """
 | |
|         super(ModelForCasualLM, self).__init__()
 | |
| 
 | |
|     @abstractmethod
 | |
|     def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray,
 | |
|                                                          paddle.Tensor]]):
 | |
|         """
 | |
|         Load model parameters from a given state dictionary.
 | |
| 
 | |
|         Args:
 | |
|             state_dict (dict[str, np.ndarray | paddle.Tensor]):
 | |
|                 A dictionary containing model parameters, where keys are parameter names
 | |
|                 and values are NumPy arrays or PaddlePaddle tensors.
 | |
|         """
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abstractmethod
 | |
|     def forward(
 | |
|         self,
 | |
|         input_ids=None,
 | |
|         pos_emb=None,
 | |
|         **model_kwargs,
 | |
|     ):
 | |
|         """
 | |
|         Defines the forward pass of the model for generating text.
 | |
| 
 | |
|         Args:
 | |
|             input_ids (Tensor, optional): The input token ids to the model.
 | |
|             pos_emb (Tensor, optional): position Embeddings for model.
 | |
|             **model_kwargs: Additional keyword arguments for the model.
 | |
| 
 | |
|         Returns:
 | |
|             Tensor or list of Tensors: Generated tokens or decoded outputs.
 | |
|         """
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abstractmethod
 | |
|     def compute_logits(self, hidden_state, **logits_prosessor_kwargs):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @classmethod
 | |
|     @abstractmethod
 | |
|     def name(self):
 | |
|         raise NotImplementedError
 | 
