mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-27 02:20:31 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			83 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			83 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """
 | |
| 
 | |
| from abc import ABC, abstractmethod
 | |
| from typing import Any, Optional
 | |
| 
 | |
| 
 | |
| class QuantMethodBase(ABC):
 | |
|     """Base class for different quantized methods."""
 | |
| 
 | |
|     @abstractmethod
 | |
|     def create_weights(self, layer, *weight_args, **extra_weight_attrs):
 | |
|         """Create weights for a layer.
 | |
| 
 | |
|         The weights will be set as attributes of the layer."""
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @abstractmethod
 | |
|     def apply(self, layer, *args, **kwargs):
 | |
|         """Apply the weights in layer to the input tensor.
 | |
| 
 | |
|         Expects create_weights to have been called before on the layer."""
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def process_loaded_weights(self, layer, weights):
 | |
|         """Process the weight after loading.
 | |
| 
 | |
|         This can be used for example, to transpose weights for computation.
 | |
|         """
 | |
|         return
 | |
| 
 | |
| 
 | |
| class QuantConfigBase(ABC):
 | |
|     """Base class for quantization configs."""
 | |
| 
 | |
|     def __init__(self):
 | |
|         super().__init__()
 | |
| 
 | |
|     @abstractmethod
 | |
|     def name(self) -> str:
 | |
|         """Name of the quantization method."""
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @classmethod
 | |
|     @abstractmethod
 | |
|     def from_config(cls, config: dict) -> "QuantConfigBase":
 | |
|         """Create a config class from the model's quantization config."""
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @staticmethod
 | |
|     def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any:
 | |
|         """Get a value from the model's quantization config."""
 | |
|         for key in keys:
 | |
|             if key in config:
 | |
|                 return config[key]
 | |
|         raise ValueError(f"Cannot find any of {keys} in the model's " "quantization config.")
 | |
| 
 | |
|     @abstractmethod
 | |
|     def get_quant_method(self, layer, prefix) -> Optional[QuantMethodBase]:
 | |
|         """Get the quantize method to use for the quantized layer.
 | |
| 
 | |
|         Args:
 | |
|             layer: The layer for the quant method.
 | |
|             prefix: The full name of the layer in the state dict
 | |
|         Returns:
 | |
|             The quantize method. None if the given layer doesn't support quant
 | |
|             method.
 | |
|         """
 | |
|         raise NotImplementedError
 | 
