mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-03 02:53:26 +08:00
[Metax] adapt DeepSeek (#4498)
This commit is contained in:
@@ -14,7 +14,9 @@
|
||||
#include "cute/tensor.hpp"
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||
#include "paddle/phi/core/memory/memcpy.h"
|
||||
#endif
|
||||
#include "utils.cuh"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <stdint.h>
|
||||
#include <cooperative_groups/memcpy_async.h>
|
||||
|
||||
enum class SharedMemFillMode { kFillZero, kNoFill };
|
||||
|
||||
@@ -42,18 +43,35 @@ __device__ __forceinline__ void ldmatrix_m8n8x4_trans_impl(uint32_t* R,
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void commit_group() {
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||
{}
|
||||
#else
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <size_t n>
|
||||
__device__ __forceinline__ void wait_group() {
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||
cooperative_groups::wait(cooperative_groups::this_thread_block());
|
||||
#else
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <PrefetchMode prefetch_mode, typename T>
|
||||
__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) {
|
||||
uint32_t smem_int_ptr =
|
||||
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
|
||||
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
|
||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
|
||||
} else {
|
||||
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
|
||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
|
||||
}
|
||||
#else
|
||||
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
|
||||
asm volatile(
|
||||
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(
|
||||
@@ -68,6 +86,7 @@ __device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) {
|
||||
"n"(16),
|
||||
"r"(16));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
|
||||
@@ -76,6 +95,28 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr,
|
||||
bool predicate) {
|
||||
uint32_t smem_int_ptr =
|
||||
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
|
||||
int src_in_bytes = predicate ? 16 : 0;
|
||||
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
|
||||
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
|
||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
|
||||
} else {
|
||||
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
|
||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
|
||||
}
|
||||
} else {
|
||||
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
|
||||
if (predicate) {
|
||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
|
||||
}
|
||||
} else {
|
||||
if (predicate) {
|
||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
|
||||
int src_in_bytes = predicate ? 16 : 0;
|
||||
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
|
||||
@@ -115,6 +156,7 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr,
|
||||
"n"(16));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
|
||||
@@ -123,6 +165,17 @@ __device__ __forceinline__ void pred_load_64b(T* smem_ptr,
|
||||
bool predicate) {
|
||||
uint32_t smem_int_ptr =
|
||||
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
|
||||
int src_in_bytes = predicate ? 8 : 0;
|
||||
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 8);
|
||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
|
||||
} else {
|
||||
if (predicate) {
|
||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 8);
|
||||
}
|
||||
}
|
||||
#else
|
||||
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
|
||||
int src_in_bytes = predicate ? 8 : 0;
|
||||
asm volatile(
|
||||
@@ -141,6 +194,7 @@ __device__ __forceinline__ void pred_load_64b(T* smem_ptr,
|
||||
"l"(gmem_ptr),
|
||||
"n"(8));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
|
||||
@@ -149,6 +203,17 @@ __device__ __forceinline__ void pred_load_32b(T* smem_ptr,
|
||||
bool predicate) {
|
||||
uint32_t smem_int_ptr =
|
||||
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
|
||||
int src_in_bytes = predicate ? 4 : 0;
|
||||
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 4);
|
||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
|
||||
} else {
|
||||
if (predicate) {
|
||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 4);
|
||||
}
|
||||
}
|
||||
#else
|
||||
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
|
||||
int src_in_bytes = predicate ? 4 : 0;
|
||||
asm volatile(
|
||||
@@ -167,6 +232,7 @@ __device__ __forceinline__ void pred_load_32b(T* smem_ptr,
|
||||
"l"(gmem_ptr),
|
||||
"n"(4));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <size_t num_bits, PrefetchMode prefetch_mode, typename T>
|
||||
|
||||
@@ -595,10 +595,13 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
#endif
|
||||
|
||||
inline int GetSMVersion() {
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||
return 80;
|
||||
#else
|
||||
static int sm_version = phi::backends::gpu::GetGPUComputeCapability(
|
||||
phi::backends::gpu::GetCurrentDeviceId());
|
||||
return sm_version;
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
inline bool GetMlaUseTensorcore() {
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include "helper.h"
|
||||
#include <cuda/std/limits>
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
@@ -601,7 +602,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
if (i < topk) {
|
||||
s_topk_value[i] = value;
|
||||
}
|
||||
topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||
topk_sum += cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -658,6 +659,11 @@ void invokeNoAuxTc(T* scores,
|
||||
cudaStream_t const stream) {
|
||||
int64_t num_cases = num_tokens * n_group;
|
||||
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
|
||||
group_scores, scores_with_bias, num_tokens, num_cases, n_group, num_experts / n_group);
|
||||
#else
|
||||
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = topk_with_k2_num_blocks;
|
||||
@@ -671,6 +677,7 @@ void invokeNoAuxTc(T* scores,
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
|
||||
num_tokens, num_cases, n_group, num_experts / n_group);
|
||||
#endif
|
||||
|
||||
int64_t topk_with_k_group_num_blocks =
|
||||
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
@@ -678,6 +685,12 @@ void invokeNoAuxTc(T* scores,
|
||||
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
||||
topk);
|
||||
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||
group_idx_and_topk_idx_kernel<T, IdxT><<<topk_with_k_group_num_blocks, BLOCK_SIZE, dynamic_smem_in_bytes, stream>>>(
|
||||
scores, group_scores, topk_values, topk_indices, scores_with_bias,
|
||||
num_tokens, n_group, topk_group, topk, num_experts, num_experts / n_group,
|
||||
renormalize, routed_scaling_factor);
|
||||
#else
|
||||
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
|
||||
config.gridDim = topk_with_k_group_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
@@ -691,6 +704,7 @@ void invokeNoAuxTc(T* scores,
|
||||
topk_values, topk_indices, scores_with_bias, num_tokens,
|
||||
n_group, topk_group, topk, num_experts,
|
||||
num_experts / n_group, renormalize, routed_scaling_factor);
|
||||
#endif
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
||||
|
||||
@@ -601,9 +601,16 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
"gpu_ops/read_data_ipc.cu",
|
||||
"gpu_ops/dequant_int8.cu",
|
||||
"gpu_ops/share_external_data.cu",
|
||||
"gpu_ops/recover_decode_task.cu",
|
||||
"gpu_ops/noaux_tc.cu",
|
||||
"gpu_ops/fused_rotary_position_encoding.cu",
|
||||
"gpu_ops/text_image_gather_scatter.cu",
|
||||
"gpu_ops/text_image_index_out.cu",
|
||||
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
|
||||
"gpu_ops/append_attn/mla_cache_kernel.cu",
|
||||
"gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu",
|
||||
"gpu_ops/moe/tritonmoe_preprocess.cu",
|
||||
"gpu_ops/moe/moe_topk_select.cu",
|
||||
"gpu_ops/recover_decode_task.cu",
|
||||
"metax_ops/moe_dispatch.cu",
|
||||
"metax_ops/moe_ffn.cu",
|
||||
"metax_ops/moe_reduce.cu",
|
||||
|
||||
@@ -1144,7 +1144,7 @@ class CacheConfig:
|
||||
self.kv_cache_ratio = 1.0
|
||||
else:
|
||||
self.kv_cache_ratio = 0.75
|
||||
self.enc_dec_block_num = 0 if current_platform.is_maca() else envs.FD_ENC_DEC_BLOCK_NUM
|
||||
self.enc_dec_block_num = envs.FD_ENC_DEC_BLOCK_NUM
|
||||
self.prealloc_dec_block_slot_num_threshold = 12
|
||||
self.cache_dtype = "bfloat16"
|
||||
self.model_cfg = None
|
||||
|
||||
@@ -1060,7 +1060,8 @@ class EngineService:
|
||||
exit sub services
|
||||
"""
|
||||
self.running = False
|
||||
self.engine_worker_queue_server.cleanup()
|
||||
if hasattr(self, "engine_worker_queue_server") and self.engine_worker_queue_server is not None:
|
||||
self.engine_worker_queue_server.cleanup()
|
||||
self.exist_task_signal.clear()
|
||||
self.exist_swapped_task_signal.clear()
|
||||
self.worker_healthy_live_signal.clear()
|
||||
|
||||
@@ -13,11 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .attention.flash_attn_backend import FlashAttentionBackend
|
||||
from .attention.mla_attn_metax_backend import MetaxMLAAttentionBackend
|
||||
from .moe.fused_moe_cutlass_metax_backend import MetaxCutlassWeightOnlyMoEMethod
|
||||
from .moe.fused_moe_triton_metax_backend import MetaxTritonWeightOnlyMoEMethod
|
||||
|
||||
__all__ = [
|
||||
"FlashAttentionBackend",
|
||||
"MetaxMLAAttentionBackend",
|
||||
"MetaxTritonWeightOnlyMoEMethod",
|
||||
"MetaxCutlassWeightOnlyMoEMethod",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,444 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
decode_mla_write_cache,
|
||||
get_block_shape_and_split_kv_block,
|
||||
prefill_mla_write_cache,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
||||
from fastdeploy.model_executor.layers.backends.metax.attention.flash_attention_interface import (
|
||||
flash_attn_unpadded_func,
|
||||
)
|
||||
|
||||
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
""" """
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLAAttentionMetadata(AttentionMetadata):
|
||||
"""
|
||||
MLAAttentionMetadata for Multi-Layer Attention
|
||||
"""
|
||||
|
||||
_dtype: paddle.dtype = paddle.bfloat16
|
||||
encoder_max_partition_size: int = 32768
|
||||
max_partition_size: int = 32768
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
|
||||
# pd_disaggregation
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
|
||||
|
||||
max_enc_len_this_time: Optional[paddle.Tensor] = None
|
||||
max_dec_len_this_time: Optional[paddle.Tensor] = None
|
||||
max_kv_len_this_time: Optional[paddle.Tensor] = None
|
||||
|
||||
|
||||
class MetaxMLAAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
MLA Attention Backend implementation.
|
||||
"""
|
||||
|
||||
__infer_dynamic_dims_fields__ = ["attention_metadata"]
|
||||
attention_metadata: MLAAttentionMetadata
|
||||
flash_attn_func: callable = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
kv_num_heads: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
encoder_block_shape_q: int = -1,
|
||||
decoder_block_shape_q: int = -1,
|
||||
) -> None:
|
||||
"""
|
||||
MLAAttentionBackend __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: MLAAttentionMetadata = None
|
||||
|
||||
# 基础配置
|
||||
self.block_size: int = fd_config.cache_config.block_size
|
||||
self.max_seq_len: int = fd_config.model_config.max_model_len
|
||||
self.rope_theta: float = (
|
||||
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
|
||||
)
|
||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
|
||||
self.causal: bool = getattr(fd_config.model_config, "causal", True)
|
||||
self.speculative_method: str = fd_config.speculative_config.method
|
||||
self.use_speculate: bool = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
|
||||
|
||||
self.kv_num_heads: int = kv_num_heads
|
||||
self.num_heads: int = num_heads
|
||||
self.group_size: int = self.num_heads // self.kv_num_heads
|
||||
self.head_dim: int = fd_config.model_config.head_dim
|
||||
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
||||
self.encoder_block_shape_q: int = encoder_block_shape_q
|
||||
self.decoder_block_shape_q: int = decoder_block_shape_q
|
||||
|
||||
# For Multi Head Latent Attention
|
||||
self.kv_lora_rank: int = fd_config.model_config.kv_lora_rank
|
||||
self.qk_rope_head_dim: int = fd_config.model_config.qk_rope_head_dim
|
||||
self.qk_head_dim: int = fd_config.model_config.qk_nope_head_dim + fd_config.model_config.qk_rope_head_dim
|
||||
self.attn_softmax_scale: float = self.qk_head_dim**-0.5
|
||||
if fd_config.model_config.rope_scaling:
|
||||
mscale_all_dim = fd_config.model_config.rope_scaling.get("mscale_all_dim", False) # 1.0
|
||||
scaling_factor = fd_config.model_config.rope_scaling["factor"] # 40
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale
|
||||
|
||||
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
|
||||
|
||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
|
||||
|
||||
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
||||
|
||||
self.flash_attn_func = flash_attn_unpadded_func
|
||||
self.flash_attn_kwargs = {"softmax_scale": self.attn_softmax_scale}
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = MLAAttentionMetadata()
|
||||
metadata.max_partition_size = 32768
|
||||
metadata.encoder_max_partition_size = self.max_seq_len
|
||||
metadata._dtype = paddle.get_default_dtype()
|
||||
if metadata._dtype == "bfloat16":
|
||||
metadata._fuse_kernel_compute_dtype = "bf16"
|
||||
elif metadata._dtype == "float16":
|
||||
metadata._fuse_kernel_compute_dtype = "fp16"
|
||||
elif metadata._dtype == "float32":
|
||||
metadata._fuse_kernel_compute_dtype = "fp32"
|
||||
|
||||
metadata.block_tables = forward_meta.block_tables
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.attn_mask = forward_meta.attn_mask
|
||||
metadata.pre_caches_length = forward_meta.pre_caches_length
|
||||
|
||||
get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.decoder_batch_ids, # decoder_batch_ids_per_ctax
|
||||
forward_meta.decoder_tile_ids_per_batch, # decoder_chunk_ids_per_ctax_each_batch
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_device,
|
||||
forward_meta.decoder_chunk_size_device,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
self.encoder_block_shape_q,
|
||||
self.decoder_block_shape_q,
|
||||
self.group_size,
|
||||
self.block_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
)
|
||||
|
||||
# MLA
|
||||
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
|
||||
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
|
||||
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[8]
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
|
||||
self.attention_metadata: AttentionMetadata = metadata
|
||||
|
||||
def get_attntion_meta(self) -> AttentionMetadata:
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
kv_cache_quant_type: str = None,
|
||||
) -> Tuple[int, int, int, int]:
|
||||
"""
|
||||
Calculate kv cache shape for MLA
|
||||
"""
|
||||
return (
|
||||
max_num_blocks,
|
||||
1,
|
||||
self.block_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
)
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Prefill阶段的前向传播
|
||||
"""
|
||||
metadata = self.attention_metadata
|
||||
|
||||
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None
|
||||
|
||||
# 写入缓存
|
||||
prefill_mla_write_cache(
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.block_tables,
|
||||
"none",
|
||||
getattr(forward_meta, "max_input_length", -1),
|
||||
)
|
||||
|
||||
# Flash注意力计算
|
||||
fmha_out = self.flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.cu_seqlens_k,
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_enc_len_this_time,
|
||||
causal=self.causal,
|
||||
**self.flash_attn_kwargs,
|
||||
)[0]
|
||||
|
||||
return fmha_out
|
||||
|
||||
def _run_single_flash_mla(self, query, latent_cache, block_tables, seq_lens, draft_token_num):
|
||||
from flash_mla_paddle import flash_mla_with_kvcache, get_mla_metadata
|
||||
|
||||
qk_head_dim = self.kv_lora_rank + self.qk_rope_head_dim
|
||||
v_head_dim = self.kv_lora_rank
|
||||
q_head_num = self.num_heads
|
||||
kv_head_num = latent_cache.shape[2]
|
||||
|
||||
query = query.reshape([-1, draft_token_num, q_head_num, qk_head_dim])
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
seq_lens, draft_token_num * q_head_num // kv_head_num, kv_head_num
|
||||
)
|
||||
|
||||
out, _ = flash_mla_with_kvcache(
|
||||
query,
|
||||
latent_cache,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
v_head_dim,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
softmax_scale=self.attn_softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
return out.reshape([-1, q_head_num, v_head_dim])
|
||||
|
||||
def compute_flash_mla(self, query, latent_cache, forward_meta):
|
||||
block_tables = self.attention_metadata.block_tables
|
||||
seq_lens_decoder = forward_meta.seq_lens_decoder
|
||||
seq_lens_this_time = forward_meta.seq_lens_this_time
|
||||
assert block_tables is not None and seq_lens_decoder is not None and seq_lens_this_time is not None
|
||||
assert block_tables.shape[0] == seq_lens_decoder.shape[0]
|
||||
|
||||
query = query.reshape([-1, self.num_heads, self.kv_lora_rank + self.qk_rope_head_dim])
|
||||
latent_cache = latent_cache.transpose([0, 2, 1, 3])
|
||||
|
||||
seq_lens_decoder = seq_lens_decoder.squeeze(-1)
|
||||
seq_lens_this_time = seq_lens_this_time.squeeze(-1)
|
||||
non_zero_index = paddle.nonzero(seq_lens_this_time).flatten()
|
||||
seq_lens_decoder = seq_lens_decoder[non_zero_index]
|
||||
seq_lens_this_time = seq_lens_this_time[non_zero_index]
|
||||
block_tables = block_tables[non_zero_index]
|
||||
|
||||
max_seq_lens_this_time = seq_lens_this_time.max().item()
|
||||
min_seq_lens_this_time = seq_lens_this_time.min().item()
|
||||
|
||||
if max_seq_lens_this_time == min_seq_lens_this_time:
|
||||
return self._run_single_flash_mla(
|
||||
query, latent_cache, block_tables, seq_lens_decoder + seq_lens_this_time, max_seq_lens_this_time
|
||||
)
|
||||
else:
|
||||
max_draft_token_num = self.speculate_max_draft_token_num + 1
|
||||
seq_lens_this_time_cpu = seq_lens_this_time.cpu()
|
||||
bsz = seq_lens_this_time_cpu.shape[0]
|
||||
qk_head_dim = self.kv_lora_rank + self.qk_rope_head_dim
|
||||
batched_query = paddle.zeros(
|
||||
[bsz * max_draft_token_num, self.num_heads, qk_head_dim], dtype=query.dtype
|
||||
).to(query.place)
|
||||
full_token_index = paddle.arange(bsz * max_draft_token_num, dtype="int32").reshape(
|
||||
[bsz, max_draft_token_num]
|
||||
)
|
||||
token_mapping_index = []
|
||||
for group_id in range(bsz):
|
||||
seq_len = seq_lens_this_time_cpu[group_id]
|
||||
token_mapping_index.append(full_token_index[group_id, :seq_len])
|
||||
token_mapping_index = paddle.concat(token_mapping_index)
|
||||
assert token_mapping_index.shape[0] == query.shape[0]
|
||||
batched_query[token_mapping_index] = query
|
||||
seq_lens_this_time = paddle.full_like(seq_lens_this_time, fill_value=max_draft_token_num)
|
||||
out = self._run_single_flash_mla(
|
||||
batched_query, latent_cache, block_tables, seq_lens_decoder + seq_lens_this_time, max_draft_token_num
|
||||
)
|
||||
return out[token_mapping_index]
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Decode阶段的前向传播
|
||||
"""
|
||||
metadata = self.attention_metadata
|
||||
|
||||
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None
|
||||
|
||||
# 获取推测解码参数
|
||||
speculate_decoder = self.speculative_method is not None
|
||||
|
||||
# 写入缓存
|
||||
decode_mla_write_cache(
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.block_tables,
|
||||
"none",
|
||||
self.max_seq_len,
|
||||
speculate_decoder,
|
||||
)
|
||||
|
||||
# 多头潜在注意力计算
|
||||
fmha_out = self.compute_flash_mla(q, latent_cache, forward_meta)
|
||||
|
||||
return fmha_out
|
||||
|
||||
def forward_mixed(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Mixed模式的前向传播
|
||||
"""
|
||||
metadata = self.attention_metadata
|
||||
speculate_decoder = self.speculative_method is not None
|
||||
|
||||
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None
|
||||
|
||||
if k is not None:
|
||||
prefill_mla_write_cache(
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.block_tables,
|
||||
"none",
|
||||
self.max_seq_len,
|
||||
)
|
||||
|
||||
# FA
|
||||
fmha_out = self.flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.cu_seqlens_k,
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_enc_len_this_time,
|
||||
causal=self.causal,
|
||||
**self.flash_attn_kwargs,
|
||||
)[0]
|
||||
|
||||
return fmha_out
|
||||
|
||||
# Decode
|
||||
if k is None:
|
||||
decode_mla_write_cache(
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.block_tables,
|
||||
"none",
|
||||
self.max_seq_len,
|
||||
speculate_decoder,
|
||||
)
|
||||
|
||||
# 多头潜在注意力计算
|
||||
fmha_out = self.compute_flash_mla(q, latent_cache, forward_meta)
|
||||
|
||||
return fmha_out
|
||||
@@ -19,8 +19,10 @@ from paddle import nn
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
|
||||
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
|
||||
from fastdeploy.utils import ceil_div
|
||||
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
@@ -65,43 +67,74 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer.moe_intermediate_size,
|
||||
layer.hidden_size,
|
||||
]
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
|
||||
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
|
||||
layer.up_gate_proj_weight = layer.create_parameter(
|
||||
shape=self.up_gate_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
)
|
||||
|
||||
layer.down_proj_weight = layer.create_parameter(
|
||||
shape=self.down_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# weight_scale
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[0],
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[1],
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.hidden_size],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
)
|
||||
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
|
||||
|
||||
set_weight_attrs(
|
||||
layer.up_gate_proj_weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.down_proj_weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
|
||||
},
|
||||
)
|
||||
else:
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.up_gate_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.down_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# weight_scale
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[0],
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[1],
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.hidden_size],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# support cache feature in future
|
||||
|
||||
@paddle.no_grad()
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
@@ -114,6 +147,8 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
|
||||
algo = layer.quant_method.quant_config.name()
|
||||
|
||||
assert algo == "wint8"
|
||||
|
||||
assert up_gate_proj_weights[0].shape == [
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size * 2,
|
||||
@@ -143,6 +178,63 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||
|
||||
@paddle.no_grad()
|
||||
def process_weights_after_loading(self, layer):
|
||||
""" """
|
||||
if not self.quant_config.is_checkpoint_bf16:
|
||||
return
|
||||
|
||||
algo = layer.quant_method.quant_config.name()
|
||||
assert algo == "wint8"
|
||||
max_bound = 127
|
||||
weight_id_map = {"gate_up": 0, "down": 1}
|
||||
if (
|
||||
hasattr(layer.up_gate_proj_weight, "tensor_track")
|
||||
and layer.up_gate_proj_weight.tensor_track is not None
|
||||
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
|
||||
):
|
||||
weight_type = "gate_up"
|
||||
layer.up_gate_proj_weight.tensor_track = None
|
||||
else:
|
||||
weight_type = "down"
|
||||
layer.down_proj_weight.tensor_track = None
|
||||
|
||||
# weight
|
||||
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
|
||||
# scale
|
||||
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
|
||||
|
||||
weight_tensor = getattr(layer, weight_name)
|
||||
quanted_weight_scale = weight_tensor.abs().max(axis=1)
|
||||
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
|
||||
quanted_weight = paddle.round(quanted_weight).astype("int8")
|
||||
quanted_weight_scale = quanted_weight_scale / max_bound
|
||||
|
||||
getattr(layer, weight_name).value().get_tensor()._clear()
|
||||
|
||||
# create weight
|
||||
setattr(
|
||||
layer,
|
||||
weight_name,
|
||||
layer.create_parameter(
|
||||
shape=weight_tensor.shape,
|
||||
dtype=quanted_weight.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# create scale
|
||||
setattr(
|
||||
layer,
|
||||
scale_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight_scale.shape,
|
||||
dtype=quanted_weight_scale.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
getattr(layer, weight_name).copy_(quanted_weight, False)
|
||||
getattr(layer, scale_name).copy_(quanted_weight_scale, False)
|
||||
|
||||
@paddle.no_grad()
|
||||
def apply(
|
||||
self,
|
||||
@@ -157,38 +249,38 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
token_num = x.shape[0]
|
||||
top_k = layer.top_k
|
||||
num_local_experts = layer.num_local_experts
|
||||
top_k = layer.top_k
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
layer.top_k,
|
||||
True, # apply_norm_weight
|
||||
False,
|
||||
)
|
||||
|
||||
if layer.topk_method == "noaux_tc":
|
||||
gate_out, topk_weights, topk_ids = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
layer.topk_group,
|
||||
layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
getattr(layer, "renormalize", True),
|
||||
)
|
||||
else:
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
layer.top_k,
|
||||
True, # apply_norm_weight
|
||||
False,
|
||||
)
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
if self.quant_config is not None:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 4,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 4,
|
||||
}
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 4,
|
||||
}
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
|
||||
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]
|
||||
)
|
||||
@@ -237,6 +329,7 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a16=True,
|
||||
per_channel_quant=False,
|
||||
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
@@ -289,11 +382,12 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a16=True,
|
||||
per_channel_quant=False,
|
||||
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
down_proj_out.reshape_([token_num, top_k, hidden_size])
|
||||
out = down_proj_out.sum(axis=1)
|
||||
if layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(out, layer.fd_config.parallel_config.tp_group)
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
out = tensor_model_parallel_all_reduce(out, layer.fd_config.parallel_config.tp_group)
|
||||
return out
|
||||
|
||||
@@ -16,7 +16,7 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
@triton.jit()
|
||||
def fused_moe_kernel_paddle(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
@@ -30,20 +30,20 @@ def fused_moe_kernel_paddle(
|
||||
# Matrix dimensions
|
||||
max_possible_num_post_padded,
|
||||
num_valid_tokens,
|
||||
N,
|
||||
K,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
stride_am: tl.constexpr,
|
||||
stride_ak: tl.constexpr,
|
||||
stride_be: tl.constexpr,
|
||||
stride_bk: tl.constexpr,
|
||||
stride_bn: tl.constexpr,
|
||||
stride_cm: tl.constexpr,
|
||||
stride_cn: tl.constexpr,
|
||||
stride_asm: tl.constexpr,
|
||||
stride_ask: tl.constexpr,
|
||||
stride_bse: tl.constexpr,
|
||||
stride_bsk: tl.constexpr,
|
||||
stride_bsn: tl.constexpr,
|
||||
# Block size for block-wise fp8 quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
@@ -57,6 +57,7 @@ def fused_moe_kernel_paddle(
|
||||
compute_type_enum: tl.constexpr,
|
||||
use_fp8_w8a8: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr,
|
||||
per_channel_quant: tl.constexpr,
|
||||
even_Ks: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
@@ -119,6 +120,13 @@ def fused_moe_kernel_paddle(
|
||||
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
||||
offs_bsn = offs_bn // group_n
|
||||
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
||||
# channel-wise
|
||||
elif per_channel_quant:
|
||||
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
# Load per-token scale for activations
|
||||
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
||||
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
|
||||
else:
|
||||
# (Zkk): every expert has one activation scale and weight scale.
|
||||
a_scale = tl.load(a_scale_ptr + off_experts)
|
||||
|
||||
@@ -23,7 +23,7 @@ from paddle import nn
|
||||
from fastdeploy.config import ModelConfig
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.is_cuda() or current_platform.is_maca():
|
||||
from fastdeploy.model_executor.ops.gpu import fused_rotary_position_encoding
|
||||
|
||||
from .utils import CpuGuard
|
||||
|
||||
@@ -43,12 +43,12 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
|
||||
def clean_memory_fragments(self, state_dict: dict) -> None:
|
||||
"""clean_memory_fragments"""
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.is_cuda() or current_platform.is_maca():
|
||||
if state_dict:
|
||||
for k, v in state_dict.items():
|
||||
if isinstance(v, paddle.Tensor):
|
||||
v.value().get_tensor()._clear()
|
||||
paddle.device.cuda.empty_cache()
|
||||
paddle.device.empty_cache()
|
||||
paddle.device.synchronize()
|
||||
|
||||
@measure_time()
|
||||
|
||||
@@ -43,8 +43,8 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
||||
|
||||
def clean_memory_fragments(self) -> None:
|
||||
"""clean_memory_fragments"""
|
||||
if current_platform.is_cuda():
|
||||
paddle.device.cuda.empty_cache()
|
||||
if current_platform.is_cuda() or current_platform.is_maca():
|
||||
paddle.device.empty_cache()
|
||||
paddle.device.synchronize()
|
||||
|
||||
@save_model()
|
||||
|
||||
@@ -55,7 +55,7 @@ from fastdeploy.model_executor.models.model_base import (
|
||||
)
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.is_cuda() or current_platform.is_maca():
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
get_position_ids_and_mask_encoder_batch,
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.is_cuda() or current_platform.is_maca():
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
text_image_gather_scatter,
|
||||
text_image_index_out,
|
||||
@@ -32,6 +32,6 @@ elif current_platform.is_iluvatar():
|
||||
text_image_index_out,
|
||||
)
|
||||
else:
|
||||
raise ImportError("Unsupported platform, only support CUDA and XPU")
|
||||
raise ImportError("Unsupported platform, only support CUDA, MACA and XPU")
|
||||
|
||||
__all__ = ["text_image_gather_scatter", "text_image_index_out"]
|
||||
|
||||
@@ -60,6 +60,9 @@ class MACAPlatform(Platform):
|
||||
elif selected_backend == _Backend.APPEND_ATTN:
|
||||
logger.info("Using FLASH ATTN backend to instead of attend attention.")
|
||||
return "fastdeploy.model_executor.layers.backends.metax.attention.flash_attn_backend.FlashAttentionBackend"
|
||||
elif selected_backend == _Backend.MLA_ATTN:
|
||||
logger.info("Using MLA ATTN backend.")
|
||||
return "fastdeploy.model_executor.layers.backends.metax.attention.mla_attn_metax_backend.MetaxMLAAttentionBackend"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid attention backend you specified.\n"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -20,13 +20,12 @@ import time
|
||||
from typing import List, Optional
|
||||
|
||||
import paddle
|
||||
import pymxsml
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.utils import get_logger
|
||||
from fastdeploy.utils import get_logger, set_random_seed
|
||||
from fastdeploy.worker.metax_model_runner import MetaxModelRunner
|
||||
from fastdeploy.worker.output import ModelRunnerOutput
|
||||
from fastdeploy.worker.worker_base import WorkerBase
|
||||
@@ -53,23 +52,21 @@ class MetaxWorker(WorkerBase):
|
||||
Initialize device and construct model runner
|
||||
"""
|
||||
self.max_chips_per_node = 8
|
||||
if paddle.is_compiled_with_custom_device("metax_gpu"):
|
||||
# Set environment variable
|
||||
self.device_ids = self.parallel_config.device_ids.split(",")
|
||||
self.device = f"metax_gpu:{self.local_rank % self.max_chips_per_node}"
|
||||
paddle.device.set_device(self.device)
|
||||
paddle.set_default_dtype(self.model_config.dtype)
|
||||
# Set environment variable
|
||||
self.device_ids = self.parallel_config.device_ids.split(",")
|
||||
self.device = f"metax_gpu:{self.local_rank % self.max_chips_per_node}"
|
||||
paddle.device.set_device(self.device)
|
||||
paddle.set_default_dtype(self.model_config.dtype)
|
||||
|
||||
gc.collect()
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Not support device type: {self.device_config.device}")
|
||||
gc.collect()
|
||||
paddle.device.empty_cache()
|
||||
|
||||
set_random_seed(self.fd_config.model_config.seed)
|
||||
# Construct model runner
|
||||
self.model_runner: MetaxModelRunner = MetaxModelRunner(
|
||||
fd_config=self.fd_config,
|
||||
device=self.device,
|
||||
device_id=self.device_ids[self.local_rank % self.max_chips_per_node],
|
||||
device_id=int(self.device_ids[self.local_rank % self.max_chips_per_node]),
|
||||
rank=self.rank,
|
||||
local_rank=self.local_rank,
|
||||
)
|
||||
@@ -99,6 +96,8 @@ class MetaxWorker(WorkerBase):
|
||||
if fd_kvache_mem is not None:
|
||||
return int(float(fd_kvache_mem) * 1024**3)
|
||||
else:
|
||||
import pymxsml
|
||||
|
||||
# 1. Record memory state before profile run
|
||||
start_time = time.perf_counter()
|
||||
Gb = 1024**3
|
||||
@@ -200,9 +199,10 @@ class MetaxWorker(WorkerBase):
|
||||
"""
|
||||
Perform the warm-up and the graph optimization
|
||||
"""
|
||||
if self.model_runner.graph_opt_level >= 1:
|
||||
if self.fd_config.graph_opt_config.graph_opt_level >= 1 and not self.model_runner.use_cudagraph:
|
||||
self.model_runner.sot_warmup()
|
||||
# Todo Trigger cuda graph capture.
|
||||
# Trigger cuda graph capture
|
||||
self.model_runner.capture_model()
|
||||
|
||||
def check_health(self) -> bool:
|
||||
""" """
|
||||
|
||||
Reference in New Issue
Block a user