mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 20:02:53 +08:00 
			
		
		
		
	[V1 Loader] support weight_only (#3413)
	
		
			
	
		
	
	
		
	
		
			Some checks failed
		
		
	
	
		
			
				
	
				CE Compile Job / ce_job_pre_check (push) Has been cancelled
				
			
		
			
				
	
				CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
				
			
		
			
				
	
				CE Compile Job / FD-Clone-Linux (push) Has been cancelled
				
			
		
			
				
	
				CE Compile Job / Show Code Archive Output (push) Has been cancelled
				
			
		
			
				
	
				CE Compile Job / BUILD_SM8090 (push) Has been cancelled
				
			
		
			
				
	
				CE Compile Job / BUILD_SM8689 (push) Has been cancelled
				
			
		
			
				
	
				CE Compile Job / CE_UPLOAD (push) Has been cancelled
				
			
		
			
				
	
				Deploy GitHub Pages / deploy (push) Has been cancelled
				
			
		
			
				
	
				Publish Job / publish_pre_check (push) Has been cancelled
				
			
		
			
				
	
				Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
				
			
		
			
				
	
				Publish Job / FD-Clone-Linux (push) Has been cancelled
				
			
		
			
				
	
				Publish Job / Show Code Archive Output (push) Has been cancelled
				
			
		
			
				
	
				Publish Job / BUILD_SM8090 (push) Has been cancelled
				
			
		
			
				
	
				Publish Job / BUILD_SM8689 (push) Has been cancelled
				
			
		
			
				
	
				Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
				
			
		
			
				
	
				Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
				
			
		
			
				
	
				Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
				
			
		
			
				
	
				Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
				
			
		
			
				
	
				Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
				
			
		
			
				
	
				Publish Job / Run Base Tests (push) Has been cancelled
				
			
		
			
				
	
				Publish Job / Run Accuracy Tests (push) Has been cancelled
				
			
		
		
	
	
				
					
				
			
		
			Some checks failed
		
		
	
	CE Compile Job / ce_job_pre_check (push) Has been cancelled
				
			CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
				
			CE Compile Job / FD-Clone-Linux (push) Has been cancelled
				
			CE Compile Job / Show Code Archive Output (push) Has been cancelled
				
			CE Compile Job / BUILD_SM8090 (push) Has been cancelled
				
			CE Compile Job / BUILD_SM8689 (push) Has been cancelled
				
			CE Compile Job / CE_UPLOAD (push) Has been cancelled
				
			Deploy GitHub Pages / deploy (push) Has been cancelled
				
			Publish Job / publish_pre_check (push) Has been cancelled
				
			Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
				
			Publish Job / FD-Clone-Linux (push) Has been cancelled
				
			Publish Job / Show Code Archive Output (push) Has been cancelled
				
			Publish Job / BUILD_SM8090 (push) Has been cancelled
				
			Publish Job / BUILD_SM8689 (push) Has been cancelled
				
			Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
				
			Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
				
			Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
				
			Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
				
			Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
				
			Publish Job / Run Base Tests (push) Has been cancelled
				
			Publish Job / Run Accuracy Tests (push) Has been cancelled
				
			* support wint4/wint8 * delete smoe case * update ci * print log
This commit is contained in:
		| @@ -22,7 +22,7 @@ from paddle import nn | ||||
| from paddle.distributed import fleet | ||||
|  | ||||
| from fastdeploy.config import FDConfig | ||||
| from fastdeploy.model_executor.models.utils import set_weight_attrs | ||||
| from fastdeploy.model_executor.utils import set_weight_attrs | ||||
|  | ||||
| from .utils import get_tensor | ||||
|  | ||||
|   | ||||
| @@ -23,7 +23,7 @@ from paddle import nn | ||||
| from fastdeploy.config import FDConfig | ||||
| from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce | ||||
| from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase | ||||
| from fastdeploy.model_executor.models.utils import ( | ||||
| from fastdeploy.model_executor.utils import ( | ||||
|     default_weight_loader, | ||||
|     set_weight_attrs, | ||||
|     slice_fn, | ||||
| @@ -39,6 +39,7 @@ class UnquantizedLinearMethod(QuantMethodBase): | ||||
|     def create_weights(self, layer: nn.Layer, **extra_weight_attrs): | ||||
|         """ | ||||
|         extra_weight_attrs is a dictionary that may include parameters like: | ||||
|         - split_axis: axis along which to split the tensor in a distributed environment | ||||
|         - output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns) | ||||
|         - weight_loader: a callable or method responsible for loading the weight data | ||||
|         """ | ||||
| @@ -48,12 +49,16 @@ class UnquantizedLinearMethod(QuantMethodBase): | ||||
|             is_bias=False, | ||||
|             default_initializer=paddle.nn.initializer.Constant(0), | ||||
|         ) | ||||
|         split_axis = extra_weight_attrs.get("split_axis") | ||||
|         if hasattr(layer, "nranks") and layer.nranks > 0: | ||||
|             _set_var_distributed(layer.weight, split_axis=split_axis) | ||||
|         set_weight_attrs( | ||||
|             layer.weight, | ||||
|             {"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config))}, | ||||
|             { | ||||
|                 **extra_weight_attrs, | ||||
|                 "weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)), | ||||
|             }, | ||||
|         ) | ||||
|         if hasattr(layer, "nranks") and layer.nranks > 1: | ||||
|             set_weight_attrs(layer.weight, {"output_dim": extra_weight_attrs.get("output_dim")}) | ||||
|  | ||||
|     def process_loaded_weights(self, layer, weights) -> None: | ||||
|         # mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation | ||||
| @@ -340,7 +345,6 @@ class ColumnParallelLinear(LinearBase): | ||||
|             ), | ||||
|         ) | ||||
|         if self.nranks > 0: | ||||
|             _set_var_distributed(self.weight, split_axis=1) | ||||
|             if self.with_bias: | ||||
|                 # col parallel | ||||
|                 _set_var_distributed(self.bias, split_axis=1) | ||||
| @@ -399,28 +403,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear): | ||||
|         ) | ||||
|  | ||||
|     def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): | ||||
|         output_dim = getattr(param, "output_dim", None) | ||||
|         shard_dim = -1 if output_dim else 0 | ||||
|         output_size = param.shape[shard_dim] | ||||
|         if loaded_shard_id is None: | ||||
|             # Loaded weight is already fused on disk. | ||||
|             if self.nranks != 1: | ||||
|                 shard_offsets = [ | ||||
|                     # (shard_id, shard_offset, shard_size) | ||||
|                     ("gate", 0, self.output_size * self.nranks // 2), | ||||
|                     ("up", self.output_size * self.nranks // 2, self.output_size * self.nranks // 2), | ||||
|                 ] | ||||
|                 for shard_id, shard_offset, shard_size in shard_offsets: | ||||
|                     loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size] | ||||
|                     self.weight_loader(param, loaded_weight_shard, shard_id) | ||||
|             else: | ||||
|                 loaded_weight = get_tensor(loaded_weight) | ||||
|                 param.copy_(loaded_weight, False) | ||||
|             shard_offsets = [ | ||||
|                 # (shard_id, shard_offset, shard_size) | ||||
|                 ("gate", 0, output_size * self.nranks // 2), | ||||
|                 ("up", output_size * self.nranks // 2, output_size * self.nranks // 2), | ||||
|             ] | ||||
|             for shard_id, shard_offset, shard_size in shard_offsets: | ||||
|                 loaded_weight_shard = slice_fn( | ||||
|                     loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size | ||||
|                 ) | ||||
|                 self.weight_loader(param, loaded_weight_shard, shard_id) | ||||
|         else: | ||||
|             # 1.fused gate_up in disk | ||||
|             # 2.split gate up | ||||
|             # split gate up | ||||
|             assert loaded_shard_id in ["gate", "up"] | ||||
|             output_dim = getattr(param, "output_dim", None) | ||||
|             # Tensor parallelism splits the weight along the output_dim | ||||
|             if output_dim is not None: | ||||
|                 dim = -1 | ||||
|             if self.nranks != 1: | ||||
|                 dim = -1 if output_dim else 0 | ||||
|                 if isinstance(loaded_weight, np.ndarray): | ||||
|                     size = loaded_weight.shape[dim] | ||||
|                 else: | ||||
| @@ -428,15 +431,20 @@ class MergedColumnParallelLinear(ColumnParallelLinear): | ||||
|                 block_size = size // self.nranks | ||||
|                 shard_offset = self.local_rank * block_size | ||||
|                 shard_size = (self.local_rank + 1) * block_size | ||||
|                 loaded_weight = loaded_weight[..., shard_offset:shard_size] | ||||
|                 loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size) | ||||
|  | ||||
|             loaded_weight = get_tensor(loaded_weight) | ||||
|  | ||||
|             if not param._is_initialized(): | ||||
|                 param.initialize() | ||||
|             param_shard_size = output_size // 2 | ||||
|             if loaded_shard_id == "gate": | ||||
|                 param = param[:, : self.output_size // 2] | ||||
|             elif loaded_shard_id == "up": | ||||
|                 param = param[:, self.output_size // 2 :] | ||||
|  | ||||
|                 param_shard_offset = 0 | ||||
|             else: | ||||
|                 # loaded_shard_id == "up" | ||||
|                 param_shard_offset = param_shard_size | ||||
|             if hasattr(param, "tensor_track"): | ||||
|                 param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size) | ||||
|             param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size) | ||||
|             assert param.shape == loaded_weight.shape, ( | ||||
|                 f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" | ||||
|             ) | ||||
| @@ -513,30 +521,25 @@ class QKVParallelLinear(ColumnParallelLinear): | ||||
|  | ||||
|     def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): | ||||
|         output_dim = getattr(param, "output_dim", None) | ||||
|         head_dim = param.shape[output_dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) | ||||
|         if loaded_shard_id is None: | ||||
|             # Loaded weight is already fused on disk | ||||
|             if self.nranks != 1: | ||||
|                 shard_offsets = [ | ||||
|                     # (shard_id, shard_offset, shard_size) | ||||
|                     ("q", 0, self.num_heads * self.head_dim), | ||||
|                     ("k", self.num_heads * self.head_dim, self.kv_num_heads * self.head_dim), | ||||
|                     ("v", (self.num_heads + self.kv_num_heads) * self.head_dim, self.kv_num_heads * self.head_dim), | ||||
|                 ] | ||||
|                 for shard_id, shard_offset, shard_size in shard_offsets: | ||||
|                     loaded_weight_shard = loaded_weight_shard = slice_fn( | ||||
|                         loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size | ||||
|                     ) | ||||
|                     self.weight_loader(param, loaded_weight_shard, shard_id) | ||||
|             else: | ||||
|                 loaded_weight = get_tensor(loaded_weight) | ||||
|                 split_loaded_weight = loaded_weight | ||||
|                 param.copy_(split_loaded_weight, False) | ||||
|             shard_offsets = [ | ||||
|                 # (shard_id, shard_offset, shard_size) | ||||
|                 ("q", 0, self.num_heads * head_dim), | ||||
|                 ("k", self.num_heads * head_dim, self.kv_num_heads * head_dim), | ||||
|                 ("v", (self.num_heads + self.kv_num_heads) * head_dim, self.kv_num_heads * head_dim), | ||||
|             ] | ||||
|             for shard_id, shard_offset, shard_size in shard_offsets: | ||||
|                 loaded_weight_shard = slice_fn( | ||||
|                     loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size | ||||
|                 ) | ||||
|                 self.weight_loader(param, loaded_weight_shard, shard_id) | ||||
|         else: | ||||
|             # 1.fused qkv in disk | ||||
|             # 2.split q k v | ||||
|             # split q k v | ||||
|             assert loaded_shard_id in ["q", "k", "v"] | ||||
|             # Tensor parallelism splits the weight along the output_dim | ||||
|             if output_dim is not None: | ||||
|             if self.nranks != 1: | ||||
|                 dim = -1 if output_dim else 0 | ||||
|                 if isinstance(loaded_weight, np.ndarray): | ||||
|                     size = loaded_weight.shape[dim] | ||||
| @@ -545,20 +548,25 @@ class QKVParallelLinear(ColumnParallelLinear): | ||||
|                 block_size = size // self.nranks | ||||
|                 shard_offset = self.local_rank * block_size | ||||
|                 shard_size = (self.local_rank + 1) * block_size | ||||
|                 loaded_weight = loaded_weight[..., shard_offset:shard_size] | ||||
|                 loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size) | ||||
|  | ||||
|             loaded_weight = get_tensor(loaded_weight) | ||||
|             if not param._is_initialized(): | ||||
|                 param.initialize() | ||||
|  | ||||
|             if loaded_shard_id == "q": | ||||
|  | ||||
|                 param_shard_offset = 0 | ||||
|                 param_shard_size = self.num_heads_per_rank * self.head_dim | ||||
|                 param_shard_size = self.num_heads_per_rank * head_dim | ||||
|             elif loaded_shard_id == "k": | ||||
|                 param_shard_offset = self.num_heads_per_rank * self.head_dim | ||||
|                 param_shard_size = self.kv_num_heads_per_rank * self.head_dim | ||||
|                 param_shard_offset = self.num_heads_per_rank * head_dim | ||||
|                 param_shard_size = self.kv_num_heads_per_rank * head_dim | ||||
|             else: | ||||
|                 # loaded_shard_id == "v" | ||||
|                 param_shard_offset = (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim | ||||
|                 param_shard_size = self.kv_num_heads_per_rank * self.head_dim | ||||
|                 param_shard_offset = (self.num_heads_per_rank + self.kv_num_heads_per_rank) * head_dim | ||||
|                 param_shard_size = self.kv_num_heads_per_rank * head_dim | ||||
|             if hasattr(param, "tensor_track"): | ||||
|                 param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size) | ||||
|             param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size) | ||||
|             assert param.shape == loaded_weight.shape, ( | ||||
|                 f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" | ||||
| @@ -706,7 +714,6 @@ class RowParallelLinear(LinearBase): | ||||
|             ), | ||||
|         ) | ||||
|         if self.nranks > 0: | ||||
|             _set_var_distributed(self.weight, split_axis=0) | ||||
|             if self.with_bias: | ||||
|                 # col parallel | ||||
|                 _set_var_distributed(self.bias, split_axis=0) | ||||
| @@ -732,7 +739,7 @@ class RowParallelLinear(LinearBase): | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class KVBatchLinear(LinearBase): | ||||
| class KVBatchLinear(nn.Layer): | ||||
|     """ | ||||
|     KVBatchLinear Layer for handling combined KV projections with bmm. | ||||
|     """ | ||||
| @@ -740,13 +747,12 @@ class KVBatchLinear(LinearBase): | ||||
|     def __init__( | ||||
|         self, | ||||
|         fd_config: FDConfig, | ||||
|         kv_b_proj: nn.Layer, | ||||
|         prefix: str = "", | ||||
|         kv_lora_rank: int = None, | ||||
|         num_attention_heads: int = None, | ||||
|         qk_nope_head_dim: int = None, | ||||
|         v_head_dim: int = None, | ||||
|         with_bias: bool = False, | ||||
|         skip_quant: bool = False, | ||||
|     ): | ||||
|         """ | ||||
|         Initializes a KV batch linear layer that internally splits into K and V projections. | ||||
| @@ -761,6 +767,7 @@ class KVBatchLinear(LinearBase): | ||||
|             with_bias (bool): Whether to include bias or not. Defaults to False. | ||||
|             skip_quant (bool): Whether to skip quantization. Defaults to False. | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.nranks = fd_config.parallel_config.tensor_parallel_size | ||||
|         self.kv_lora_rank = kv_lora_rank | ||||
|         self.num_attention_heads = num_attention_heads | ||||
| @@ -770,69 +777,27 @@ class KVBatchLinear(LinearBase): | ||||
|         self.num_heads_per_partition = divide(num_attention_heads, self.nranks) | ||||
|         self.local_rank = fd_config.parallel_config.tensor_parallel_rank | ||||
|  | ||||
|         # Initialize parent with combined dimensions | ||||
|         super().__init__( | ||||
|             fd_config=fd_config, | ||||
|             prefix=prefix, | ||||
|             input_size=None,  # Will be determined from weight shape | ||||
|             output_size=None,  # Will be determined from weight shape | ||||
|             with_bias=with_bias, | ||||
|             add_bias=False, | ||||
|             skip_quant=skip_quant, | ||||
|         ) | ||||
|         self.weight_dtype = self._dtype | ||||
|         self.kv_b_proj = kv_b_proj | ||||
|  | ||||
|         self.weight_dtype = self._helper.get_default_dtype() | ||||
|  | ||||
|         # Override weight keys to use the combined kv_b_proj | ||||
|         self.weight_key = f"{prefix}.weight"  # e.g., "kv_b_proj.weight" | ||||
|         self.k_weight_key = f"{prefix.replace('kv_b_proj', 'k_b_proj')}.weight" | ||||
|         self.v_weight_key = f"{prefix.replace('kv_b_proj', 'v_b_proj')}.weight" | ||||
|  | ||||
|         self.k_b_proj_weight = self.create_parameter( | ||||
|             shape=[self.num_heads_per_partition, self.qk_nope_head_dim, self.kv_lora_rank], | ||||
|             dtype=self.weight_dtype, | ||||
|             is_bias=False, | ||||
|             default_initializer=paddle.nn.initializer.Constant(0), | ||||
|         ) | ||||
|     def process_weights_after_loading(self): | ||||
|  | ||||
|         self.v_b_proj_weight = self.create_parameter( | ||||
|             shape=[self.num_heads_per_partition, self.kv_lora_rank, self.v_head_dim], | ||||
|             dtype=self.weight_dtype, | ||||
|             is_bias=False, | ||||
|             default_initializer=paddle.nn.initializer.Constant(0), | ||||
|         ) | ||||
|         w = self.kv_b_proj.weight.reshape( | ||||
|             [ | ||||
|                 self.kv_lora_rank, | ||||
|                 self.num_heads_per_partition, | ||||
|                 -1, | ||||
|             ] | ||||
|         ).transpose(perm=[1, 2, 0]) | ||||
|         self.kv_b_proj = None | ||||
|  | ||||
|         set_weight_attrs( | ||||
|             self.k_b_proj_weight, | ||||
|             {"weight_loader": self.weight_loader}, | ||||
|         ) | ||||
|         if w.dtype != self.weight_dtype: | ||||
|             w = w.cast(self.weight_dtype) | ||||
|  | ||||
|         if self.nranks > 0: | ||||
|             _set_var_distributed(self.k_b_proj_weight, split_axis=1) | ||||
|             set_weight_attrs(self.k_b_proj_weight, {"output_dim": True}) | ||||
|  | ||||
|     def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): | ||||
|         output_dim = getattr(param, "output_dim", None) | ||||
|         # Tensor parallelism splits the weight along the output_dim | ||||
|         if output_dim is not None: | ||||
|             dim = -1 | ||||
|             size = loaded_weight.get_shape()[dim] | ||||
|             block_size = size // self.nranks | ||||
|             shard_offset = self.local_rank * block_size | ||||
|             shard_size = (self.local_rank + 1) * block_size | ||||
|             loaded_weight = loaded_weight[..., shard_offset:shard_size] | ||||
|         w = ( | ||||
|             get_tensor(loaded_weight) | ||||
|             .reshape( | ||||
|                 [ | ||||
|                     self.kv_lora_rank, | ||||
|                     self.num_heads_per_partition, | ||||
|                     -1, | ||||
|                 ] | ||||
|             ) | ||||
|             .transpose(perm=[1, 2, 0]) | ||||
|         ) | ||||
|         if param.dtype != w.dtype: | ||||
|             w = w.cast(param.dtype) | ||||
|         # Split into K and V weights | ||||
|         # wk_b: [num_heads, qk_nope_head_dim, kv_lora_rank] | ||||
|         wk_b = w[:, : self.qk_nope_head_dim, :] | ||||
| @@ -840,9 +805,8 @@ class KVBatchLinear(LinearBase): | ||||
|             raise ValueError("self.v_head_dim should not be None") | ||||
|         # wv_b: [num_heads, kv_lora_rank, v_head_dim] | ||||
|         wv_b = w[:, -self.v_head_dim :, :].transpose(perm=[0, 2, 1]) | ||||
|  | ||||
|         self.k_b_proj_weight.set_value(wk_b) | ||||
|         self.v_b_proj_weight.set_value(wv_b) | ||||
|         self.k_b_proj_weight = wk_b | ||||
|         self.v_b_proj_weight = wv_b | ||||
|  | ||||
|     def load_state_dict(self, state_dict: dict): | ||||
|         """ | ||||
| @@ -916,7 +880,7 @@ class KVBatchLinear(LinearBase): | ||||
|         out = paddle.bmm(x, self.v_b_proj_weight) | ||||
|         return out | ||||
|  | ||||
|     def forward_cuda(self, x: paddle.Tensor, proj_type: str = "k") -> paddle.Tensor: | ||||
|     def forward(self, x: paddle.Tensor, proj_type: str = "k") -> paddle.Tensor: | ||||
|         """ | ||||
|         Forward function that can handle both K and V projections | ||||
|  | ||||
|   | ||||
| @@ -22,7 +22,7 @@ from paddle import nn | ||||
| from paddle.distributed import fleet | ||||
|  | ||||
| from fastdeploy.config import FDConfig | ||||
| from fastdeploy.model_executor.models.utils import set_weight_attrs | ||||
| from fastdeploy.model_executor.utils import set_weight_attrs | ||||
|  | ||||
| from .utils import get_tensor | ||||
|  | ||||
|   | ||||
| @@ -19,7 +19,7 @@ from abc import abstractmethod | ||||
| import paddle | ||||
| from paddle import nn | ||||
|  | ||||
| from fastdeploy.model_executor.layers.utils import set_weight_attrs | ||||
| from fastdeploy.model_executor.utils import set_weight_attrs | ||||
| from fastdeploy.platforms import current_platform | ||||
|  | ||||
| from ..quantization.quant_base import QuantMethodBase | ||||
| @@ -185,9 +185,11 @@ class UnquantizedFusedMoEMethod(MoEMethodBase): | ||||
|         if current_platform.is_cuda(): | ||||
|             self.up_gate_proj_weight_shape = [layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2] | ||||
|             self.down_proj_weight_shape = [layer.num_experts, layer.moe_intermediate_size, layer.hidden_size] | ||||
|             extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 1, "down": 0, "up": 1}} | ||||
|         else: | ||||
|             self.up_gate_proj_weight_shape = [layer.num_experts, layer.moe_intermediate_size * 2, layer.hidden_size] | ||||
|             self.down_proj_weight_shape = [layer.num_experts, layer.hidden_size, layer.moe_intermediate_size] | ||||
|             extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}} | ||||
|  | ||||
|         layer.up_gate_proj_weight = layer.create_parameter( | ||||
|             shape=self.up_gate_proj_weight_shape, | ||||
| @@ -203,10 +205,3 @@ class UnquantizedFusedMoEMethod(MoEMethodBase): | ||||
|  | ||||
|         set_weight_attrs(layer.up_gate_proj_weight, extra_weight_attrs) | ||||
|         set_weight_attrs(layer.down_proj_weight, extra_weight_attrs) | ||||
|  | ||||
|         if layer.moe_use_gate_correction_bias: | ||||
|             gate_correction_bias_shape = [1, layer.num_experts] | ||||
|             layer.gate_correction_bias = layer.create_parameter( | ||||
|                 shape=gate_correction_bias_shape, | ||||
|                 dtype="float32", | ||||
|             ) | ||||
|   | ||||
| @@ -38,6 +38,8 @@ elif current_platform.is_iluvatar(): | ||||
|         moe_expert_reduce, | ||||
|     ) | ||||
|  | ||||
| from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs | ||||
|  | ||||
|  | ||||
| # used for deepseek_v3 | ||||
| def get_moe_scores( | ||||
| @@ -93,8 +95,8 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): | ||||
|             return fastdeploy.model_executor.ops.iluvatar.moe_expert_ffn( | ||||
|                 permute_input, | ||||
|                 token_nums_per_expert, | ||||
|                 layer.up_gate_proj_weight, | ||||
|                 layer.down_proj_weight, | ||||
|                 getattr(layer, self.added_weight_attrs[0]), | ||||
|                 getattr(layer, self.added_weight_attrs[1]), | ||||
|                 None, | ||||
|                 (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), | ||||
|                 (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None), | ||||
| @@ -106,8 +108,8 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): | ||||
|         return fastdeploy.model_executor.ops.gpu.moe_expert_ffn( | ||||
|             permute_input, | ||||
|             token_nums_per_expert, | ||||
|             layer.up_gate_proj_weight, | ||||
|             layer.down_proj_weight, | ||||
|             getattr(layer, self.added_weight_attrs[0]), | ||||
|             getattr(layer, self.added_weight_attrs[1]), | ||||
|             None, | ||||
|             (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), | ||||
|             (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None), | ||||
| @@ -392,12 +394,12 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod): | ||||
|         Paddle cutlass create weight process. | ||||
|         """ | ||||
|         self.weight_dtype = "int8" | ||||
|         self.ffn1_weight_shape = [ | ||||
|         self.up_gate_proj_weight_shape = [ | ||||
|             layer.num_local_experts, | ||||
|             layer.hidden_size // 2, | ||||
|             layer.moe_intermediate_size * 2, | ||||
|         ] | ||||
|         self.ffn2_weight_shape = [ | ||||
|         self.down_proj_weight_shape = [ | ||||
|             layer.num_local_experts, | ||||
|             layer.moe_intermediate_size // 2, | ||||
|             layer.hidden_size, | ||||
| @@ -406,7 +408,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod): | ||||
|             layer, | ||||
|             self.added_weight_attrs[0], | ||||
|             layer.create_parameter( | ||||
|                 shape=self.ffn1_weight_shape, | ||||
|                 shape=self.up_gate_proj_weight_shape, | ||||
|                 dtype=self.weight_dtype, | ||||
|                 default_initializer=paddle.nn.initializer.Constant(0), | ||||
|             ), | ||||
| @@ -415,7 +417,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod): | ||||
|             layer, | ||||
|             self.added_weight_attrs[1], | ||||
|             layer.create_parameter( | ||||
|                 shape=self.ffn2_weight_shape, | ||||
|                 shape=self.down_proj_weight_shape, | ||||
|                 dtype=self.weight_dtype, | ||||
|                 default_initializer=paddle.nn.initializer.Constant(0), | ||||
|             ), | ||||
| @@ -625,71 +627,177 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod): | ||||
|         Paddle cutlass create weight process. | ||||
|         """ | ||||
|         self.default_dtype = layer._helper.get_default_dtype() | ||||
|         self.weight_dtype = "int8" | ||||
|  | ||||
|         up_gate_proj_weight_name = self.added_weight_attrs[0] | ||||
|         down_proj_weight_name = self.added_weight_attrs[1] | ||||
|         if self.moe_quant_type == "weight_only_int4": | ||||
|             self.ffn1_weight_shape = [ | ||||
|             self.up_gate_proj_weight_shape = [ | ||||
|                 layer.num_local_experts, | ||||
|                 layer.moe_intermediate_size, | ||||
|                 layer.hidden_size, | ||||
|             ] | ||||
|         else: | ||||
|             self.ffn1_weight_shape = [ | ||||
|             self.up_gate_proj_weight_shape = [ | ||||
|                 layer.num_local_experts, | ||||
|                 layer.moe_intermediate_size * 2, | ||||
|                 layer.hidden_size, | ||||
|             ] | ||||
|         if self.moe_quant_type == "weight_only_int4": | ||||
|             self.ffn2_weight_shape = [ | ||||
|             self.down_proj_weight_shape = [ | ||||
|                 layer.num_local_experts, | ||||
|                 layer.hidden_size // 2, | ||||
|                 layer.moe_intermediate_size, | ||||
|             ] | ||||
|         else: | ||||
|             self.ffn2_weight_shape = [ | ||||
|             self.down_proj_weight_shape = [ | ||||
|                 layer.num_local_experts, | ||||
|                 layer.hidden_size, | ||||
|                 layer.moe_intermediate_size, | ||||
|             ] | ||||
|         self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2] | ||||
|         self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size] | ||||
|  | ||||
|         if layer.fd_config.load_config.load_choices == "default_v1": | ||||
|             layer.up_gate_proj_weight = layer.create_parameter( | ||||
|                 shape=[layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2], | ||||
|                 dtype=layer.weight_dtype, | ||||
|                 default_initializer=paddle.nn.initializer.Constant(0), | ||||
|             ) | ||||
|  | ||||
|             layer.down_proj_weight = layer.create_parameter( | ||||
|                 shape=[layer.num_experts, layer.moe_intermediate_size, layer.hidden_size], | ||||
|                 dtype=layer.weight_dtype, | ||||
|                 default_initializer=paddle.nn.initializer.Constant(0), | ||||
|             ) | ||||
|  | ||||
|             set_weight_attrs( | ||||
|                 layer.up_gate_proj_weight, | ||||
|                 { | ||||
|                     **extra_weight_attrs, | ||||
|                     "tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True), | ||||
|                 }, | ||||
|             ) | ||||
|             set_weight_attrs( | ||||
|                 layer.down_proj_weight, | ||||
|                 { | ||||
|                     **extra_weight_attrs, | ||||
|                     "tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False), | ||||
|                 }, | ||||
|             ) | ||||
|         else: | ||||
|             self.weight_dtype = "int8" | ||||
|  | ||||
|             up_gate_proj_weight_name = self.added_weight_attrs[0] | ||||
|             down_proj_weight_name = self.added_weight_attrs[1] | ||||
|             up_gate_proj_scale_name = self.added_scale_attrs[0] | ||||
|             down_proj_scale_name = self.added_scale_attrs[1] | ||||
|  | ||||
|             setattr( | ||||
|                 layer, | ||||
|                 up_gate_proj_weight_name, | ||||
|                 layer.create_parameter( | ||||
|                     shape=self.up_gate_proj_weight_shape, | ||||
|                     dtype=self.weight_dtype, | ||||
|                     default_initializer=paddle.nn.initializer.Constant(0), | ||||
|                 ), | ||||
|             ) | ||||
|             setattr( | ||||
|                 layer, | ||||
|                 down_proj_weight_name, | ||||
|                 layer.create_parameter( | ||||
|                     shape=self.down_proj_weight_shape, | ||||
|                     dtype=self.weight_dtype, | ||||
|                     default_initializer=paddle.nn.initializer.Constant(0), | ||||
|                 ), | ||||
|             ) | ||||
|             # weight_scale | ||||
|             setattr( | ||||
|                 layer, | ||||
|                 up_gate_proj_scale_name, | ||||
|                 layer.create_parameter( | ||||
|                     shape=self.up_gate_proj_scale_shape, | ||||
|                     dtype=self.default_dtype, | ||||
|                     default_initializer=paddle.nn.initializer.Constant(0), | ||||
|                 ), | ||||
|             ) | ||||
|             setattr( | ||||
|                 layer, | ||||
|                 down_proj_scale_name, | ||||
|                 layer.create_parameter( | ||||
|                     shape=self.down_proj_scale_shape, | ||||
|                     dtype=self.default_dtype, | ||||
|                     default_initializer=paddle.nn.initializer.Constant(0), | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|             moe_extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}} | ||||
|             set_weight_attrs(layer.up_gate_proj_weight, moe_extra_weight_attrs) | ||||
|             set_weight_attrs(layer.down_proj_weight, moe_extra_weight_attrs) | ||||
|             scale_extra_weight_attrs = { | ||||
|                 **extra_weight_attrs, | ||||
|                 "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "up": 0, "down": None}, | ||||
|             } | ||||
|             set_weight_attrs(layer.up_gate_proj_weight_scale, scale_extra_weight_attrs) | ||||
|             set_weight_attrs(layer.down_proj_weight_scale, scale_extra_weight_attrs) | ||||
|  | ||||
|     def process_weights_after_loading(self, layer): | ||||
|         """ """ | ||||
|         if not layer.fd_config.load_config.load_choices == "default_v1": | ||||
|             return | ||||
|         weight_id_map = {"gate_up": 0, "down": 1} | ||||
|         if ( | ||||
|             hasattr(layer.up_gate_proj_weight, "tensor_track") | ||||
|             and layer.up_gate_proj_weight.tensor_track is not None | ||||
|             and layer.up_gate_proj_weight.tensor_track.is_fully_copied() | ||||
|         ): | ||||
|             weight_type = "gate_up" | ||||
|         else: | ||||
|             weight_type = "down" | ||||
|  | ||||
|         # 1.init shape and type | ||||
|         # weight | ||||
|         weight_name = self.added_weight_attrs[weight_id_map[weight_type]] | ||||
|         unquantized_weight_name = weight_name.replace("quant_weight", "weight") | ||||
|         weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape | ||||
|         weight_dtype = "int8" | ||||
|         # scale | ||||
|         scale_name = self.added_scale_attrs[weight_id_map[weight_type]] | ||||
|         scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape | ||||
|         scale_dtype = self.default_dtype | ||||
|  | ||||
|         # 2.crate tmp tensor | ||||
|  | ||||
|         weight = paddle.empty(weight_shape, dtype=weight_dtype) | ||||
|         scale = paddle.empty(scale_shape, dtype=scale_dtype) | ||||
|  | ||||
|         # 3.quantize weight | ||||
|  | ||||
|         for expert_id in range(layer.num_experts): | ||||
|             weight[expert_id], scale[expert_id] = weight_quantize( | ||||
|                 getattr(layer, unquantized_weight_name)[expert_id], algo=self.moe_quant_type | ||||
|             ) | ||||
|  | ||||
|         free_tensor(getattr(layer, unquantized_weight_name)) | ||||
|  | ||||
|         # create weight | ||||
|         setattr( | ||||
|             layer, | ||||
|             up_gate_proj_weight_name, | ||||
|             weight_name, | ||||
|             layer.create_parameter( | ||||
|                 shape=self.ffn1_weight_shape, | ||||
|                 dtype=self.weight_dtype, | ||||
|                 shape=weight_shape, | ||||
|                 dtype=weight_dtype, | ||||
|                 default_initializer=paddle.nn.initializer.Constant(0), | ||||
|             ), | ||||
|         ) | ||||
|         # create scale | ||||
|         setattr( | ||||
|             layer, | ||||
|             down_proj_weight_name, | ||||
|             scale_name, | ||||
|             layer.create_parameter( | ||||
|                 shape=self.ffn2_weight_shape, | ||||
|                 dtype=self.weight_dtype, | ||||
|                 default_initializer=paddle.nn.initializer.Constant(0), | ||||
|             ), | ||||
|         ) | ||||
|         # weight_scale | ||||
|         setattr( | ||||
|             layer, | ||||
|             self.added_scale_attrs[0], | ||||
|             layer.create_parameter( | ||||
|                 shape=[layer.num_local_experts, layer.moe_intermediate_size * 2], | ||||
|                 dtype=self.default_dtype, | ||||
|                 default_initializer=paddle.nn.initializer.Constant(0), | ||||
|             ), | ||||
|         ) | ||||
|         setattr( | ||||
|             layer, | ||||
|             self.added_scale_attrs[1], | ||||
|             layer.create_parameter( | ||||
|                 shape=[layer.num_local_experts, layer.hidden_size], | ||||
|                 dtype=self.default_dtype, | ||||
|                 shape=scale_shape, | ||||
|                 dtype=scale_dtype, | ||||
|                 default_initializer=paddle.nn.initializer.Constant(0), | ||||
|             ), | ||||
|         ) | ||||
|         getattr(layer, weight_name).copy_(weight, False) | ||||
|         getattr(layer, scale_name).copy_(scale, False) | ||||
|  | ||||
|     def process_loaded_weights(self, layer: nn.Layer, state_dict): | ||||
|         """ | ||||
|   | ||||
| @@ -23,6 +23,7 @@ from paddleformers.utils.log import logger | ||||
|  | ||||
| from fastdeploy import envs | ||||
| from fastdeploy.model_executor.layers.utils import get_tensor | ||||
| from fastdeploy.model_executor.utils import slice_fn | ||||
| from fastdeploy.platforms import current_platform | ||||
| from fastdeploy.worker.experts_manager import RedundantExpertManger | ||||
|  | ||||
| @@ -78,6 +79,7 @@ class FusedMoE(nn.Layer): | ||||
|         routed_scaling_factor: float = 1.0, | ||||
|         layer_idx: int = -1, | ||||
|         moe_tag: str = "", | ||||
|         gate_correction_bias=None, | ||||
|         weight_key_map: dict = {}, | ||||
|     ): | ||||
|         """ | ||||
| @@ -155,9 +157,10 @@ class FusedMoE(nn.Layer): | ||||
|             # It's for RL to build model | ||||
|             self.init_moe_weights() | ||||
|         else: | ||||
|             self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None) | ||||
|             if self.gate_correction_bias_key is not None: | ||||
|                 self.gate_correction_bias = self.create_parameter(shape=[1, self.num_experts], dtype="float32") | ||||
|             if gate_correction_bias is not None: | ||||
|                 self.gate_correction_bias = gate_correction_bias | ||||
|             else: | ||||
|                 self.gate_correction_bias = None | ||||
|             if moe_quant_config: | ||||
|                 if ( | ||||
|                     moe_quant_config | ||||
| @@ -179,54 +182,72 @@ class FusedMoE(nn.Layer): | ||||
|     def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str] = None): | ||||
|         from fastdeploy.platforms import current_platform | ||||
|  | ||||
|         if hasattr(param, "SHARD_ID_TO_SHARDED_DIM"): | ||||
|             SHARD_ID_TO_SHARDED_DIM = param.SHARD_ID_TO_SHARDED_DIM | ||||
|         elif current_platform.is_cuda(): | ||||
|             SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1} | ||||
|         else: | ||||
|             SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0} | ||||
|  | ||||
|         if not param._is_initialized(): | ||||
|             param.initialize() | ||||
|  | ||||
|         if shard_id is None: | ||||
|             # 1.gate up fused in disk | ||||
|             if self.tp_size > 1: | ||||
|                 shard_offsets = [ | ||||
|                     # (shard_id, shard_offset, shard_size) | ||||
|                     ("gate", 0, self.moe_intermediate_size * self.tp_size), | ||||
|                     ("up", self.moe_intermediate_size * self.tp_size, self.moe_intermediate_size * self.tp_size), | ||||
|                 ] | ||||
|                 for shard_id, shard_offset, shard_size in shard_offsets: | ||||
|                     loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size] | ||||
|                     self.weight_loader(param, loaded_weight_shard, expert_id, shard_id) | ||||
|             else: | ||||
|                 expert_param = param[expert_id - self.expert_id_offset] | ||||
|                 loaded_weight = get_tensor(loaded_weight) | ||||
|                 expert_param.copy_(loaded_weight, False) | ||||
|             output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]] | ||||
|             shard_offsets = [ | ||||
|                 # (shard_id, shard_offset, shard_size) | ||||
|                 ("gate", 0, output_size // 2 * self.tp_size), | ||||
|                 ("up", output_size // 2 * self.tp_size, output_size // 2 * self.tp_size), | ||||
|             ] | ||||
|             for shard_id, shard_offset, shard_size in shard_offsets: | ||||
|                 loaded_weight_shard = slice_fn( | ||||
|                     loaded_weight, SHARD_ID_TO_SHARDED_DIM[shard_id], shard_offset, shard_offset + shard_size | ||||
|                 ) | ||||
|                 self.weight_loader(param, loaded_weight_shard, expert_id, shard_id) | ||||
|         else: | ||||
|             # 2.gate up splited in disk | ||||
|             assert shard_id in ["gate", "down", "up"] | ||||
|             if current_platform.is_cuda(): | ||||
|                 SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1} | ||||
|             else: | ||||
|                 SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0} | ||||
|             self._load_expert_weight( | ||||
|                 param=param, | ||||
|                 expert_id=expert_id, | ||||
|                 shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id], | ||||
|                 loaded_weight=loaded_weight, | ||||
|                 shard_id=shard_id, | ||||
|                 shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id], | ||||
|             ) | ||||
|  | ||||
|     def _load_gate_up_weight(self, expert_param, shard_dim, loaded_weight, shard_id): | ||||
|         tensor_size = expert_param.shape[shard_dim] // 2 | ||||
|         if shard_id == "gate": | ||||
|             expert_param = expert_param[..., :tensor_size] if shard_dim else expert_param[:tensor_size, ...] | ||||
|         elif shard_id == "up": | ||||
|             expert_param = expert_param[..., tensor_size:] if shard_dim else expert_param[tensor_size:, ...] | ||||
|  | ||||
|     def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None): | ||||
|         dim = -1 if shard_dim else 0 | ||||
|         if self.tp_size > 1: | ||||
|             if isinstance(loaded_weight, np.ndarray): | ||||
|                 size = loaded_weight.shape[-1] | ||||
|                 size = loaded_weight.shape[dim] | ||||
|             else: | ||||
|                 size = loaded_weight.get_shape()[-1] | ||||
|                 size = loaded_weight.get_shape()[dim] | ||||
|             block_size = size // self.tp_size | ||||
|             shard_offset = self.tp_rank * block_size | ||||
|             shard_size = (self.tp_rank + 1) * block_size | ||||
|             loaded_weight = loaded_weight[..., shard_offset:shard_size] | ||||
|             loaded_weight = slice_fn(loaded_weight, shard_dim, shard_offset, shard_size) | ||||
|  | ||||
|         loaded_weight = get_tensor(loaded_weight) | ||||
|  | ||||
|         expert_param = param[expert_id - self.expert_id_offset] | ||||
|         param_shard_size = expert_param.shape[dim] // 2 | ||||
|         if shard_id == "gate": | ||||
|             param_shard_offset = 0 | ||||
|         else: | ||||
|             # shard_id == "up": | ||||
|             param_shard_offset = param_shard_size | ||||
|         expert_param = slice_fn( | ||||
|             expert_param, shard_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size | ||||
|         ) | ||||
|         if hasattr(param, "tensor_track"): | ||||
|             # for dyn quant | ||||
|             param.tensor_track.mark( | ||||
|                 start=param_shard_offset, | ||||
|                 end=param_shard_offset + param_shard_size, | ||||
|                 batch_id=expert_id - self.expert_id_offset, | ||||
|             ) | ||||
|  | ||||
|         # To ensure compatibility across backends, apply an extra transpose for GCU and XPU | ||||
|         if expert_param.shape != loaded_weight.shape: | ||||
|             loaded_weight = loaded_weight.transpose([1, 0]) | ||||
| @@ -235,17 +256,22 @@ class FusedMoE(nn.Layer): | ||||
|         ) | ||||
|         expert_param.copy_(loaded_weight, False) | ||||
|  | ||||
|     def _load_down_weight(self, expert_param, shard_dim, loaded_weight, shard_id): | ||||
|         if self.tp_size > 1: | ||||
|     def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None): | ||||
|         if self.tp_size > 1 and shard_dim is not None: | ||||
|             dim = -1 if shard_dim else 0 | ||||
|             if isinstance(loaded_weight, np.ndarray): | ||||
|                 size = loaded_weight.shape[shard_dim] | ||||
|                 size = loaded_weight.shape[dim] | ||||
|             else: | ||||
|                 size = loaded_weight.get_shape()[shard_dim] | ||||
|                 size = loaded_weight.get_shape()[dim] | ||||
|             block_size = size // self.tp_size | ||||
|             shard_offset = self.tp_rank * block_size | ||||
|             shard_size = (self.tp_rank + 1) * block_size | ||||
|             loaded_weight = loaded_weight[shard_offset:shard_size, ...] | ||||
|             loaded_weight = slice_fn(loaded_weight, shard_dim, shard_offset, shard_size) | ||||
|         loaded_weight = get_tensor(loaded_weight) | ||||
|         expert_param = param[expert_id - self.expert_id_offset] | ||||
|         if hasattr(param, "tensor_track"): | ||||
|             # for dyn quant | ||||
|             param.tensor_track.mark(start=0, batch_id=expert_id - self.expert_id_offset) | ||||
|         # To ensure compatibility across backends, apply an extra transpose for GCU and XPU | ||||
|         if expert_param.shape != loaded_weight.shape: | ||||
|             loaded_weight = loaded_weight.transpose([1, 0]) | ||||
| @@ -258,15 +284,14 @@ class FusedMoE(nn.Layer): | ||||
|         self, | ||||
|         param, | ||||
|         expert_id, | ||||
|         shard_dim, | ||||
|         loaded_weight, | ||||
|         shard_id, | ||||
|         shard_dim=None, | ||||
|     ): | ||||
|         expert_param = param[expert_id - self.expert_id_offset] | ||||
|         if shard_id == "down": | ||||
|             self._load_down_weight(expert_param, shard_dim, loaded_weight, shard_id) | ||||
|             self._load_down_weight(param, expert_id, loaded_weight, shard_id, shard_dim) | ||||
|         elif shard_id in ["gate", "up"]: | ||||
|             self._load_gate_up_weight(expert_param, shard_dim, loaded_weight, shard_id) | ||||
|             self._load_gate_up_weight(param, expert_id, loaded_weight, shard_id, shard_dim) | ||||
|  | ||||
|     @classmethod | ||||
|     def make_expert_params_mapping( | ||||
| @@ -314,13 +339,6 @@ class FusedMoE(nn.Layer): | ||||
|         Combines weight shape initialization and parameter creation into a single function. | ||||
|         """ | ||||
|         # Initialize weight shapes | ||||
|         gate_correction_bias_shape = [1, self.num_experts] | ||||
|  | ||||
|         if self.fd_config.model_config.moe_use_aux_free: | ||||
|             self.gate_correction_bias = self.create_parameter( | ||||
|                 shape=gate_correction_bias_shape, | ||||
|                 dtype="float32", | ||||
|             ) | ||||
|         up_gate_proj_output_dim = self.moe_intermediate_size * 2 | ||||
|         if self.moe_quant_type in ["block_wise_fp8", "wint8"]: | ||||
|             up_gate_proj_weight_shape = [ | ||||
| @@ -535,19 +553,6 @@ class FusedMoE(nn.Layer): | ||||
|         """ | ||||
|         load_state_dict function. | ||||
|         """ | ||||
|         if not is_rearrange: | ||||
|             if self.moe_use_gate_correction_bias: | ||||
|                 gate_correction_bias_tensor = self.extract_gate_correction_bias( | ||||
|                     self.gate_correction_bias_key, state_dict | ||||
|                 ) | ||||
|                 if self.gate_correction_bias.shape != gate_correction_bias_tensor.shape: | ||||
|                     gate_correction_bias_tensor = gate_correction_bias_tensor.reshape(self.gate_correction_bias.shape) | ||||
|                 self.gate_correction_bias.set_value(gate_correction_bias_tensor) | ||||
|             else: | ||||
|                 self.gate_correction_bias = None | ||||
|         else: | ||||
|             self.gate_correction_bias = None | ||||
|  | ||||
|         if is_supported_moe_backend is not None and is_supported_moe_backend(self.quant_method): | ||||
|             if self.fd_config.model_config.is_quantized: | ||||
|                 if getattr(self.fd_config.quant_config, "is_permuted", True): | ||||
|   | ||||
| @@ -21,6 +21,11 @@ from typing import Optional | ||||
| import paddle | ||||
| from paddle.nn.quant import weight_only_linear, weight_quantize | ||||
|  | ||||
| from fastdeploy.model_executor.layers.linear import ( | ||||
|     MergedColumnParallelLinear, | ||||
|     QKVParallelLinear, | ||||
| ) | ||||
| from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs | ||||
| from fastdeploy.platforms import current_platform | ||||
|  | ||||
| from ..moe import FusedMoE | ||||
| @@ -135,9 +140,7 @@ class WINT8Config(WeightOnlyConfig): | ||||
|     weight only int8 config | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|     ) -> None: | ||||
|     def __init__(self) -> None: | ||||
|         super().__init__("weight_only_int8") | ||||
|  | ||||
|     @classmethod | ||||
| @@ -179,27 +182,89 @@ class WeightOnlyLinearMethod(QuantMethodBase): | ||||
|         self.quant_config = quant_config | ||||
|  | ||||
|     def create_weights(self, layer, **extra_weight_attrs): | ||||
|         if layer.fd_config.load_config.load_choices == "default_v1": | ||||
|             layer.weight = layer.create_parameter( | ||||
|                 shape=layer.weight_shape, | ||||
|                 dtype=layer.weight_dtype, | ||||
|                 is_bias=False, | ||||
|                 default_initializer=paddle.nn.initializer.Constant(0), | ||||
|             ) | ||||
|             quant_attrs = extra_weight_attrs | ||||
|             if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear): | ||||
|                 quant_attrs = { | ||||
|                     **extra_weight_attrs, | ||||
|                     "tensor_track": TensorTracker( | ||||
|                         shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim") | ||||
|                     ), | ||||
|                 } | ||||
|             set_weight_attrs( | ||||
|                 layer.weight, | ||||
|                 quant_attrs, | ||||
|             ) | ||||
|         else: | ||||
|             # The scale shape should be equal to the output dim of weight using Per-Channel Quantization. | ||||
|             weight_scale_shape = [layer.weight_shape[1]] | ||||
|             layer.weight_shape.reverse() | ||||
|             if self.quant_config.name() == "wint4": | ||||
|                 layer.weight_shape[0] //= 2 | ||||
|             layer.weight_dtype = "int8" | ||||
|             layer.weight = layer.create_parameter( | ||||
|                 shape=layer.weight_shape, | ||||
|                 dtype=layer.weight_dtype, | ||||
|                 is_bias=False, | ||||
|                 default_initializer=paddle.nn.initializer.Constant(0), | ||||
|             ) | ||||
|  | ||||
|         # The scale shape should be equal to the output dim of weight using Per-Channel Quantization. | ||||
|         weight_scale_shape = [layer.weight_shape[1]] | ||||
|             output_dim = extra_weight_attrs.get("output_dim") | ||||
|             output_dim = not output_dim | ||||
|             weight_loader = extra_weight_attrs.get("weight_loader") | ||||
|             set_weight_attrs( | ||||
|                 layer.weight, | ||||
|                 { | ||||
|                     "weight_loader": weight_loader, | ||||
|                     "output_dim": output_dim, | ||||
|                 }, | ||||
|             ) | ||||
|  | ||||
|         layer.weight_shape.reverse() | ||||
|         if self.quant_config.name() == "wint4": | ||||
|             layer.weight_shape[0] //= 2 | ||||
|         layer.weight_dtype = "int8" | ||||
|             layer.weight_scale = layer.create_parameter( | ||||
|                 shape=weight_scale_shape, | ||||
|                 dtype=layer._dtype, | ||||
|                 is_bias=False, | ||||
|             ) | ||||
|  | ||||
|             set_weight_attrs( | ||||
|                 layer.weight_scale, | ||||
|                 { | ||||
|                     "weight_loader": weight_loader, | ||||
|                     "output_dim": output_dim, | ||||
|                 }, | ||||
|             ) | ||||
|  | ||||
|     def process_weights_after_loading(self, layer) -> None: | ||||
|         if not layer.fd_config.load_config.load_choices == "default_v1": | ||||
|             return | ||||
|         quanted_weight_tensor, weight_scale_tensor = weight_quantize( | ||||
|             layer.weight, | ||||
|             algo=self.quant_config.algo, | ||||
|             arch=self.quant_config.weight_only_linear_arch, | ||||
|         ) | ||||
|  | ||||
|         free_tensor(layer.weight) | ||||
|  | ||||
|         layer.weight = layer.create_parameter( | ||||
|             shape=layer.weight_shape, | ||||
|             dtype=layer.weight_dtype, | ||||
|             shape=quanted_weight_tensor.shape, | ||||
|             dtype="int8", | ||||
|             is_bias=False, | ||||
|             default_initializer=paddle.nn.initializer.Constant(0), | ||||
|         ) | ||||
|  | ||||
|         layer.weight_scale = layer.create_parameter( | ||||
|             shape=weight_scale_shape, | ||||
|             shape=weight_scale_tensor.shape, | ||||
|             dtype=layer._dtype, | ||||
|             is_bias=False, | ||||
|             default_initializer=paddle.nn.initializer.Constant(0), | ||||
|         ) | ||||
|         layer.weight.copy_(quanted_weight_tensor, False) | ||||
|         layer.weight_scale.copy_(weight_scale_tensor, False) | ||||
|  | ||||
|     @abstractmethod | ||||
|     def process_loaded_weights(self, layer, weights) -> None: | ||||
|   | ||||
| @@ -15,7 +15,7 @@ | ||||
| """ | ||||
|  | ||||
| import functools | ||||
| from typing import Any, Optional, Tuple, Union | ||||
| from typing import Tuple, Union | ||||
|  | ||||
| import numpy as np | ||||
| import paddle | ||||
| @@ -45,14 +45,6 @@ if cache_params != "none": | ||||
|     c8_state_dict = paddle.load(cache_params, return_numpy=True) | ||||
|  | ||||
|  | ||||
| # TODO(lulinjun): delete it, import from fastdeploy.model_executor.models.utils after supporting all backends | ||||
| def set_weight_attrs(param, param_attr_map: Optional[dict[str, Any]]): | ||||
|     if param_attr_map is None: | ||||
|         return | ||||
|     for key, value in param_attr_map.items(): | ||||
|         setattr(param, key, value) | ||||
|  | ||||
|  | ||||
| def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Tensor, Tensor]: | ||||
|     """ | ||||
|     Only used in deep_gemm block wise quant weight. | ||||
|   | ||||
| @@ -14,8 +14,6 @@ | ||||
| # limitations under the License. | ||||
| """ | ||||
|  | ||||
| import contextlib | ||||
|  | ||||
| import paddle | ||||
| from paddle import nn | ||||
| from paddleformers.utils.log import logger | ||||
| @@ -56,15 +54,12 @@ class DefaultModelLoaderV1(BaseModelLoader): | ||||
|     def load_model(self, fd_config: FDConfig) -> nn.Layer: | ||||
|         architectures = fd_config.model_config.architectures[0] | ||||
|         logger.info(f"Starting to load model {architectures}") | ||||
|         context = paddle.LazyGuard() | ||||
|         if fd_config.load_config.dynamic_load_weight: | ||||
|             # register rl model | ||||
|             import fastdeploy.rl  # noqa | ||||
|  | ||||
|             architectures = architectures + "RL" | ||||
|             context = paddle.LazyGuard() | ||||
|  | ||||
|         else: | ||||
|             context = contextlib.nullcontext() | ||||
|  | ||||
|         with context: | ||||
|             model_cls = ModelRegistry.get_class(architectures) | ||||
| @@ -75,6 +70,5 @@ class DefaultModelLoaderV1(BaseModelLoader): | ||||
|         # RL model not need set_state_dict | ||||
|         if fd_config.load_config.dynamic_load_weight: | ||||
|             return model | ||||
|  | ||||
|         self.load_weights(model, fd_config) | ||||
|         return model | ||||
|   | ||||
| @@ -17,6 +17,7 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import math | ||||
| import re | ||||
| from functools import partial | ||||
|  | ||||
| import paddle | ||||
| @@ -122,6 +123,25 @@ class DeepSeekV3MoE(nn.Layer): | ||||
|             "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight", | ||||
|         } | ||||
|  | ||||
|         self.gate = ReplicatedLinear( | ||||
|             fd_config=fd_config, | ||||
|             prefix=f"{prefix}.gate", | ||||
|             input_size=fd_config.model_config.hidden_size, | ||||
|             output_size=fd_config.model_config.n_routed_experts, | ||||
|             with_bias=False, | ||||
|             skip_quant=True, | ||||
|             weight_dtype="float32", | ||||
|         ) | ||||
|  | ||||
|         if fd_config.model_config.topk_method == "noaux_tc": | ||||
|             self.gate.e_score_correction_bias = self.create_parameter( | ||||
|                 shape=[1, fd_config.model_config.n_routed_experts], | ||||
|                 dtype="float32", | ||||
|                 default_initializer=paddle.nn.initializer.Constant(0), | ||||
|             ) | ||||
|         else: | ||||
|             self.gate.e_score_correction_bias = None | ||||
|  | ||||
|         self.experts = FusedMoE( | ||||
|             fd_config=fd_config, | ||||
|             reduce_results=False, | ||||
| @@ -133,19 +153,10 @@ class DeepSeekV3MoE(nn.Layer): | ||||
|             n_group=fd_config.model_config.n_group, | ||||
|             routed_scaling_factor=fd_config.model_config.routed_scaling_factor, | ||||
|             layer_idx=layer_id, | ||||
|             gate_correction_bias=self.gate.e_score_correction_bias, | ||||
|             weight_key_map=weight_key_map, | ||||
|         ) | ||||
|  | ||||
|         self.gate = ReplicatedLinear( | ||||
|             fd_config=fd_config, | ||||
|             prefix=f"{prefix}.gate", | ||||
|             input_size=fd_config.model_config.hidden_size, | ||||
|             output_size=fd_config.model_config.n_routed_experts, | ||||
|             with_bias=False, | ||||
|             skip_quant=True, | ||||
|             weight_dtype="float32", | ||||
|         ) | ||||
|  | ||||
|         self.num_shared_experts = fd_config.model_config.n_shared_experts | ||||
|         shared_experts_intermediate_size = self.num_shared_experts * fd_config.model_config.moe_intermediate_size | ||||
|  | ||||
| @@ -258,6 +269,7 @@ class DeepseekV3MLAAttention(nn.Layer): | ||||
|  | ||||
|         self.kv_b_proj_bmm = KVBatchLinear( | ||||
|             fd_config=fd_config, | ||||
|             kv_b_proj=self.kv_b_proj, | ||||
|             prefix=f"{prefix}.kv_b_proj", | ||||
|             kv_lora_rank=self.kv_lora_rank, | ||||
|             num_attention_heads=self.num_attention_heads, | ||||
| @@ -617,7 +629,10 @@ class DeepseekV3ForCausalLM(ModelForCasualLM): | ||||
|         Args: | ||||
|             weights_iterator (Iterator): An iterator yielding (name, weight) pairs. | ||||
|         """ | ||||
|         from fastdeploy.model_executor.models.utils import default_weight_loader | ||||
|         from fastdeploy.model_executor.utils import ( | ||||
|             default_weight_loader, | ||||
|             process_weights_after_loading, | ||||
|         ) | ||||
|  | ||||
|         stacked_params_mapping = [ | ||||
|             # (param_name, shard_name, shard_id) | ||||
| @@ -637,7 +652,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM): | ||||
|             param_down_proj_name="experts.down_proj_", | ||||
|         ) | ||||
|         params_dict = dict(self.named_parameters()) | ||||
|  | ||||
|         process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers())) | ||||
|         for loaded_weight_name, loaded_weight in weights_iterator: | ||||
|             loaded_weight_name = loaded_weight_name.replace("deepseek_v3", "model") | ||||
|  | ||||
| @@ -668,19 +683,18 @@ class DeepseekV3ForCausalLM(ModelForCasualLM): | ||||
|                     weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id) | ||||
|                     break | ||||
|                 else: | ||||
|                     if loaded_weight_name not in params_dict: | ||||
|                     model_param_name = loaded_weight_name | ||||
|                     if model_param_name not in params_dict: | ||||
|                         continue | ||||
|                     param = params_dict[loaded_weight_name] | ||||
|                     param = params_dict[model_param_name] | ||||
|                     weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) | ||||
|                     weight_loader(param, loaded_weight) | ||||
|                     if "kv_b_proj.weight" in loaded_weight_name: | ||||
|                         # handle kv_b_proj_bmm | ||||
|                         model_param_name = loaded_weight_name.replace( | ||||
|                             "kv_b_proj.weight", "kv_b_proj_bmm.k_b_proj_weight" | ||||
|                         ) | ||||
|                         param = params_dict[model_param_name] | ||||
|                         weight_loader = getattr(param, "weight_loader", None) | ||||
|                         weight_loader(param, loaded_weight, shard_id) | ||||
|  | ||||
|             model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name) | ||||
|             if "kv_b_proj" in model_sublayer_name: | ||||
|                 kv_model_sublayer_name = model_sublayer_name.replace("kv_b_proj", "kv_b_proj_bmm") | ||||
|                 process_weights_after_loading_fn(kv_model_sublayer_name) | ||||
|             process_weights_after_loading_fn(model_sublayer_name, param) | ||||
|  | ||||
|     def compute_logits(self, hidden_states: paddle.Tensor): | ||||
|         """ """ | ||||
|   | ||||
| @@ -17,6 +17,7 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import inspect | ||||
| import re | ||||
| from functools import partial | ||||
| from typing import Dict, Union | ||||
|  | ||||
| @@ -149,15 +150,6 @@ class Ernie4_5_MoE(nn.Layer): | ||||
|                 "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight", | ||||
|             } | ||||
|  | ||||
|         self.experts = FusedMoE( | ||||
|             fd_config=fd_config, | ||||
|             moe_intermediate_size=fd_config.model_config.moe_intermediate_size, | ||||
|             num_experts=fd_config.model_config.moe_num_experts, | ||||
|             top_k=fd_config.model_config.moe_k, | ||||
|             layer_idx=layer_id, | ||||
|             weight_key_map=weight_key_map, | ||||
|         ) | ||||
|  | ||||
|         self.gate = ReplicatedLinear( | ||||
|             fd_config=fd_config, | ||||
|             prefix=f"{prefix}.gate", | ||||
| @@ -168,6 +160,25 @@ class Ernie4_5_MoE(nn.Layer): | ||||
|             weight_dtype="float32", | ||||
|         ) | ||||
|  | ||||
|         self.experts = FusedMoE( | ||||
|             fd_config=fd_config, | ||||
|             moe_intermediate_size=fd_config.model_config.moe_intermediate_size, | ||||
|             num_experts=fd_config.model_config.moe_num_experts, | ||||
|             top_k=fd_config.model_config.moe_k, | ||||
|             layer_idx=layer_id, | ||||
|             gate_correction_bias=None, | ||||
|             weight_key_map=weight_key_map, | ||||
|         ) | ||||
|  | ||||
|         if fd_config.model_config.moe_use_aux_free: | ||||
|             self.experts.gate_correction_bias = self.create_parameter( | ||||
|                 shape=[1, fd_config.model_config.moe_num_experts], | ||||
|                 dtype="float32", | ||||
|                 default_initializer=paddle.nn.initializer.Constant(0), | ||||
|             ) | ||||
|         else: | ||||
|             self.experts.gate_correction_bias = None | ||||
|  | ||||
|         self.num_shared_experts = fd_config.model_config.moe_num_shared_experts | ||||
|         if self.num_shared_experts > 0: | ||||
|             shared_experts_hidden_dim = self.num_shared_experts * fd_config.model_config.moe_intermediate_size | ||||
| @@ -180,6 +191,13 @@ class Ernie4_5_MoE(nn.Layer): | ||||
|     def load_state_dict(self, state_dict): | ||||
|         self.gate.load_state_dict(state_dict) | ||||
|         self.experts.load_state_dict(state_dict) | ||||
|         if self.experts.gate_correction_bias is not None: | ||||
|             gate_correction_bias_tensor = state_dict.pop(self.experts.gate_correction_bias_key) | ||||
|             if self.experts.gate_correction_bias.shape != gate_correction_bias_tensor.shape: | ||||
|                 gate_correction_bias_tensor = gate_correction_bias_tensor.reshape( | ||||
|                     self.experts.gate_correction_bias.shape | ||||
|                 ) | ||||
|             self.experts.gate_correction_bias.set_value(gate_correction_bias_tensor) | ||||
|         if self.num_shared_experts > 0: | ||||
|             self.shared_experts.load_state_dict(state_dict) | ||||
|  | ||||
| @@ -441,12 +459,16 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM): | ||||
|             weights_iterator (Iterator): An iterator yielding (name, weight) pairs. | ||||
|         """ | ||||
|  | ||||
|         from fastdeploy.model_executor.models.utils import default_weight_loader | ||||
|         from fastdeploy.model_executor.utils import ( | ||||
|             default_weight_loader, | ||||
|             process_weights_after_loading, | ||||
|         ) | ||||
|  | ||||
|         general_params_mapping = [ | ||||
|             # (param_name, weight_name, expert_id, shard_id) | ||||
|             ("embed_tokens.embeddings", "embed_tokens", None, None), | ||||
|             ("lm_head.linear", "lm_head", None, None), | ||||
|             ("experts.gate_correction_bias", "moe_statics.e_score_correction_bias", None, None), | ||||
|         ] | ||||
|  | ||||
|         expert_params_mapping = [] | ||||
| @@ -458,13 +480,10 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM): | ||||
|                 param_gate_up_proj_name="experts.up_gate_proj_", | ||||
|                 param_down_proj_name="experts.down_proj_", | ||||
|             ) | ||||
|             expert_params_mapping.append( | ||||
|                 ("experts.gate_correction_bias", "moe_statics.e_score_correction_bias", None, "gate_bias") | ||||
|             ) | ||||
|             logger.info(f"expert params mapping:{expert_params_mapping}") | ||||
|         all_param_mapping = general_params_mapping + expert_params_mapping | ||||
|  | ||||
|         params_dict = dict(self.named_parameters()) | ||||
|         process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers())) | ||||
|         expert_id = None | ||||
|         shard_id = None | ||||
|  | ||||
| @@ -478,9 +497,10 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM): | ||||
|                 shard_id = shard_id | ||||
|                 break | ||||
|             else: | ||||
|                 if loaded_weight_name not in params_dict.keys(): | ||||
|                 model_param_name = loaded_weight_name | ||||
|                 if model_param_name not in params_dict.keys(): | ||||
|                     continue | ||||
|                 param = params_dict[loaded_weight_name] | ||||
|                 param = params_dict[model_param_name] | ||||
|  | ||||
|             # Get weight loader from parameter and set weight | ||||
|             weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) | ||||
| @@ -490,6 +510,8 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM): | ||||
|             else: | ||||
|                 weight_loader(param, loaded_weight) | ||||
|  | ||||
|             model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name) | ||||
|             process_weights_after_loading_fn(model_sublayer_name, param) | ||||
|         if self.tie_word_embeddings: | ||||
|             self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) | ||||
|  | ||||
|   | ||||
| @@ -34,7 +34,7 @@ from paddle.nn.functional.flash_attention import ( | ||||
| from paddleformers.transformers.model_utils import PretrainedModel | ||||
|  | ||||
| from fastdeploy.model_executor.layers.utils import divide, get_tensor | ||||
| from fastdeploy.model_executor.models.utils import set_weight_attrs | ||||
| from fastdeploy.model_executor.utils import set_weight_attrs | ||||
|  | ||||
| from .activation import ACT2FN | ||||
| from .configuration import DFNRopeVisionTransformerConfig | ||||
|   | ||||
| @@ -17,6 +17,7 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import inspect | ||||
| import re | ||||
| from dataclasses import dataclass | ||||
| from functools import partial | ||||
| from typing import Dict, Optional, Union | ||||
| @@ -38,7 +39,6 @@ from fastdeploy.model_executor.layers.linear import ReplicatedLinear | ||||
| from fastdeploy.model_executor.layers.lm_head import ParallelLMHead | ||||
| from fastdeploy.model_executor.layers.moe.moe import FusedMoE | ||||
| from fastdeploy.model_executor.layers.normalization import RMSNorm | ||||
| from fastdeploy.model_executor.layers.utils import get_tensor | ||||
| from fastdeploy.model_executor.models.ernie4_5_moe import ( | ||||
|     Ernie4_5_Attention, | ||||
|     Ernie4_5_MLP, | ||||
| @@ -75,7 +75,15 @@ class VLMoEMeta: | ||||
|  | ||||
|  | ||||
| class Ernie4_5_VLMoeBlock(nn.Layer): | ||||
|     def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str, moe_tag: str, expert_id_offset: int) -> None: | ||||
|     def __init__( | ||||
|         self, | ||||
|         fd_config: FDConfig, | ||||
|         layer_id: int, | ||||
|         prefix: str, | ||||
|         moe_tag: str, | ||||
|         expert_id_offset: int, | ||||
|         gate_correction_bias=None, | ||||
|     ) -> None: | ||||
|         super().__init__() | ||||
|         moe_quant_type = "" | ||||
|         if hasattr(fd_config, "quant_config") and fd_config.quant_config is not None: | ||||
| @@ -120,6 +128,7 @@ class Ernie4_5_VLMoeBlock(nn.Layer): | ||||
|             layer_idx=layer_id, | ||||
|             moe_tag=moe_tag, | ||||
|             weight_key_map=weight_key_map, | ||||
|             gate_correction_bias=gate_correction_bias, | ||||
|         ) | ||||
|  | ||||
|         self.gate = ReplicatedLinear( | ||||
| @@ -133,29 +142,10 @@ class Ernie4_5_VLMoeBlock(nn.Layer): | ||||
|             weight_key="weight" if moe_tag == "Text" else "weight_1", | ||||
|         ) | ||||
|  | ||||
|         if moe_tag == "Text": | ||||
|             self.experts.extract_gate_correction_bias = self.extract_gate_correction_bias_text | ||||
|         elif moe_tag == "Image": | ||||
|             self.experts.extract_gate_correction_bias = self.extract_gate_correction_bias_image | ||||
|  | ||||
|     def forward(self, hidden_states: paddle.Tensor): | ||||
|         out = self.experts(hidden_states, self.gate) | ||||
|         return out | ||||
|  | ||||
|     def extract_gate_correction_bias_text(self, gate_correction_bias_key, state_dict): | ||||
|         """ | ||||
|         extract_gate_correction_bias function. | ||||
|         """ | ||||
|         gate_correction_bias_tensor = get_tensor(state_dict[gate_correction_bias_key]).astype("float32") | ||||
|         return gate_correction_bias_tensor[0].unsqueeze(0) | ||||
|  | ||||
|     def extract_gate_correction_bias_image(self, gate_correction_bias_key, state_dict): | ||||
|         """ | ||||
|         extract_gate_correction_bias function. | ||||
|         """ | ||||
|         gate_correction_bias_tensor = get_tensor(state_dict[gate_correction_bias_key]).astype("float32") | ||||
|         return gate_correction_bias_tensor[1].unsqueeze(0) | ||||
|  | ||||
|     def load_state_dict(self, state_dict): | ||||
|         self.experts.load_state_dict(state_dict) | ||||
|         self.gate.load_state_dict(state_dict) | ||||
| @@ -186,10 +176,25 @@ class Ernie4_5_VLMoE(nn.Layer): | ||||
|             image_moe_layer_end_index = moe_layer_end_index[1] | ||||
|  | ||||
|         assert text_moe_layer_start_index <= text_moe_layer_end_index | ||||
|         if fd_config.model_config.moe_use_aux_free: | ||||
|             self.gate_correction_bias = self.create_parameter( | ||||
|                 shape=[2, fd_config.model_config.moe_num_experts[0]], | ||||
|                 dtype="float32", | ||||
|                 default_initializer=paddle.nn.initializer.Constant(0), | ||||
|             ) | ||||
|             if not self.gate_correction_bias._is_initialized(): | ||||
|                 self.gate_correction_bias.initialize() | ||||
|         else: | ||||
|             self.gate_correction_bias = None | ||||
|  | ||||
|         if layer_id >= text_moe_layer_start_index and layer_id <= text_moe_layer_end_index: | ||||
|             self.text_fused_moe = Ernie4_5_VLMoeBlock( | ||||
|                 fd_config=fd_config, layer_id=layer_id, prefix=f"{prefix}", moe_tag="Text", expert_id_offset=0 | ||||
|                 fd_config=fd_config, | ||||
|                 layer_id=layer_id, | ||||
|                 prefix=f"{prefix}", | ||||
|                 moe_tag="Text", | ||||
|                 expert_id_offset=0, | ||||
|                 gate_correction_bias=self.gate_correction_bias[0] if fd_config.model_config.moe_use_aux_free else None, | ||||
|             ) | ||||
|         else: | ||||
|             self.text_fused_moe = Ernie4_5_VLMLP( | ||||
| @@ -207,6 +212,7 @@ class Ernie4_5_VLMoE(nn.Layer): | ||||
|                 prefix=f"{prefix}", | ||||
|                 moe_tag="Image", | ||||
|                 expert_id_offset=fd_config.model_config.moe_num_experts[0], | ||||
|                 gate_correction_bias=self.gate_correction_bias[1] if fd_config.model_config.moe_use_aux_free else None, | ||||
|             ) | ||||
|         else: | ||||
|             self.image_fused_moe = Ernie4_5_VLMLP( | ||||
| @@ -226,10 +232,13 @@ class Ernie4_5_VLMoE(nn.Layer): | ||||
|             ) | ||||
|  | ||||
|     def load_state_dict(self, state_dict): | ||||
|         if self.gate_correction_bias is not None: | ||||
|             gate_correction_bias_tensor = state_dict.pop(self.text_fused_moe.experts.gate_correction_bias_key) | ||||
|             if self.gate_correction_bias.shape != gate_correction_bias_tensor.shape: | ||||
|                 gate_correction_bias_tensor = gate_correction_bias_tensor.reshape(self.gate_correction_bias.shape) | ||||
|             self.gate_correction_bias.set_value(gate_correction_bias_tensor) | ||||
|         self.text_fused_moe.load_state_dict(state_dict) | ||||
|         self.image_fused_moe.load_state_dict(state_dict) | ||||
|         if self.text_fused_moe.experts.moe_use_gate_correction_bias: | ||||
|             state_dict.pop(self.text_fused_moe.experts.gate_correction_bias_key) | ||||
|         if self.num_shared_experts > 0: | ||||
|             self.shared_experts.load_state_dict(state_dict) | ||||
|  | ||||
| @@ -563,19 +572,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): | ||||
|     def name(self): | ||||
|         return "Ernie4_5_VLMoeForConditionalGeneration" | ||||
|  | ||||
|     def gate_correction_bias_loader(self, params_dict, loaded_weight_name, loaded_weight): | ||||
|         text_param_name = loaded_weight_name.replace( | ||||
|             "moe_statics.e_score_correction_bias", "text_fused_moe.experts.gate_correction_bias" | ||||
|         ) | ||||
|         image_param_name = loaded_weight_name.replace( | ||||
|             "moe_statics.e_score_correction_bias", "image_fused_moe.experts.gate_correction_bias" | ||||
|         ) | ||||
|         text_param = params_dict[text_param_name] | ||||
|         image_param = params_dict[image_param_name] | ||||
|         loaded_weight = get_tensor(loaded_weight) | ||||
|         text_param.copy_(loaded_weight[0].unsqueeze(0), False) | ||||
|         image_param.copy_(loaded_weight[1].unsqueeze(0), False) | ||||
|  | ||||
|     @paddle.no_grad() | ||||
|     def load_weights(self, weights_iterator) -> None: | ||||
|         """ | ||||
| @@ -585,7 +581,10 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): | ||||
|             weights_iterator (Iterator): An iterator yielding (name, weight) pairs. | ||||
|         """ | ||||
|  | ||||
|         from fastdeploy.model_executor.models.utils import default_weight_loader | ||||
|         from fastdeploy.model_executor.utils import ( | ||||
|             default_weight_loader, | ||||
|             process_weights_after_loading, | ||||
|         ) | ||||
|  | ||||
|         general_params_mapping = [ | ||||
|             # (param_name, weight_name, expert_id, shard_id) | ||||
| @@ -594,6 +593,8 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): | ||||
|             ("mlp.image_fused_moe.gate.weight", "mlp.gate.weight_1", None, "gate"), | ||||
|             ("mlp.text_fused_moe.gate.weight", "mlp.gate.weight", None, "gate"), | ||||
|             ("resampler_model", "ernie.resampler_model", None, None), | ||||
|             ("vision_model", "ernie.vision_model", None, None), | ||||
|             ("gate_correction_bias", "moe_statics.e_score_correction_bias", None, None), | ||||
|         ] | ||||
|  | ||||
|         text_expert_params_mapping = [] | ||||
| @@ -617,6 +618,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): | ||||
|         all_param_mapping = general_params_mapping + text_expert_params_mapping + image_expert_params_mapping | ||||
|  | ||||
|         params_dict = dict(self.named_parameters()) | ||||
|         process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers())) | ||||
|         expert_id = None | ||||
|         shard_id = None | ||||
|         for loaded_weight_name, loaded_weight in weights_iterator: | ||||
| @@ -629,10 +631,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): | ||||
|                 shard_id = shard_id | ||||
|                 break | ||||
|             else: | ||||
|                 # text and image gate_correction_bias is fused in ckpt and need load independently | ||||
|                 if "moe_statics.e_score_correction_bias" in loaded_weight_name: | ||||
|                     self.gate_correction_bias_loader(params_dict, loaded_weight_name, loaded_weight) | ||||
|                     continue | ||||
|                 if loaded_weight_name not in params_dict.keys(): | ||||
|                     continue | ||||
|                 model_param_name = loaded_weight_name | ||||
| @@ -646,7 +644,8 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): | ||||
|                 weight_loader(param, loaded_weight, expert_id=expert_id, shard_id=shard_id) | ||||
|             else: | ||||
|                 weight_loader(param, loaded_weight) | ||||
|  | ||||
|             model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name) | ||||
|             process_weights_after_loading_fn(model_sublayer_name, param) | ||||
|         if self.tie_word_embeddings: | ||||
|             self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) | ||||
|  | ||||
|   | ||||
| @@ -30,7 +30,7 @@ from fastdeploy.model_executor.models.ernie4_5_vl.dist_utils import ( | ||||
|     reduce_scatter_group, | ||||
|     scatter_axis, | ||||
| ) | ||||
| from fastdeploy.model_executor.models.utils import set_weight_attrs | ||||
| from fastdeploy.model_executor.utils import set_weight_attrs | ||||
|  | ||||
|  | ||||
| class ScatterOp(PyLayer): | ||||
|   | ||||
| @@ -16,6 +16,7 @@ | ||||
|  | ||||
| from __future__ import annotations | ||||
|  | ||||
| import re | ||||
| from functools import partial | ||||
|  | ||||
| import paddle | ||||
| @@ -254,7 +255,10 @@ class Qwen3ForCausalLM(ModelForCasualLM): | ||||
|             weights_iterator (Iterator): An iterator yielding (name, weight) pairs. | ||||
|         """ | ||||
|  | ||||
|         from fastdeploy.model_executor.models.utils import default_weight_loader | ||||
|         from fastdeploy.model_executor.utils import ( | ||||
|             default_weight_loader, | ||||
|             process_weights_after_loading, | ||||
|         ) | ||||
|  | ||||
|         stacked_params_mapping = [ | ||||
|             # (param_name, shard_name, shard_id) | ||||
| @@ -266,8 +270,8 @@ class Qwen3ForCausalLM(ModelForCasualLM): | ||||
|             ("embed_tokens.embeddings", "embed_tokens", None), | ||||
|             ("lm_head.linear", "lm_head", None), | ||||
|         ] | ||||
|  | ||||
|         params_dict = dict(self.named_parameters()) | ||||
|         process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers())) | ||||
|         for loaded_weight_name, loaded_weight in weights_iterator: | ||||
|             for param_name, weight_name, shard_id in stacked_params_mapping: | ||||
|                 if weight_name not in loaded_weight_name: | ||||
| @@ -280,11 +284,14 @@ class Qwen3ForCausalLM(ModelForCasualLM): | ||||
|                 weight_loader(param, loaded_weight, shard_id) | ||||
|                 break | ||||
|             else: | ||||
|                 if loaded_weight_name not in params_dict: | ||||
|                 model_param_name = loaded_weight_name | ||||
|                 if model_param_name not in params_dict: | ||||
|                     continue | ||||
|                 param = params_dict[loaded_weight_name] | ||||
|                 param = params_dict[model_param_name] | ||||
|                 weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) | ||||
|                 weight_loader(param, loaded_weight) | ||||
|             model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name) | ||||
|             process_weights_after_loading_fn(model_sublayer_name, param) | ||||
|  | ||||
|         if self.tie_word_embeddings: | ||||
|             self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0])) | ||||
|   | ||||
| @@ -16,6 +16,7 @@ | ||||
|  | ||||
| from __future__ import annotations | ||||
|  | ||||
| import re | ||||
| from functools import partial | ||||
|  | ||||
| import paddle | ||||
| @@ -334,7 +335,10 @@ class Qwen3MoeForCausalLM(ModelForCasualLM): | ||||
|             weights_iterator (Iterator): An iterator yielding (name, weight) pairs. | ||||
|         """ | ||||
|  | ||||
|         from fastdeploy.model_executor.models.utils import default_weight_loader | ||||
|         from fastdeploy.model_executor.utils import ( | ||||
|             default_weight_loader, | ||||
|             process_weights_after_loading, | ||||
|         ) | ||||
|  | ||||
|         stacked_params_mapping = [ | ||||
|             # (param_name, shard_name, shard_id) | ||||
| @@ -348,6 +352,7 @@ class Qwen3MoeForCausalLM(ModelForCasualLM): | ||||
|         ] | ||||
|         expert_params_mapping = self.get_expert_mapping() | ||||
|         params_dict = dict(self.named_parameters()) | ||||
|         process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers())) | ||||
|         for loaded_weight_name, loaded_weight in weights_iterator: | ||||
|             for param_name, weight_name, shard_id in stacked_params_mapping: | ||||
|                 if weight_name not in loaded_weight_name: | ||||
| @@ -374,12 +379,16 @@ class Qwen3MoeForCausalLM(ModelForCasualLM): | ||||
|                     weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id) | ||||
|                     break | ||||
|                 else: | ||||
|                     if loaded_weight_name not in params_dict: | ||||
|                     model_param_name = loaded_weight_name | ||||
|                     if model_param_name not in params_dict: | ||||
|                         continue | ||||
|                     param = params_dict[loaded_weight_name] | ||||
|                     param = params_dict[model_param_name] | ||||
|                     weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) | ||||
|                     weight_loader(param, loaded_weight) | ||||
|  | ||||
|             model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name) | ||||
|             process_weights_after_loading_fn(model_sublayer_name, param) | ||||
|  | ||||
|     @paddle.no_grad() | ||||
|     def set_state_dict(self, state_dict): | ||||
|         """ | ||||
|   | ||||
| @@ -24,7 +24,7 @@ import random | ||||
| import re | ||||
| import struct | ||||
| from functools import partial | ||||
| from typing import Any, NamedTuple, Optional, Union | ||||
| from typing import NamedTuple, Optional | ||||
|  | ||||
| import numpy as np | ||||
| import paddle | ||||
| @@ -40,73 +40,10 @@ from paddleformers.utils.env import ( | ||||
| from paddleformers.utils.log import logger | ||||
| from tqdm import tqdm | ||||
|  | ||||
| from fastdeploy.config import FDConfig | ||||
| from fastdeploy.model_executor.layers.utils import get_tensor | ||||
|  | ||||
| MAX_BSZ = 512 | ||||
| MAX_DRAFT_TOKENS = 6 | ||||
|  | ||||
|  | ||||
| def set_weight_attrs(param, param_attr_map: Optional[dict[str, Any]]): | ||||
|     if param_attr_map is None: | ||||
|         return | ||||
|     for key, value in param_attr_map.items(): | ||||
|         setattr(param, key, value) | ||||
|  | ||||
|  | ||||
| def slice_fn(weight_or_paramter, output_dim, start, end, step=1): | ||||
|     if hasattr(weight_or_paramter, "get_shape"): | ||||
|         shape = weight_or_paramter.get_shape() | ||||
|     else: | ||||
|         shape = weight_or_paramter.shape | ||||
|     if len(shape) == 1: | ||||
|         weight_or_paramter = weight_or_paramter[start:end] | ||||
|     elif output_dim: | ||||
|         weight_or_paramter = weight_or_paramter[..., start:end] | ||||
|     else: | ||||
|         weight_or_paramter = weight_or_paramter[start:end, ...] | ||||
|     return weight_or_paramter | ||||
|  | ||||
|  | ||||
| def default_weight_loader(fd_config: FDConfig) -> None: | ||||
|     """Default weight loader""" | ||||
|  | ||||
|     def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None): | ||||
|         """fn""" | ||||
|         try: | ||||
|             output_dim = getattr(param, "output_dim", None) | ||||
|             # Tensor parallelism splits the weight along the output_dim | ||||
|             if output_dim is not None: | ||||
|                 dim = -1 if output_dim else 0 | ||||
|                 size = loaded_weight.get_shape()[dim] | ||||
|                 block_size = size // fd_config.parallel_config.tensor_parallel_size | ||||
|                 shard_offset = fd_config.parallel_config.tensor_parallel_rank * block_size | ||||
|                 shard_size = (fd_config.parallel_config.tensor_parallel_rank + 1) * block_size | ||||
|                 if output_dim: | ||||
|                     loaded_weight = loaded_weight[..., shard_offset:shard_size] | ||||
|                 else: | ||||
|                     loaded_weight = loaded_weight[shard_offset:shard_size, ...] | ||||
|  | ||||
|             loaded_weight = get_tensor(loaded_weight) | ||||
|             # mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation | ||||
|             if param.dtype != loaded_weight.dtype: | ||||
|                 loaded_weight = loaded_weight.cast(param.dtype) | ||||
|  | ||||
|             if param.shape != loaded_weight.shape: | ||||
|                 try: | ||||
|                     param = param.reshape(loaded_weight.shape) | ||||
|                 except ValueError as e: | ||||
|                     raise ValueError( | ||||
|                         f" Attempted to load weight ({loaded_weight.shape}) into parameter ({param.shape}). {e}" | ||||
|                     ) | ||||
|  | ||||
|             param.copy_(loaded_weight, False) | ||||
|         except Exception: | ||||
|             raise | ||||
|  | ||||
|     return fn | ||||
|  | ||||
|  | ||||
| class LayerIdPlaceholder(str, enum.Enum): | ||||
|     """LayerIdPlaceholder""" | ||||
|  | ||||
|   | ||||
							
								
								
									
										179
									
								
								fastdeploy/model_executor/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										179
									
								
								fastdeploy/model_executor/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,179 @@ | ||||
| """ | ||||
| # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| """ | ||||
|  | ||||
| from typing import Any, Optional, Union | ||||
|  | ||||
| from fastdeploy.config import FDConfig | ||||
| from fastdeploy.model_executor.layers.utils import get_tensor | ||||
|  | ||||
|  | ||||
| class BitMaskTracker: | ||||
|     def __init__(self, length: int): | ||||
|         """ | ||||
|         Track filling status along a single dimension using a bitmask. | ||||
|  | ||||
|         Args: | ||||
|             length (int): Number of positions to track (e.g., columns or rows) | ||||
|         """ | ||||
|         self.length = length | ||||
|         self.mask = 0 | ||||
|  | ||||
|     def mark(self, start: int, end: int): | ||||
|         """ | ||||
|         Mark the range [start, end) as filled. | ||||
|  | ||||
|         Args: | ||||
|             start (int): Start index (inclusive) | ||||
|             end (int): End index (exclusive) | ||||
|         """ | ||||
|         if start < 0 or end > self.length or start >= end: | ||||
|             raise ValueError("Invalid mark range") | ||||
|         block = ((1 << (end - start)) - 1) << start | ||||
|         self.mask |= block | ||||
|  | ||||
|     def is_full(self) -> bool: | ||||
|         """Return True if all positions are filled.""" | ||||
|         return self.mask == (1 << self.length) - 1 | ||||
|  | ||||
|  | ||||
| class TensorTracker: | ||||
|     def __init__(self, shape: tuple, output_dim: int): | ||||
|         """ | ||||
|         Unified tracker for 2D or 3D tensors. | ||||
|  | ||||
|         Args: | ||||
|             shape (tuple): Tensor shape | ||||
|             output_dim (bool): | ||||
|                 - 2D: True = track columns (dim=1), False = track rows (dim=0) | ||||
|                 - 3D: True = track columns (dim=2), False = track rows (dim=1) | ||||
|         """ | ||||
|         self.shape = shape | ||||
|         self.output_dim = output_dim | ||||
|  | ||||
|         if len(shape) == 2: | ||||
|             self.track_dim = 1 if output_dim else 0 | ||||
|             self.trackers = [BitMaskTracker(shape[self.track_dim])] | ||||
|         elif len(shape) == 3: | ||||
|             batch = shape[0] | ||||
|             self.track_dim = 2 if output_dim else 1 | ||||
|             self.trackers = [BitMaskTracker(shape[self.track_dim]) for _ in range(batch)] | ||||
|         else: | ||||
|             raise ValueError("Only 2D or 3D tensors supported") | ||||
|  | ||||
|     def mark(self, start: int = 0, end: int = None, batch_id: int = None): | ||||
|         """ | ||||
|         Mark a slice of the tensor as filled. | ||||
|  | ||||
|         Args: | ||||
|             batch_id (int, optional): Batch index for 3D tensors | ||||
|             start (int): Start index along tracked dimension | ||||
|             end (int): End index along tracked dimension | ||||
|         """ | ||||
|         if end is None: | ||||
|             end = self.shape[self.track_dim] | ||||
|  | ||||
|         if len(self.shape) == 2: | ||||
|             self.trackers[0].mark(start, end) | ||||
|         else: | ||||
|             if batch_id is None: | ||||
|                 raise ValueError("batch_id must be provided for 3D tensor") | ||||
|             self.trackers[batch_id].mark(start, end) | ||||
|  | ||||
|     def is_fully_copied(self) -> bool: | ||||
|         """Return True if the tensor is fully filled along tracked dimension(s).""" | ||||
|         return all(tr.is_full() for tr in self.trackers) | ||||
|  | ||||
|  | ||||
| def set_weight_attrs(param, param_attr_map: Optional[dict[str, Any]]): | ||||
|     if param_attr_map is None: | ||||
|         return | ||||
|     for key, value in param_attr_map.items(): | ||||
|         setattr(param, key, value) | ||||
|  | ||||
|  | ||||
| def slice_fn(weight_or_paramter, output_dim, start, end, step=1): | ||||
|     if hasattr(weight_or_paramter, "get_shape"): | ||||
|         shape = weight_or_paramter.get_shape() | ||||
|     else: | ||||
|         shape = weight_or_paramter.shape | ||||
|     if len(shape) == 1: | ||||
|         weight_or_paramter = weight_or_paramter[start:end] | ||||
|     elif output_dim: | ||||
|         weight_or_paramter = weight_or_paramter[..., start:end] | ||||
|     else: | ||||
|         weight_or_paramter = weight_or_paramter[start:end, ...] | ||||
|     return weight_or_paramter | ||||
|  | ||||
|  | ||||
| def process_weights_after_loading(sublayers_dict: dict): | ||||
|     """ | ||||
|     process_weights_after_loading: e.g., handle extracted weights (quantization, reshaping, etc.) | ||||
|     """ | ||||
|  | ||||
|     def fn(model_sublayer_name: str, param=None): | ||||
|         from fastdeploy.model_executor.layers.linear import KVBatchLinear | ||||
|  | ||||
|         if model_sublayer_name not in sublayers_dict: | ||||
|             return | ||||
|         model_sublayer = sublayers_dict[model_sublayer_name] | ||||
|         if isinstance(model_sublayer, KVBatchLinear): | ||||
|             model_sublayer.process_weights_after_loading() | ||||
|         if hasattr(model_sublayer, "quant_method"): | ||||
|             quant_method = getattr(model_sublayer, "quant_method", None) | ||||
|             if not hasattr(quant_method, "process_weights_after_loading"): | ||||
|                 return | ||||
|             if param is not None and hasattr(param, "tensor_track") and not param.tensor_track.is_fully_copied(): | ||||
|                 return | ||||
|             quant_method.process_weights_after_loading(model_sublayer) | ||||
|  | ||||
|     return fn | ||||
|  | ||||
|  | ||||
| def free_tensor(tensor): | ||||
|     if hasattr(tensor, "tensor_track"): | ||||
|         tensor.tensor_track = None | ||||
|     tensor.value().get_tensor()._clear() | ||||
|     del tensor | ||||
|  | ||||
|  | ||||
| def default_weight_loader(fd_config: FDConfig) -> None: | ||||
|     """Default weight loader""" | ||||
|  | ||||
|     def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None): | ||||
|         """fn""" | ||||
|         output_dim = getattr(param, "output_dim", None) | ||||
|         # Tensor parallelism splits the weight along the output_dim | ||||
|         if output_dim is not None and fd_config.parallel_config.tensor_parallel_size > 1: | ||||
|             dim = -1 if output_dim else 0 | ||||
|             size = loaded_weight.get_shape()[dim] | ||||
|             block_size = size // fd_config.parallel_config.tensor_parallel_size | ||||
|             shard_offset = fd_config.parallel_config.tensor_parallel_rank * block_size | ||||
|             shard_size = (fd_config.parallel_config.tensor_parallel_rank + 1) * block_size | ||||
|             loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size) | ||||
|  | ||||
|         loaded_weight = get_tensor(loaded_weight) | ||||
|         # mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation | ||||
|         if param.dtype != loaded_weight.dtype: | ||||
|             loaded_weight = loaded_weight.cast(param.dtype) | ||||
|         if param.shape != loaded_weight.shape: | ||||
|             # for e_score_correction_bias | ||||
|             loaded_weight = loaded_weight.reshape(param.shape) | ||||
|         assert param.shape == loaded_weight.shape, ( | ||||
|             f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" | ||||
|         ) | ||||
|         param.copy_(loaded_weight, False) | ||||
|  | ||||
|     return fn | ||||
		Reference in New Issue
	
	Block a user
	 bukejiyu
					bukejiyu