[Metax] adapt DeepSeek (#4498)

This commit is contained in:
xiaozude
2025-10-24 10:14:53 +08:00
committed by GitHub
parent 8718fa34b2
commit f7069b8057
19 changed files with 1538 additions and 324 deletions

View File

@@ -14,7 +14,9 @@
#include "cute/tensor.hpp"
#include "helper.h"
#include "paddle/extension.h"
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
#include "paddle/phi/core/memory/memcpy.h"
#endif
#include "utils.cuh"
template <int THREADBLOCK_SIZE>

View File

@@ -15,6 +15,7 @@
#include <cuda_runtime.h>
#include <stdint.h>
#include <cooperative_groups/memcpy_async.h>
enum class SharedMemFillMode { kFillZero, kNoFill };
@@ -42,18 +43,35 @@ __device__ __forceinline__ void ldmatrix_m8n8x4_trans_impl(uint32_t* R,
}
__device__ __forceinline__ void commit_group() {
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
{}
#else
asm volatile("cp.async.commit_group;\n" ::);
#endif
}
template <size_t n>
__device__ __forceinline__ void wait_group() {
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
cooperative_groups::wait(cooperative_groups::this_thread_block());
#else
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
#endif
}
template <PrefetchMode prefetch_mode, typename T>
__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) {
uint32_t smem_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
} else {
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
}
#else
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
asm volatile(
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(
@@ -68,6 +86,7 @@ __device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) {
"n"(16),
"r"(16));
}
#endif
}
template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
@@ -76,6 +95,28 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr,
bool predicate) {
uint32_t smem_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 16 : 0;
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
} else {
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
}
} else {
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
if (predicate) {
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
}
} else {
if (predicate) {
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
}
}
}
#else
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 16 : 0;
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
@@ -115,6 +156,7 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr,
"n"(16));
}
}
#endif
}
template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
@@ -123,6 +165,17 @@ __device__ __forceinline__ void pred_load_64b(T* smem_ptr,
bool predicate) {
uint32_t smem_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 8 : 0;
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 8);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
} else {
if (predicate) {
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 8);
}
}
#else
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 8 : 0;
asm volatile(
@@ -141,6 +194,7 @@ __device__ __forceinline__ void pred_load_64b(T* smem_ptr,
"l"(gmem_ptr),
"n"(8));
}
#endif
}
template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
@@ -149,6 +203,17 @@ __device__ __forceinline__ void pred_load_32b(T* smem_ptr,
bool predicate) {
uint32_t smem_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 4 : 0;
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 4);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
} else {
if (predicate) {
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 4);
}
}
#else
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 4 : 0;
asm volatile(
@@ -167,6 +232,7 @@ __device__ __forceinline__ void pred_load_32b(T* smem_ptr,
"l"(gmem_ptr),
"n"(4));
}
#endif
}
template <size_t num_bits, PrefetchMode prefetch_mode, typename T>

View File

@@ -595,10 +595,13 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
#endif
inline int GetSMVersion() {
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
return 80;
#else
static int sm_version = phi::backends::gpu::GetGPUComputeCapability(
phi::backends::gpu::GetCurrentDeviceId());
return sm_version;
#endif
}
inline bool GetMlaUseTensorcore() {

View File

@@ -18,6 +18,7 @@
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include "helper.h"
#include <cuda/std/limits>
namespace cg = cooperative_groups;
@@ -601,7 +602,7 @@ __global__ void group_idx_and_topk_idx_kernel(
if (i < topk) {
s_topk_value[i] = value;
}
topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
topk_sum += cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
}
}
@@ -658,6 +659,11 @@ void invokeNoAuxTc(T* scores,
cudaStream_t const stream) {
int64_t num_cases = num_tokens * n_group;
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
group_scores, scores_with_bias, num_tokens, num_cases, n_group, num_experts / n_group);
#else
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
cudaLaunchConfig_t config;
config.gridDim = topk_with_k2_num_blocks;
@@ -671,6 +677,7 @@ void invokeNoAuxTc(T* scores,
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
num_tokens, num_cases, n_group, num_experts / n_group);
#endif
int64_t topk_with_k_group_num_blocks =
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
@@ -678,6 +685,12 @@ void invokeNoAuxTc(T* scores,
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
topk);
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
group_idx_and_topk_idx_kernel<T, IdxT><<<topk_with_k_group_num_blocks, BLOCK_SIZE, dynamic_smem_in_bytes, stream>>>(
scores, group_scores, topk_values, topk_indices, scores_with_bias,
num_tokens, n_group, topk_group, topk, num_experts, num_experts / n_group,
renormalize, routed_scaling_factor);
#else
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
config.gridDim = topk_with_k_group_num_blocks;
config.blockDim = BLOCK_SIZE;
@@ -691,6 +704,7 @@ void invokeNoAuxTc(T* scores,
topk_values, topk_indices, scores_with_bias, num_tokens,
n_group, topk_group, topk, num_experts,
num_experts / n_group, renormalize, routed_scaling_factor);
#endif
}
#define INSTANTIATE_NOAUX_TC(T, IdxT) \

View File

@@ -601,9 +601,16 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
"gpu_ops/read_data_ipc.cu",
"gpu_ops/dequant_int8.cu",
"gpu_ops/share_external_data.cu",
"gpu_ops/recover_decode_task.cu",
"gpu_ops/noaux_tc.cu",
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/text_image_gather_scatter.cu",
"gpu_ops/text_image_index_out.cu",
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/append_attn/mla_cache_kernel.cu",
"gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu",
"gpu_ops/moe/tritonmoe_preprocess.cu",
"gpu_ops/moe/moe_topk_select.cu",
"gpu_ops/recover_decode_task.cu",
"metax_ops/moe_dispatch.cu",
"metax_ops/moe_ffn.cu",
"metax_ops/moe_reduce.cu",

View File

@@ -1144,7 +1144,7 @@ class CacheConfig:
self.kv_cache_ratio = 1.0
else:
self.kv_cache_ratio = 0.75
self.enc_dec_block_num = 0 if current_platform.is_maca() else envs.FD_ENC_DEC_BLOCK_NUM
self.enc_dec_block_num = envs.FD_ENC_DEC_BLOCK_NUM
self.prealloc_dec_block_slot_num_threshold = 12
self.cache_dtype = "bfloat16"
self.model_cfg = None

View File

@@ -1060,7 +1060,8 @@ class EngineService:
exit sub services
"""
self.running = False
self.engine_worker_queue_server.cleanup()
if hasattr(self, "engine_worker_queue_server") and self.engine_worker_queue_server is not None:
self.engine_worker_queue_server.cleanup()
self.exist_task_signal.clear()
self.exist_swapped_task_signal.clear()
self.worker_healthy_live_signal.clear()

View File

@@ -13,11 +13,13 @@
# limitations under the License.
from .attention.flash_attn_backend import FlashAttentionBackend
from .attention.mla_attn_metax_backend import MetaxMLAAttentionBackend
from .moe.fused_moe_cutlass_metax_backend import MetaxCutlassWeightOnlyMoEMethod
from .moe.fused_moe_triton_metax_backend import MetaxTritonWeightOnlyMoEMethod
__all__ = [
"FlashAttentionBackend",
"MetaxMLAAttentionBackend",
"MetaxTritonWeightOnlyMoEMethod",
"MetaxCutlassWeightOnlyMoEMethod",
]

View File

@@ -0,0 +1,444 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import math
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional, Tuple
import paddle
from fastdeploy.model_executor.ops.gpu import (
decode_mla_write_cache,
get_block_shape_and_split_kv_block,
prefill_mla_write_cache,
)
if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend,
AttentionMetadata,
)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
from fastdeploy.model_executor.layers.backends.metax.attention.flash_attention_interface import (
flash_attn_unpadded_func,
)
def yarn_get_mscale(scale=1, mscale=1):
""" """
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
@dataclass
class MLAAttentionMetadata(AttentionMetadata):
"""
MLAAttentionMetadata for Multi-Layer Attention
"""
_dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768
max_partition_size: int = 32768
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
_fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
max_enc_len_this_time: Optional[paddle.Tensor] = None
max_dec_len_this_time: Optional[paddle.Tensor] = None
max_kv_len_this_time: Optional[paddle.Tensor] = None
class MetaxMLAAttentionBackend(AttentionBackend):
"""
MLA Attention Backend implementation.
"""
__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: MLAAttentionMetadata
flash_attn_func: callable = None
def __init__(
self,
fd_config: FDConfig,
kv_num_heads: int,
num_heads: int,
head_dim: int,
encoder_block_shape_q: int = -1,
decoder_block_shape_q: int = -1,
) -> None:
"""
MLAAttentionBackend __init__
"""
super().__init__()
self.attention_metadata: MLAAttentionMetadata = None
# 基础配置
self.block_size: int = fd_config.cache_config.block_size
self.max_seq_len: int = fd_config.model_config.max_model_len
self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method: str = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads
self.group_size: int = self.num_heads // self.kv_num_heads
self.head_dim: int = fd_config.model_config.head_dim
self.num_layers: int = fd_config.model_config.num_hidden_layers
self.encoder_block_shape_q: int = encoder_block_shape_q
self.decoder_block_shape_q: int = decoder_block_shape_q
# For Multi Head Latent Attention
self.kv_lora_rank: int = fd_config.model_config.kv_lora_rank
self.qk_rope_head_dim: int = fd_config.model_config.qk_rope_head_dim
self.qk_head_dim: int = fd_config.model_config.qk_nope_head_dim + fd_config.model_config.qk_rope_head_dim
self.attn_softmax_scale: float = self.qk_head_dim**-0.5
if fd_config.model_config.rope_scaling:
mscale_all_dim = fd_config.model_config.rope_scaling.get("mscale_all_dim", False) # 1.0
scaling_factor = fd_config.model_config.rope_scaling["factor"] # 40
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
self.start_layer_index: int = fd_config.model_config.start_layer_index
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
self.rank, self.device_id = init_rank_and_device_id(fd_config)
self.flash_attn_func = flash_attn_unpadded_func
self.flash_attn_kwargs = {"softmax_scale": self.attn_softmax_scale}
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
metadata = MLAAttentionMetadata()
metadata.max_partition_size = 32768
metadata.encoder_max_partition_size = self.max_seq_len
metadata._dtype = paddle.get_default_dtype()
if metadata._dtype == "bfloat16":
metadata._fuse_kernel_compute_dtype = "bf16"
elif metadata._dtype == "float16":
metadata._fuse_kernel_compute_dtype = "fp16"
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"
metadata.block_tables = forward_meta.block_tables
metadata.rotary_embs = forward_meta.rotary_embs
metadata.attn_mask = forward_meta.attn_mask
metadata.pre_caches_length = forward_meta.pre_caches_length
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.decoder_batch_ids, # decoder_batch_ids_per_ctax
forward_meta.decoder_tile_ids_per_batch, # decoder_chunk_ids_per_ctax_each_batch
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
self.block_size,
self.speculate_max_draft_token_num + 1,
)
# MLA
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[8]
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
self.attention_metadata: AttentionMetadata = metadata
def get_attntion_meta(self) -> AttentionMetadata:
"""get_attntion_meta"""
return self.attention_metadata
def get_kv_cache_shape(
self,
max_num_blocks: int,
kv_cache_quant_type: str = None,
) -> Tuple[int, int, int, int]:
"""
Calculate kv cache shape for MLA
"""
return (
max_num_blocks,
1,
self.block_size,
self.kv_lora_rank + self.qk_rope_head_dim,
)
def forward_extend(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""
Prefill阶段的前向传播
"""
metadata = self.attention_metadata
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None
# 写入缓存
prefill_mla_write_cache(
compressed_kv,
k_pe,
latent_cache,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
"none",
getattr(forward_meta, "max_input_length", -1),
)
# Flash注意力计算
fmha_out = self.flash_attn_func(
q,
k,
v,
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k,
metadata.max_enc_len_this_time,
metadata.max_enc_len_this_time,
causal=self.causal,
**self.flash_attn_kwargs,
)[0]
return fmha_out
def _run_single_flash_mla(self, query, latent_cache, block_tables, seq_lens, draft_token_num):
from flash_mla_paddle import flash_mla_with_kvcache, get_mla_metadata
qk_head_dim = self.kv_lora_rank + self.qk_rope_head_dim
v_head_dim = self.kv_lora_rank
q_head_num = self.num_heads
kv_head_num = latent_cache.shape[2]
query = query.reshape([-1, draft_token_num, q_head_num, qk_head_dim])
tile_scheduler_metadata, num_splits = get_mla_metadata(
seq_lens, draft_token_num * q_head_num // kv_head_num, kv_head_num
)
out, _ = flash_mla_with_kvcache(
query,
latent_cache,
block_tables,
seq_lens,
v_head_dim,
tile_scheduler_metadata,
num_splits,
softmax_scale=self.attn_softmax_scale,
causal=True,
)
return out.reshape([-1, q_head_num, v_head_dim])
def compute_flash_mla(self, query, latent_cache, forward_meta):
block_tables = self.attention_metadata.block_tables
seq_lens_decoder = forward_meta.seq_lens_decoder
seq_lens_this_time = forward_meta.seq_lens_this_time
assert block_tables is not None and seq_lens_decoder is not None and seq_lens_this_time is not None
assert block_tables.shape[0] == seq_lens_decoder.shape[0]
query = query.reshape([-1, self.num_heads, self.kv_lora_rank + self.qk_rope_head_dim])
latent_cache = latent_cache.transpose([0, 2, 1, 3])
seq_lens_decoder = seq_lens_decoder.squeeze(-1)
seq_lens_this_time = seq_lens_this_time.squeeze(-1)
non_zero_index = paddle.nonzero(seq_lens_this_time).flatten()
seq_lens_decoder = seq_lens_decoder[non_zero_index]
seq_lens_this_time = seq_lens_this_time[non_zero_index]
block_tables = block_tables[non_zero_index]
max_seq_lens_this_time = seq_lens_this_time.max().item()
min_seq_lens_this_time = seq_lens_this_time.min().item()
if max_seq_lens_this_time == min_seq_lens_this_time:
return self._run_single_flash_mla(
query, latent_cache, block_tables, seq_lens_decoder + seq_lens_this_time, max_seq_lens_this_time
)
else:
max_draft_token_num = self.speculate_max_draft_token_num + 1
seq_lens_this_time_cpu = seq_lens_this_time.cpu()
bsz = seq_lens_this_time_cpu.shape[0]
qk_head_dim = self.kv_lora_rank + self.qk_rope_head_dim
batched_query = paddle.zeros(
[bsz * max_draft_token_num, self.num_heads, qk_head_dim], dtype=query.dtype
).to(query.place)
full_token_index = paddle.arange(bsz * max_draft_token_num, dtype="int32").reshape(
[bsz, max_draft_token_num]
)
token_mapping_index = []
for group_id in range(bsz):
seq_len = seq_lens_this_time_cpu[group_id]
token_mapping_index.append(full_token_index[group_id, :seq_len])
token_mapping_index = paddle.concat(token_mapping_index)
assert token_mapping_index.shape[0] == query.shape[0]
batched_query[token_mapping_index] = query
seq_lens_this_time = paddle.full_like(seq_lens_this_time, fill_value=max_draft_token_num)
out = self._run_single_flash_mla(
batched_query, latent_cache, block_tables, seq_lens_decoder + seq_lens_this_time, max_draft_token_num
)
return out[token_mapping_index]
def forward_decode(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""
Decode阶段的前向传播
"""
metadata = self.attention_metadata
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None
# 获取推测解码参数
speculate_decoder = self.speculative_method is not None
# 写入缓存
decode_mla_write_cache(
compressed_kv,
k_pe,
latent_cache,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_encoder,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
"none",
self.max_seq_len,
speculate_decoder,
)
# 多头潜在注意力计算
fmha_out = self.compute_flash_mla(q, latent_cache, forward_meta)
return fmha_out
def forward_mixed(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""
Mixed模式的前向传播
"""
metadata = self.attention_metadata
speculate_decoder = self.speculative_method is not None
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None
if k is not None:
prefill_mla_write_cache(
compressed_kv,
k_pe,
latent_cache,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
"none",
self.max_seq_len,
)
# FA
fmha_out = self.flash_attn_func(
q,
k,
v,
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k,
metadata.max_enc_len_this_time,
metadata.max_enc_len_this_time,
causal=self.causal,
**self.flash_attn_kwargs,
)[0]
return fmha_out
# Decode
if k is None:
decode_mla_write_cache(
compressed_kv,
k_pe,
latent_cache,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_encoder,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
"none",
self.max_seq_len,
speculate_decoder,
)
# 多头潜在注意力计算
fmha_out = self.compute_flash_mla(q, latent_cache, forward_meta)
return fmha_out

View File

@@ -19,8 +19,10 @@ from paddle import nn
import fastdeploy
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
from fastdeploy.utils import ceil_div
from .triton_moe_kernels import fused_moe_kernel_paddle
@@ -65,43 +67,74 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
layer.moe_intermediate_size,
layer.hidden_size,
]
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
layer.up_gate_proj_weight = layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=self.weight_dtype,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
)
layer.down_proj_weight = layer.create_parameter(
shape=self.down_proj_weight_shape,
dtype=self.weight_dtype,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
self.added_scale_attrs[0],
layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
self.added_scale_attrs[1],
layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
set_weight_attrs(
layer.up_gate_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
},
)
set_weight_attrs(
layer.down_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
},
)
else:
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.down_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
self.added_scale_attrs[0],
layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
self.added_scale_attrs[1],
layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# support cache feature in future
@paddle.no_grad()
def process_loaded_weights(self, layer: nn.Layer, state_dict):
@@ -114,6 +147,8 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
algo = layer.quant_method.quant_config.name()
assert algo == "wint8"
assert up_gate_proj_weights[0].shape == [
layer.hidden_size,
layer.moe_intermediate_size * 2,
@@ -143,6 +178,63 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
getattr(layer, weight_name).set_value(quanted_weight)
getattr(layer, scale_name).set_value(quanted_weight_scale)
@paddle.no_grad()
def process_weights_after_loading(self, layer):
""" """
if not self.quant_config.is_checkpoint_bf16:
return
algo = layer.quant_method.quant_config.name()
assert algo == "wint8"
max_bound = 127
weight_id_map = {"gate_up": 0, "down": 1}
if (
hasattr(layer.up_gate_proj_weight, "tensor_track")
and layer.up_gate_proj_weight.tensor_track is not None
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
):
weight_type = "gate_up"
layer.up_gate_proj_weight.tensor_track = None
else:
weight_type = "down"
layer.down_proj_weight.tensor_track = None
# weight
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
# scale
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
weight_tensor = getattr(layer, weight_name)
quanted_weight_scale = weight_tensor.abs().max(axis=1)
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
quanted_weight = paddle.round(quanted_weight).astype("int8")
quanted_weight_scale = quanted_weight_scale / max_bound
getattr(layer, weight_name).value().get_tensor()._clear()
# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_tensor.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(quanted_weight, False)
getattr(layer, scale_name).copy_(quanted_weight_scale, False)
@paddle.no_grad()
def apply(
self,
@@ -157,38 +249,38 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
token_num = x.shape[0]
top_k = layer.top_k
num_local_experts = layer.num_local_experts
top_k = layer.top_k
moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
layer.top_k,
True, # apply_norm_weight
False,
)
if layer.topk_method == "noaux_tc":
gate_out, topk_weights, topk_ids = get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
else:
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
layer.top_k,
True, # apply_norm_weight
False,
)
up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
)
if self.quant_config is not None:
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
}
else:
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
}
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
}
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]
)
@@ -237,6 +329,7 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
compute_type_enum=1,
use_fp8_w8a8=False,
use_int8_w8a16=True,
per_channel_quant=False,
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
)
@@ -289,11 +382,12 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
compute_type_enum=1,
use_fp8_w8a8=False,
use_int8_w8a16=True,
per_channel_quant=False,
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
)
down_proj_out.reshape_([token_num, top_k, hidden_size])
out = down_proj_out.sum(axis=1)
if layer.tp_size > 1:
tensor_model_parallel_all_reduce(out, layer.fd_config.parallel_config.tp_group)
if layer.reduce_results and layer.tp_size > 1:
out = tensor_model_parallel_all_reduce(out, layer.fd_config.parallel_config.tp_group)
return out

View File

@@ -16,7 +16,7 @@ import triton
import triton.language as tl
@triton.jit
@triton.jit()
def fused_moe_kernel_paddle(
a_ptr,
b_ptr,
@@ -30,20 +30,20 @@ def fused_moe_kernel_paddle(
# Matrix dimensions
max_possible_num_post_padded,
num_valid_tokens,
N,
K,
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
N: tl.constexpr,
K: tl.constexpr,
stride_am: tl.constexpr,
stride_ak: tl.constexpr,
stride_be: tl.constexpr,
stride_bk: tl.constexpr,
stride_bn: tl.constexpr,
stride_cm: tl.constexpr,
stride_cn: tl.constexpr,
stride_asm: tl.constexpr,
stride_ask: tl.constexpr,
stride_bse: tl.constexpr,
stride_bsk: tl.constexpr,
stride_bsn: tl.constexpr,
# Block size for block-wise fp8 quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
@@ -57,6 +57,7 @@ def fused_moe_kernel_paddle(
compute_type_enum: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr,
even_Ks: tl.constexpr,
):
"""
@@ -119,6 +120,13 @@ def fused_moe_kernel_paddle(
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
# channel-wise
elif per_channel_quant:
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs)
# Load per-token scale for activations
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
else:
# (Zkk): every expert has one activation scale and weight scale.
a_scale = tl.load(a_scale_ptr + off_experts)

View File

@@ -23,7 +23,7 @@ from paddle import nn
from fastdeploy.config import ModelConfig
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
if current_platform.is_cuda() or current_platform.is_maca():
from fastdeploy.model_executor.ops.gpu import fused_rotary_position_encoding
from .utils import CpuGuard

View File

@@ -43,12 +43,12 @@ class DefaultModelLoader(BaseModelLoader):
def clean_memory_fragments(self, state_dict: dict) -> None:
"""clean_memory_fragments"""
if current_platform.is_cuda():
if current_platform.is_cuda() or current_platform.is_maca():
if state_dict:
for k, v in state_dict.items():
if isinstance(v, paddle.Tensor):
v.value().get_tensor()._clear()
paddle.device.cuda.empty_cache()
paddle.device.empty_cache()
paddle.device.synchronize()
@measure_time()

View File

@@ -43,8 +43,8 @@ class DefaultModelLoaderV1(BaseModelLoader):
def clean_memory_fragments(self) -> None:
"""clean_memory_fragments"""
if current_platform.is_cuda():
paddle.device.cuda.empty_cache()
if current_platform.is_cuda() or current_platform.is_maca():
paddle.device.empty_cache()
paddle.device.synchronize()
@save_model()

View File

@@ -55,7 +55,7 @@ from fastdeploy.model_executor.models.model_base import (
)
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
if current_platform.is_cuda() or current_platform.is_maca():
from fastdeploy.model_executor.ops.gpu import (
get_position_ids_and_mask_encoder_batch,
)

View File

@@ -16,7 +16,7 @@
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
if current_platform.is_cuda() or current_platform.is_maca():
from fastdeploy.model_executor.ops.gpu import (
text_image_gather_scatter,
text_image_index_out,
@@ -32,6 +32,6 @@ elif current_platform.is_iluvatar():
text_image_index_out,
)
else:
raise ImportError("Unsupported platform, only support CUDA and XPU")
raise ImportError("Unsupported platform, only support CUDA, MACA and XPU")
__all__ = ["text_image_gather_scatter", "text_image_index_out"]

View File

@@ -60,6 +60,9 @@ class MACAPlatform(Platform):
elif selected_backend == _Backend.APPEND_ATTN:
logger.info("Using FLASH ATTN backend to instead of attend attention.")
return "fastdeploy.model_executor.layers.backends.metax.attention.flash_attn_backend.FlashAttentionBackend"
elif selected_backend == _Backend.MLA_ATTN:
logger.info("Using MLA ATTN backend.")
return "fastdeploy.model_executor.layers.backends.metax.attention.mla_attn_metax_backend.MetaxMLAAttentionBackend"
else:
raise ValueError(
"Invalid attention backend you specified.\n"

File diff suppressed because it is too large Load Diff

View File

@@ -20,13 +20,12 @@ import time
from typing import List, Optional
import paddle
import pymxsml
from paddle import nn
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.utils import get_logger
from fastdeploy.utils import get_logger, set_random_seed
from fastdeploy.worker.metax_model_runner import MetaxModelRunner
from fastdeploy.worker.output import ModelRunnerOutput
from fastdeploy.worker.worker_base import WorkerBase
@@ -53,23 +52,21 @@ class MetaxWorker(WorkerBase):
Initialize device and construct model runner
"""
self.max_chips_per_node = 8
if paddle.is_compiled_with_custom_device("metax_gpu"):
# Set environment variable
self.device_ids = self.parallel_config.device_ids.split(",")
self.device = f"metax_gpu:{self.local_rank % self.max_chips_per_node}"
paddle.device.set_device(self.device)
paddle.set_default_dtype(self.model_config.dtype)
# Set environment variable
self.device_ids = self.parallel_config.device_ids.split(",")
self.device = f"metax_gpu:{self.local_rank % self.max_chips_per_node}"
paddle.device.set_device(self.device)
paddle.set_default_dtype(self.model_config.dtype)
gc.collect()
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")
gc.collect()
paddle.device.empty_cache()
set_random_seed(self.fd_config.model_config.seed)
# Construct model runner
self.model_runner: MetaxModelRunner = MetaxModelRunner(
fd_config=self.fd_config,
device=self.device,
device_id=self.device_ids[self.local_rank % self.max_chips_per_node],
device_id=int(self.device_ids[self.local_rank % self.max_chips_per_node]),
rank=self.rank,
local_rank=self.local_rank,
)
@@ -99,6 +96,8 @@ class MetaxWorker(WorkerBase):
if fd_kvache_mem is not None:
return int(float(fd_kvache_mem) * 1024**3)
else:
import pymxsml
# 1. Record memory state before profile run
start_time = time.perf_counter()
Gb = 1024**3
@@ -200,9 +199,10 @@ class MetaxWorker(WorkerBase):
"""
Perform the warm-up and the graph optimization
"""
if self.model_runner.graph_opt_level >= 1:
if self.fd_config.graph_opt_config.graph_opt_level >= 1 and not self.model_runner.use_cudagraph:
self.model_runner.sot_warmup()
# Todo Trigger cuda graph capture.
# Trigger cuda graph capture
self.model_runner.capture_model()
def check_health(self) -> bool:
""" """