mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 03:46:40 +08:00 
			
		
		
		
	Support 45t fp8 8 GPU (#3659)
This commit is contained in:
		| @@ -31,7 +31,8 @@ __global__ void quant_per_token_per_block(const T *input, | |||||||
|     const int num_warp = blockDim.x / 32; |     const int num_warp = blockDim.x / 32; | ||||||
|     static constexpr int NUM_PER_THREADS = 128 / 32; // 4 |     static constexpr int NUM_PER_THREADS = 128 / 32; // 4 | ||||||
|     static constexpr float MAX_VALUE = 448.f; |     static constexpr float MAX_VALUE = 448.f; | ||||||
|     const int end_iter = hidden_size / 128; // warp_iter_num |     // Note(ZKK) use ceil_div!! | ||||||
|  |     const int end_iter = (hidden_size + 127) / 128; // warp_iter_num | ||||||
|     AlignedVector<T, NUM_PER_THREADS> load_vec; |     AlignedVector<T, NUM_PER_THREADS> load_vec; | ||||||
|     AlignedVector<float, NUM_PER_THREADS> load_vec_float; |     AlignedVector<float, NUM_PER_THREADS> load_vec_float; | ||||||
|     AlignedVector<phi::dtype::float8_e4m3fn, NUM_PER_THREADS> res_vec; |     AlignedVector<phi::dtype::float8_e4m3fn, NUM_PER_THREADS> res_vec; | ||||||
| @@ -42,7 +43,16 @@ __global__ void quant_per_token_per_block(const T *input, | |||||||
|         // deal a block per warp |         // deal a block per warp | ||||||
|         for (int iter = warp_id; iter < end_iter; iter += num_warp) { |         for (int iter = warp_id; iter < end_iter; iter += num_warp) { | ||||||
|             const int start_offset = iter * 128; |             const int start_offset = iter * 128; | ||||||
|  |  | ||||||
|  |  | ||||||
|  |             const bool is_valid_data = start_offset + lane_id * NUM_PER_THREADS < hidden_size; | ||||||
|  |  | ||||||
|  |             if (is_valid_data) { | ||||||
|                 Load<T, NUM_PER_THREADS>(input_now + start_offset + lane_id * NUM_PER_THREADS, &load_vec); |                 Load<T, NUM_PER_THREADS>(input_now + start_offset + lane_id * NUM_PER_THREADS, &load_vec); | ||||||
|  |             } else { | ||||||
|  |                 #pragma unroll | ||||||
|  |                 for (int vid = 0; vid < NUM_PER_THREADS; vid++) load_vec[vid] = T(0.f); | ||||||
|  |             } | ||||||
|             // get max value per thread |             // get max value per thread | ||||||
|             float max_value_thread = -5e4; |             float max_value_thread = -5e4; | ||||||
| #pragma unroll | #pragma unroll | ||||||
| @@ -71,6 +81,7 @@ __global__ void quant_per_token_per_block(const T *input, | |||||||
|                 res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(load_vec_float[vid] * MAX_VALUE / max_value_thread); |                 res_vec[vid] = static_cast<phi::dtype::float8_e4m3fn>(load_vec_float[vid] * MAX_VALUE / max_value_thread); | ||||||
|             } |             } | ||||||
|             // store |             // store | ||||||
|  |             if (is_valid_data) | ||||||
|             Store<phi::dtype::float8_e4m3fn, NUM_PER_THREADS>(res_vec, quanted_res_now + start_offset + lane_id * NUM_PER_THREADS); |             Store<phi::dtype::float8_e4m3fn, NUM_PER_THREADS>(res_vec, quanted_res_now + start_offset + lane_id * NUM_PER_THREADS); | ||||||
|             if (lane_id == 0) { |             if (lane_id == 0) { | ||||||
|                 quanted_scale_now[iter] = scale_to_store; |                 quanted_scale_now[iter] = scale_to_store; | ||||||
| @@ -84,7 +95,10 @@ std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor& input, | |||||||
|     auto input_dim = input.dims(); |     auto input_dim = input.dims(); | ||||||
|     const int token_num = input_dim[0]; |     const int token_num = input_dim[0]; | ||||||
|     const int hidden_size = input_dim[1]; |     const int hidden_size = input_dim[1]; | ||||||
|     const int hidden_size_scale = hidden_size / block_size; |     // Note(ZKK) here we use ceil_dive to support 4.5T runing on 8 GPUS | ||||||
|  |     // where moe_intermediate_size is 448, can not be divided by 128. | ||||||
|  |     const int hidden_size_scale = (hidden_size + block_size - 1) / block_size; | ||||||
|  |  | ||||||
|     auto quanted_x = GetEmptyTensor( |     auto quanted_x = GetEmptyTensor( | ||||||
|         {token_num, hidden_size}, |         {token_num, hidden_size}, | ||||||
|         paddle::DataType::FLOAT8_E4M3FN, |         paddle::DataType::FLOAT8_E4M3FN, | ||||||
|   | |||||||
| @@ -617,13 +617,13 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): | |||||||
|         ] |         ] | ||||||
|         self.up_gate_proj_scale_shape = [ |         self.up_gate_proj_scale_shape = [ | ||||||
|             layer.num_local_experts, |             layer.num_local_experts, | ||||||
|             layer.moe_intermediate_size * 2 // self.quant_config.weight_block_size[0], |             ceil_div(layer.moe_intermediate_size * 2, self.quant_config.weight_block_size[0]), | ||||||
|             layer.hidden_size // self.quant_config.weight_block_size[1], |             ceil_div(layer.hidden_size, self.quant_config.weight_block_size[1]), | ||||||
|         ] |         ] | ||||||
|         self.down_proj_scale_shape = [ |         self.down_proj_scale_shape = [ | ||||||
|             layer.num_local_experts, |             layer.num_local_experts, | ||||||
|             layer.hidden_size // self.quant_config.weight_block_size[0], |             ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]), | ||||||
|             layer.moe_intermediate_size // self.quant_config.weight_block_size[1], |             ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]), | ||||||
|         ] |         ] | ||||||
|         if self.quant_config.is_checkpoint_bf16: |         if self.quant_config.is_checkpoint_bf16: | ||||||
|             layer.up_gate_proj_weight = layer.create_parameter( |             layer.up_gate_proj_weight = layer.create_parameter( | ||||||
|   | |||||||
| @@ -14,14 +14,11 @@ | |||||||
| # limitations under the License. | # limitations under the License. | ||||||
| """ | """ | ||||||
|  |  | ||||||
|  | import triton | ||||||
| import triton.language as tl | import triton.language as tl | ||||||
|  |  | ||||||
| from fastdeploy.model_executor.ops.triton_ops.triton_utils_v2 import ( |  | ||||||
|     paddle_use_triton_v2, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  | @triton.jit() | ||||||
| @paddle_use_triton_v2() |  | ||||||
| def fused_moe_kernel_paddle( | def fused_moe_kernel_paddle( | ||||||
|     a_ptr, |     a_ptr, | ||||||
|     b_ptr, |     b_ptr, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 周周周
					周周周