mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[MetaxGPU] Support FastDeploy on metax gpu (#3241)
* [MetaxGPU] Support FastDeploy on metax gpu * Update metax_worker.py 1. change worker log; 2. remove custom allreduce, adapt it later; 3. remove cuda graph; * Update __init__.py 1. remove metax's key work comment * Update __init__.py 1. remove metax's key word comment; 2. add fused_moe_kernel_paddle import --------- Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
10
build.sh
10
build.sh
@@ -126,6 +126,16 @@ function copy_ops(){
|
||||
return
|
||||
fi
|
||||
|
||||
is_maca=`$python -c "import paddle; print(paddle.device.is_compiled_with_custom_device('metax_gpu'))"`
|
||||
if [ "$is_maca" = "True" ]; then
|
||||
DEVICE_TYPE="metax_gpu"
|
||||
mkdir -p ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
|
||||
echo -e "MACA ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
|
||||
DEVICE_TYPE="cpu"
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cd ../../../../
|
||||
|
@@ -509,6 +509,7 @@ static void PrintMatrix3(const T *mat_d, int num, std::string name) {
|
||||
}
|
||||
|
||||
#ifndef PADDLE_WITH_HIP
|
||||
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||
__forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr,
|
||||
int mode = 0) {
|
||||
uint32_t flag;
|
||||
@@ -541,7 +542,7 @@ __forceinline__ __device__ void st_flag_release(uint32_t *flag_addr,
|
||||
"l"(flag_addr));
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
int max_shared_mem_per_block_opt_in = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
||||
|
@@ -564,6 +564,72 @@ elif paddle.is_compiled_with_custom_device("gcu"):
|
||||
]
|
||||
),
|
||||
)
|
||||
elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
maca_path = os.getenv("MACA_PATH", "/opt/maca")
|
||||
json_dir = "third_party/nlohmann_json"
|
||||
if not os.path.exists(json_dir) or not os.listdir(json_dir):
|
||||
if not os.path.exists(json_dir):
|
||||
os.makedirs(json_dir)
|
||||
clone_git_repo("v3.11.3", "https://gitee.com/learnlov/mirrors_nlohmann_json.git", json_dir)
|
||||
if not os.listdir(json_dir):
|
||||
raise ValueError("Git clone nlohmann_json failed!")
|
||||
sources = [
|
||||
"gpu_ops/save_with_output.cc",
|
||||
"gpu_ops/set_mask_value.cu",
|
||||
"gpu_ops/set_value_by_flags.cu",
|
||||
"gpu_ops/ngram_mask.cu",
|
||||
"gpu_ops/gather_idx.cu",
|
||||
"gpu_ops/get_output_ep.cc",
|
||||
"gpu_ops/token_penalty_multi_scores.cu",
|
||||
"gpu_ops/token_penalty_only_once.cu",
|
||||
"gpu_ops/stop_generation.cu",
|
||||
"gpu_ops/stop_generation_multi_ends.cu",
|
||||
"gpu_ops/set_flags.cu",
|
||||
"gpu_ops/fused_get_rope.cu",
|
||||
"gpu_ops/get_padding_offset.cu",
|
||||
"gpu_ops/update_inputs.cu",
|
||||
"gpu_ops/update_inputs_beam.cu",
|
||||
"gpu_ops/beam_search_softmax.cu",
|
||||
"gpu_ops/rebuild_padding.cu",
|
||||
"gpu_ops/step.cu",
|
||||
"gpu_ops/step_reschedule.cu",
|
||||
"gpu_ops/step_system_cache.cu",
|
||||
"gpu_ops/set_data_ipc.cu",
|
||||
"gpu_ops/read_data_ipc.cu",
|
||||
"gpu_ops/dequant_int8.cu",
|
||||
"gpu_ops/share_external_data.cu",
|
||||
"gpu_ops/extract_text_token_output.cu",
|
||||
"gpu_ops/moe/tritonmoe_preprocess.cu",
|
||||
"gpu_ops/moe/moe_topk_select.cu",
|
||||
"gpu_ops/recover_decode_task.cu",
|
||||
]
|
||||
|
||||
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
|
||||
sources += find_end_files("gpu_ops/speculate_decoding", ".cc")
|
||||
|
||||
setup(
|
||||
name="fastdeploy_ops",
|
||||
ext_modules=CUDAExtension(
|
||||
sources=sources,
|
||||
extra_compile_args={
|
||||
"cxx": ["-O3"],
|
||||
"nvcc": [
|
||||
"-O3",
|
||||
"-Ithird_party/nlohmann_json/include",
|
||||
"-Igpu_ops",
|
||||
"-DPADDLE_DEV",
|
||||
"-DPADDLE_WITH_CUSTOM_DEVICE_METAX_GPU",
|
||||
],
|
||||
},
|
||||
library_dirs=[os.path.join(maca_path, "lib")],
|
||||
extra_link_args=["-lruntime_cu"],
|
||||
include_dirs=[
|
||||
os.path.join(maca_path, "include"),
|
||||
os.path.join(maca_path, "include/mcr"),
|
||||
os.path.join(maca_path, "include/common"),
|
||||
],
|
||||
),
|
||||
)
|
||||
else:
|
||||
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
|
||||
|
||||
|
@@ -37,6 +37,8 @@ class ForwardMode(IntEnum):
|
||||
DECODE = auto()
|
||||
# Mixed mode
|
||||
MIXED = auto()
|
||||
# Native mode
|
||||
NATIVE = auto()
|
||||
|
||||
def is_prefill(self):
|
||||
"""Is Extend mode"""
|
||||
@@ -50,6 +52,10 @@ class ForwardMode(IntEnum):
|
||||
"""Is Mixed mode"""
|
||||
return self == ForwardMode.MIXED
|
||||
|
||||
def is_native(self):
|
||||
"""Is Native mode"""
|
||||
return self == ForwardMode.NATIVE
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardMeta:
|
||||
|
@@ -68,6 +68,7 @@ class SiluAndMul(nn.Layer):
|
||||
or current_platform.is_xpu()
|
||||
or current_platform.is_iluvatar()
|
||||
or current_platform.is_dcu()
|
||||
or current_platform.is_maca()
|
||||
):
|
||||
self.forward = self.forward_cuda
|
||||
elif current_platform.is_gcu():
|
||||
|
@@ -86,6 +86,15 @@ class AttentionBackend(ABC):
|
||||
layer,
|
||||
forward_meta,
|
||||
)
|
||||
elif forward_meta.forward_mode.is_native():
|
||||
return self.forward_native_backend(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
qkv,
|
||||
layer,
|
||||
forward_meta,
|
||||
)
|
||||
else:
|
||||
return self.forward_extend(
|
||||
q,
|
||||
@@ -139,3 +148,15 @@ class AttentionBackend(ABC):
|
||||
) -> paddle.Tensor:
|
||||
"""Run a forward for extend."""
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_native_backend(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
layer: paddle.nn.Layer,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
"""Run a forward for native."""
|
||||
raise NotImplementedError
|
||||
|
@@ -48,3 +48,10 @@ if current_platform.is_dcu():
|
||||
if hasattr(dcu, "__all__"):
|
||||
globals().update({name: getattr(dcu, name) for name in dcu.__all__})
|
||||
__all__.extend(dcu.__all__)
|
||||
|
||||
if current_platform.is_maca():
|
||||
from . import metax
|
||||
|
||||
if hasattr(metax, "__all__"):
|
||||
globals().update({name: getattr(metax, name) for name in metax.__all__})
|
||||
__all__.extend(metax.__all__)
|
||||
|
21
fastdeploy/model_executor/layers/backends/metax/__init__.py
Normal file
21
fastdeploy/model_executor/layers/backends/metax/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .attention.flash_attn_backend import FlashAttentionBackend
|
||||
from .moe.fused_moe_triton_metax_backend import MetaxTritonWeightOnlyMoEMethod
|
||||
|
||||
__all__ = [
|
||||
"FlashAttentionBackend",
|
||||
"MetaxTritonWeightOnlyMoEMethod",
|
||||
]
|
@@ -0,0 +1,30 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
metax gpu backend attention methods
|
||||
"""
|
||||
from .flash_attention_interface import (
|
||||
flash_attn_func,
|
||||
flash_attn_kvcache_func,
|
||||
flash_attn_unpadded_func,
|
||||
)
|
||||
from .flash_attn_backend import FlashAttentionBackend
|
||||
|
||||
__all__ = [
|
||||
"FlashAttentionBackend",
|
||||
"flash_attn_func",
|
||||
"flash_attn_unpadded_func",
|
||||
"flash_attn_kvcache_func",
|
||||
]
|
@@ -0,0 +1,104 @@
|
||||
import os
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import paddle
|
||||
from paddle import Tensor
|
||||
|
||||
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
|
||||
if lib.endswith(".so"):
|
||||
paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(lib)
|
||||
|
||||
|
||||
def flash_attn_func(
|
||||
q: Tensor,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
fixed_seed_offset: Optional[Tensor] = None,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
dropout_prob: float = 0.0,
|
||||
causal: bool = False,
|
||||
return_softmax: bool = False,
|
||||
is_test: bool = True,
|
||||
rng_name: str = "",
|
||||
) -> Union[Tensor, Tuple[Tensor, ...]]:
|
||||
return paddle._C_ops.flash_attn(
|
||||
q, k, v, fixed_seed_offset, attn_mask, dropout_prob, causal, return_softmax, is_test, rng_name
|
||||
)
|
||||
|
||||
|
||||
def flash_attn_unpadded_func(
|
||||
q: Tensor,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
cu_seqlens_q: Tensor,
|
||||
cu_seqlens_k: Tensor,
|
||||
max_seqlen_q: Union[int, float],
|
||||
max_seqlen_k: Union[int, float],
|
||||
fixed_seed_offset: Optional[Tensor] = None,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
softmax_scale: float = 1.0,
|
||||
dropout: float = 0.0,
|
||||
causal: bool = False,
|
||||
return_softmax: bool = False,
|
||||
is_test: bool = True,
|
||||
rng_name: str = "",
|
||||
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
max_seqlen_q_t = paddle.to_tensor(max_seqlen_q, dtype="int64")
|
||||
max_seqlen_k_t = paddle.to_tensor(max_seqlen_k, dtype="int64")
|
||||
|
||||
outputs = paddle._C_ops.flash_attn_unpadded(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
fixed_seed_offset,
|
||||
attn_mask,
|
||||
max_seqlen_q_t,
|
||||
max_seqlen_k_t,
|
||||
softmax_scale,
|
||||
dropout,
|
||||
causal,
|
||||
return_softmax,
|
||||
is_test,
|
||||
rng_name,
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
def flash_attn_kvcache_func(
|
||||
q: Tensor,
|
||||
k_cache: Tensor,
|
||||
v_cache: Tensor,
|
||||
seqlens_k: Tensor,
|
||||
block_table: Tensor,
|
||||
k: Optional[Tensor] = None,
|
||||
v: Optional[Tensor] = None,
|
||||
rotary_cos: Optional[Tensor] = None,
|
||||
rotary_sin: Optional[Tensor] = None,
|
||||
cache_batch_idx: Optional[Tensor] = None,
|
||||
causal: bool = True,
|
||||
is_rotary_interleaved: bool = False,
|
||||
num_splits: int = 1,
|
||||
dropout: float = 0.0,
|
||||
return_softmax: bool = False,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
out, softmax_lse = paddle._C_ops._run_custom_op(
|
||||
"flash_attn_kvcache",
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
k,
|
||||
v,
|
||||
seqlens_k,
|
||||
rotary_cos,
|
||||
rotary_sin,
|
||||
cache_batch_idx,
|
||||
block_table,
|
||||
causal,
|
||||
is_rotary_interleaved,
|
||||
num_splits,
|
||||
dropout,
|
||||
return_softmax,
|
||||
)
|
||||
return out, softmax_lse
|
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta, ForwardMode
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
||||
from fastdeploy.model_executor.layers.backends.metax.attention.flash_attention_interface import (
|
||||
flash_attn_kvcache_func,
|
||||
flash_attn_unpadded_func,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata(AttentionMetadata):
|
||||
"""
|
||||
FlashAttentionMetadata
|
||||
"""
|
||||
|
||||
max_len_kv: paddle.Tensor = None
|
||||
set_max_lengths: int = -1
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
decoder_batch_ids: paddle.Tensor = None
|
||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
decoder_num_blocks: paddle.Tensor = None
|
||||
|
||||
_dtype: paddle.dtype = paddle.bfloat16
|
||||
encoder_max_partition_size: int = 32768
|
||||
max_partition_size: int = 32768
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: int = -1
|
||||
decoder_block_shape_q: int = -1
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
|
||||
# pd_disaggregation
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
FlashAttentionBackend backend implementation.
|
||||
"""
|
||||
|
||||
__infer_dynamic_dims_fields__ = ["attention_metadata"]
|
||||
attention_metadata: FlashAttentionMetadata
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
kv_num_heads: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
encoder_block_shape_q: int = -1,
|
||||
decoder_block_shape_q: int = -1,
|
||||
) -> None:
|
||||
"""
|
||||
FlashAttentionBackend __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: FlashAttentionMetadata = None
|
||||
self.block_size: int = fd_config.parallel_config.block_size
|
||||
self.max_seq_len: int = fd_config.parallel_config.max_model_len
|
||||
self.rope_theta: float = (
|
||||
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
|
||||
)
|
||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
|
||||
self.causal: bool = getattr(fd_config.model_config, "causal", True)
|
||||
self.speculative_method: str = fd_config.speculative_config.method
|
||||
self.use_speculate: bool = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
|
||||
self.encoder_block_shape_q: int = encoder_block_shape_q
|
||||
self.decoder_block_shape_q: int = decoder_block_shape_q
|
||||
|
||||
self.kv_num_heads: int = kv_num_heads
|
||||
self.num_heads: int = num_heads
|
||||
self.head_dim: int = fd_config.model_config.head_dim
|
||||
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
||||
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 32768))
|
||||
|
||||
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
|
||||
|
||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||
|
||||
if fd_config.parallel_config.expert_parallel_rank is None:
|
||||
fd_config.parallel_config.expert_parallel_rank = 0
|
||||
|
||||
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||
forward_meta.forward_mode = ForwardMode.NATIVE
|
||||
return
|
||||
|
||||
def get_attntion_meta(self) -> AttentionMetadata:
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
kv_cache_quant_type: str = None,
|
||||
):
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
"""
|
||||
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
|
||||
return (
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.head_dim // 2,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
def split_qkv(self, qkv, num_head_q, num_head_kv, dim):
|
||||
q = qkv[:, : num_head_q * dim].reshape([-1, num_head_q, dim])
|
||||
k = qkv[:, num_head_q * dim : num_head_q * dim + num_head_kv * dim].reshape([-1, num_head_kv, dim])
|
||||
v = qkv[:, num_head_q * dim + num_head_kv * dim :].reshape([-1, num_head_kv, dim])
|
||||
return q, k, v
|
||||
|
||||
def flash_attn_varlen(self, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k):
|
||||
num_head = q.shape[1]
|
||||
dim = q.shape[2]
|
||||
|
||||
q_ = q.reshape([-1, num_head, dim])
|
||||
k_ = k.reshape([-1, num_head, dim])
|
||||
v_ = v.reshape([-1, num_head, dim])
|
||||
|
||||
bsz = cu_seqlens_q.shape[0] - 1
|
||||
out = []
|
||||
for i in range(bsz):
|
||||
start_q, end_q = cu_seqlens_q[i].item(), cu_seqlens_q[i + 1].item()
|
||||
start_k, end_k = cu_seqlens_k[i].item(), cu_seqlens_k[i + 1].item()
|
||||
qi = q_[start_q:end_q] # [seq_q, nh, dim]
|
||||
ki = k_[start_k:end_k] # [seq_k, nh, dim]
|
||||
vi = v_[start_k:end_k] # [seq_k, nh, dim]
|
||||
qi = qi.transpose([1, 0, 2]) # [nh, seq_q, dim]
|
||||
ki = ki.transpose([1, 2, 0]) # [nh, dim, seq_k]
|
||||
vi = vi.transpose([1, 0, 2]) # [nh, seq_k, dim]
|
||||
|
||||
score = paddle.matmul(qi, ki) / math.sqrt(dim) # [nh, seq_q, seq_k]
|
||||
prob = F.softmax(score, axis=-1)
|
||||
o = paddle.matmul(prob, vi) # [nh, seq_q, dim]
|
||||
o = o.transpose([1, 0, 2]) # [seq_q, nh, dim]
|
||||
out.append(o)
|
||||
|
||||
return paddle.concat(out, axis=0) # [total_q, nh, dim]
|
||||
|
||||
def flash_attn_with_kvcache(self, q, cache_k, cache_v, cache_seqlens, block_tables=None):
|
||||
bs, _, nh, dim = q.shape
|
||||
out = []
|
||||
for i in range(bs):
|
||||
q_i = q[i] # [1, nh, dim]
|
||||
k_i = cache_k[i, : cache_seqlens[i, 0]] # [seqlen, nh, dim]
|
||||
v_i = cache_v[i, : cache_seqlens[i, 0]]
|
||||
qi = q_i.transpose([1, 0, 2]) # [nh, 1, dim]
|
||||
ki = k_i.transpose([1, 2, 0]) # [nh, dim, seqlen]
|
||||
vi = v_i.transpose([1, 0, 2]) # [nh, seqlen, dim]
|
||||
score = paddle.matmul(qi, ki) / math.sqrt(dim)
|
||||
prob = F.softmax(score, axis=-1)
|
||||
o = paddle.matmul(prob, vi).transpose([1, 0, 2]) # [1, nh, dim]
|
||||
out.append(o)
|
||||
return paddle.concat(out, axis=0) # [bs, nh, dim]
|
||||
|
||||
def block_cache_to_naive_cache(slef, cache_k, cache_v, bsz, block_tables, cache_seq_len):
|
||||
_, num_head, blocksize, dim_head = cache_k.shape
|
||||
out_cache_k = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype)
|
||||
out_cache_v = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype)
|
||||
for i in range(bsz):
|
||||
for j in range(cache_seq_len):
|
||||
out_cache_k[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
|
||||
out_cache_v[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
|
||||
return out_cache_k, out_cache_v
|
||||
|
||||
def block_cache_to_naive_cache__(self, cache_k, cache_v, bsz, block_tables, max_cache_seq_len):
|
||||
_, num_head, blocksize, dim_head = cache_k.shape
|
||||
out_cache_k = paddle.zeros(shape=[bsz, max_cache_seq_len + 1, num_head, dim_head], dtype=cache_k.dtype)
|
||||
out_cache_v = paddle.zeros(shape=[bsz, max_cache_seq_len + 1, num_head, dim_head], dtype=cache_v.dtype)
|
||||
for i in range(bsz):
|
||||
for j in range(max_cache_seq_len):
|
||||
out_cache_k[i, j, :, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
|
||||
out_cache_v[i, j, :, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
|
||||
return out_cache_k, out_cache_v
|
||||
|
||||
def update_encoder_kv_cache(self, k, v, seq_lens_encoder, cache_k, cache_v, block_tables):
|
||||
_, num_head, blocksize, dim_head = cache_k.shape
|
||||
offset = 0
|
||||
for batch_idx, seq_len in enumerate(seq_lens_encoder.numpy()):
|
||||
if seq_len == 0:
|
||||
continue
|
||||
for seq_idx in range(seq_len):
|
||||
block_id = block_tables[batch_idx, seq_idx // blocksize]
|
||||
assert block_id != -1
|
||||
index = offset + seq_idx
|
||||
cache_k[block_id, :, seq_idx % blocksize, :] = k[index, :, :]
|
||||
cache_v[block_id, :, seq_idx % blocksize, :] = v[index, :, :]
|
||||
|
||||
offset += seq_len
|
||||
|
||||
def update_decoder_kv_cache(self, k, v, seq_lens_decoder, cache_k, cache_v, block_tables):
|
||||
_, num_head, blocksize, dim_head = cache_k.shape
|
||||
for batch_idx, seq_idx in enumerate(seq_lens_decoder.numpy()):
|
||||
if seq_idx == 0:
|
||||
continue
|
||||
block_id = block_tables[batch_idx, seq_idx // blocksize]
|
||||
assert block_id != -1
|
||||
cache_k[block_id, :, seq_idx % blocksize, :] = k[batch_idx, :, :]
|
||||
cache_v[block_id, :, seq_idx % blocksize, :] = v[batch_idx, :, :]
|
||||
|
||||
def apply_rope(self, qk, cos, sin):
|
||||
rotate_half = paddle.reshape(
|
||||
paddle.stack([-qk[..., 1::2], qk[..., 0::2]], axis=-1),
|
||||
paddle.shape(qk),
|
||||
)
|
||||
out = paddle.add(paddle.multiply(qk, cos), paddle.multiply(rotate_half, sin))
|
||||
return paddle.cast(out, qk.dtype)
|
||||
|
||||
def forward_native_backend(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
layer,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
|
||||
bsz = forward_meta.seq_lens_this_time.shape[0]
|
||||
num_head_q, num_head_kv, dim = layer.num_heads, layer.kv_num_heads, layer.head_dim
|
||||
|
||||
# 1. 分离 encoder / decoder 的 mask
|
||||
seq_lens_encoder = forward_meta.seq_lens_encoder.squeeze(-1)
|
||||
seq_lens_decoder = forward_meta.seq_lens_decoder.squeeze(-1)
|
||||
seq_lens_this_time = forward_meta.seq_lens_this_time.squeeze(-1)
|
||||
encoder_indices = []
|
||||
decoder_indices = []
|
||||
|
||||
offset = 0
|
||||
for i in range(bsz):
|
||||
length = seq_lens_this_time[i].item()
|
||||
if seq_lens_encoder[i] > 0:
|
||||
encoder_indices.extend(range(offset, offset + length))
|
||||
elif seq_lens_decoder[i] > 0:
|
||||
decoder_indices.extend(range(offset, offset + length))
|
||||
offset += length
|
||||
|
||||
encoder_indices = paddle.to_tensor(encoder_indices, dtype="int32")
|
||||
decoder_indices = paddle.to_tensor(decoder_indices, dtype="int32")
|
||||
|
||||
encoder_qkv = paddle.index_select(qkv, encoder_indices, axis=0)
|
||||
decoder_qkv = paddle.index_select(qkv, decoder_indices, axis=0)
|
||||
|
||||
# 2. 分解 encoder 和 decoder 的 qkv
|
||||
encoder_q, encoder_k, encoder_v = self.split_qkv(encoder_qkv, num_head_q, num_head_kv, dim)
|
||||
decoder_q, decoder_k, decoder_v = self.split_qkv(decoder_qkv, num_head_q, num_head_kv, dim)
|
||||
cache_k = forward_meta.caches[2 * layer.layer_id]
|
||||
cache_v = forward_meta.caches[2 * layer.layer_id + 1]
|
||||
|
||||
# 3. Rotary Embedding
|
||||
if decoder_q.numel() != 0 or encoder_q.numel() != 0:
|
||||
for batch_idx in range(forward_meta.seq_lens_this_time.shape[0]):
|
||||
seq_len_i = forward_meta.seq_lens_this_time[batch_idx]
|
||||
if seq_len_i == 0:
|
||||
continue
|
||||
cached_kv_len = seq_lens_decoder[batch_idx]
|
||||
cu_seq_start_q = forward_meta.cu_seqlens_q[batch_idx]
|
||||
cu_seq_end_q = forward_meta.cu_seqlens_q[batch_idx + 1]
|
||||
if forward_meta.rotary_embs is not None and cu_seq_end_q > cu_seq_start_q:
|
||||
cos = forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :]
|
||||
sin = forward_meta.rotary_embs[1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :]
|
||||
|
||||
def rope_func(qk):
|
||||
qk[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(qk[cu_seq_start_q:cu_seq_end_q], cos, sin)
|
||||
|
||||
if encoder_q.numel() != 0:
|
||||
rope_func(encoder_q)
|
||||
rope_func(encoder_k)
|
||||
if decoder_q.numel() != 0:
|
||||
rope_func(decoder_q)
|
||||
rope_func(decoder_k)
|
||||
|
||||
# 4. Flash Attention for encoder
|
||||
encoder_v = encoder_v
|
||||
cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||
cu_seqlens_k = forward_meta.cu_seqlens_k
|
||||
max_seqlen_q = paddle.max(seq_lens_this_time)
|
||||
max_seqlen_k = max_seqlen_q
|
||||
|
||||
if encoder_q.numel() > 0:
|
||||
encoder_out = flash_attn_unpadded_func(
|
||||
encoder_q,
|
||||
encoder_k,
|
||||
encoder_v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
attn_mask=forward_meta.attn_mask,
|
||||
causal=self.causal,
|
||||
)
|
||||
self.update_encoder_kv_cache(
|
||||
encoder_k, encoder_v, seq_lens_encoder, cache_k, cache_v, forward_meta.block_tables
|
||||
)
|
||||
else:
|
||||
encoder_out = None
|
||||
|
||||
# 5. decoder attention with kv cache
|
||||
bs = decoder_q.shape[0]
|
||||
decoder_q = decoder_q.reshape([bs, 1, num_head_q, dim])
|
||||
decoder_k_ = decoder_k.reshape([bs, 1, num_head_kv, dim])
|
||||
decoder_v_ = decoder_v.reshape([bs, 1, num_head_kv, dim])
|
||||
cache_seqlens = paddle.index_select(forward_meta.seq_lens_decoder, decoder_indices, axis=0)
|
||||
|
||||
# 5.1 convert paged kv cache to continuous cache
|
||||
if decoder_q.numel() > 0:
|
||||
max_cache_seq_len = paddle.max(cache_seqlens)
|
||||
c_cache_k, c_cache_v = self.block_cache_to_naive_cache__(
|
||||
cache_k, cache_v, bs, forward_meta.block_tables, max_cache_seq_len
|
||||
)
|
||||
decoder_out = flash_attn_kvcache_func(
|
||||
decoder_q,
|
||||
c_cache_k,
|
||||
c_cache_v,
|
||||
cache_seqlens.squeeze(-1),
|
||||
None,
|
||||
decoder_k_,
|
||||
decoder_v_,
|
||||
causal=self.causal,
|
||||
)
|
||||
self.update_decoder_kv_cache(
|
||||
decoder_k, decoder_v, seq_lens_decoder, cache_k, cache_v, forward_meta.block_tables
|
||||
)
|
||||
else:
|
||||
decoder_out = None
|
||||
|
||||
# 6. 拼接 encoder_out 和 decoder_out
|
||||
total_len = qkv.shape[0]
|
||||
out = paddle.zeros([total_len, num_head_q, dim])
|
||||
if encoder_out is not None:
|
||||
out = paddle.tensor.put_along_axis(
|
||||
out, encoder_indices.unsqueeze(-1).unsqueeze(-1), encoder_out[0], axis=0
|
||||
)
|
||||
if decoder_out is not None:
|
||||
new_decoder_out = decoder_out[0].squeeze(1)
|
||||
out = paddle.tensor.put_along_axis(
|
||||
out, decoder_indices.unsqueeze(-1).unsqueeze(-1), new_decoder_out, axis=0
|
||||
)
|
||||
|
||||
out.reshape_([total_len, num_head_q * dim])
|
||||
|
||||
return out
|
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
|
||||
__all__ = [
|
||||
"fused_moe_kernel_paddle",
|
||||
]
|
@@ -0,0 +1,276 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
|
||||
from fastdeploy.utils import ceil_div
|
||||
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
|
||||
|
||||
class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Use Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config=None):
|
||||
"""
|
||||
Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
self.quant_config = quant_config
|
||||
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
|
||||
self.added_scale_attrs = [
|
||||
"up_gate_proj_weight_scale",
|
||||
"down_proj_weight_scale",
|
||||
]
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
|
||||
"""process_prequanted_weights"""
|
||||
pass
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(up_gate_proj_weights) == layer.num_local_experts
|
||||
assert len(down_proj_weights) == layer.num_local_experts
|
||||
|
||||
if layer.quant_method.quant_config:
|
||||
algo = layer.quant_method.quant_config.name()
|
||||
|
||||
assert up_gate_proj_weights[0].shape == [
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size * 2,
|
||||
]
|
||||
assert down_proj_weights[0].shape == [
|
||||
layer.moe_intermediate_size,
|
||||
layer.hidden_size,
|
||||
]
|
||||
|
||||
up_gate_proj_tensor = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
down_proj_tensor = paddle.stack(down_proj_weights, axis=0)
|
||||
|
||||
if algo == "wint8":
|
||||
max_bound = 127
|
||||
elif algo == "wint4":
|
||||
max_bound = 7
|
||||
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_tensor, down_proj_tensor]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
|
||||
quanted_weight_scale = weight_tensor.abs().max(axis=1)
|
||||
if self.quant_config is not None:
|
||||
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
|
||||
quanted_weight = paddle.round(quanted_weight).astype("int8")
|
||||
quanted_weight_scale = quanted_weight_scale / max_bound
|
||||
|
||||
setattr(
|
||||
layer,
|
||||
weight_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight.shape,
|
||||
dtype=quanted_weight.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
|
||||
setattr(
|
||||
layer,
|
||||
scale_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight_scale.shape,
|
||||
dtype=quanted_weight_scale.dtype,
|
||||
),
|
||||
)
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||
else:
|
||||
setattr(
|
||||
layer,
|
||||
weight_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight.shape,
|
||||
dtype=quanted_weight.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
|
||||
setattr(
|
||||
layer,
|
||||
scale_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight_scale.shape,
|
||||
dtype=quanted_weight_scale.dtype,
|
||||
),
|
||||
)
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
"""
|
||||
token_num = x.shape[0]
|
||||
top_k = layer.top_k
|
||||
num_local_experts = layer.num_local_experts
|
||||
top_k = layer.top_k
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
top_k,
|
||||
True, # apply_norm_weight,
|
||||
False,
|
||||
)
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
if self.quant_config is not None:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
|
||||
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]
|
||||
)
|
||||
max_possible_num_post_padded = sorted_token_ids.shape[0]
|
||||
grid = (
|
||||
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"])
|
||||
* ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x,
|
||||
layer.up_gate_proj_weight,
|
||||
up_gate_proj_out,
|
||||
None,
|
||||
layer.up_gate_proj_weight_scale,
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
max_possible_num_post_padded,
|
||||
token_num * top_k,
|
||||
N=moe_intermediate_size * 2,
|
||||
K=hidden_size,
|
||||
stride_am=x.strides[0],
|
||||
stride_ak=x.strides[1],
|
||||
stride_be=layer.up_gate_proj_weight.strides[0],
|
||||
stride_bk=layer.up_gate_proj_weight.strides[1],
|
||||
stride_bn=layer.up_gate_proj_weight.strides[2],
|
||||
stride_cm=up_gate_proj_out.strides[0],
|
||||
stride_cn=up_gate_proj_out.strides[1],
|
||||
#
|
||||
stride_asm=-1,
|
||||
stride_ask=-1,
|
||||
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
|
||||
stride_bsk=-1,
|
||||
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
|
||||
group_n=-1,
|
||||
group_k=-1,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
||||
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
||||
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
||||
MUL_ROUTED_WEIGHT=False,
|
||||
top_k=top_k,
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a16=True,
|
||||
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out)
|
||||
|
||||
down_proj_out = paddle.empty(
|
||||
(token_num * top_k, hidden_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
grid = (
|
||||
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"])
|
||||
* ceil_div(hidden_size, config["BLOCK_SIZE_N"]),
|
||||
)
|
||||
fused_moe_kernel_paddle[grid](
|
||||
down_proj_input,
|
||||
layer.down_proj_weight,
|
||||
down_proj_out,
|
||||
None,
|
||||
layer.down_proj_weight_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
max_possible_num_post_padded,
|
||||
token_num * top_k,
|
||||
N=hidden_size,
|
||||
K=moe_intermediate_size,
|
||||
stride_am=down_proj_input.strides[0],
|
||||
stride_ak=down_proj_input.strides[1],
|
||||
stride_be=layer.down_proj_weight.strides[0],
|
||||
stride_bk=layer.down_proj_weight.strides[1],
|
||||
stride_bn=layer.down_proj_weight.strides[2],
|
||||
stride_cm=down_proj_out.strides[0],
|
||||
stride_cn=down_proj_out.strides[1],
|
||||
stride_asm=-1,
|
||||
stride_ask=-1,
|
||||
stride_bse=layer.down_proj_weight_scale.strides[0],
|
||||
stride_bsk=-1,
|
||||
stride_bsn=layer.down_proj_weight_scale.strides[1],
|
||||
group_n=-1,
|
||||
group_k=-1,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
||||
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
||||
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
||||
MUL_ROUTED_WEIGHT=True,
|
||||
top_k=1,
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a16=True,
|
||||
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
down_proj_out.reshape_([token_num, top_k, hidden_size])
|
||||
out = down_proj_out.sum(axis=1)
|
||||
return out
|
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_moe_kernel_paddle(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
topk_weights_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
num_tokens_post_padded_ptr,
|
||||
# Matrix dimensions
|
||||
max_possible_num_post_padded,
|
||||
num_valid_tokens,
|
||||
N,
|
||||
K,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
# Block size for block-wise fp8 quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
compute_type_enum: tl.constexpr,
|
||||
use_fp8_w8a8: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr,
|
||||
even_Ks: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
|
||||
Key Parameters:
|
||||
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
||||
be any shape representing batches and K is the feature dimension of
|
||||
each token.
|
||||
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
||||
the number of experts, K is the input feature dimension, and N is
|
||||
the output feature dimension.
|
||||
- C: The output cache tensor with shape (M, topk, N), where M is the
|
||||
total number of tokens post padding, topk is the number of times
|
||||
each token is repeated, and N is the output feature dimension.
|
||||
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
||||
repeated topk times and arranged by the expert index they are
|
||||
assigned to.
|
||||
- expert_ids: A tensor containing the indices of the expert for each
|
||||
block. It determines which expert matrix from B should be used for
|
||||
each block in A.
|
||||
This kernel performs the multiplication of a token by its corresponding
|
||||
expert matrix as determined by `expert_ids`. The sorting of
|
||||
`sorted_token_ids` by expert index and padding ensures divisibility by
|
||||
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
||||
multiplication across different blocks processed by the same expert.
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(max_possible_num_post_padded, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
assert compute_type_enum == 1
|
||||
compute_type = tl.bfloat16
|
||||
|
||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||
return
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak)
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m)
|
||||
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
if use_int8_w8a16:
|
||||
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
|
||||
if use_fp8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
||||
offs_bsn = offs_bn // group_n
|
||||
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
||||
else:
|
||||
# (Zkk): every expert has one activation scale and weight scale.
|
||||
a_scale = tl.load(a_scale_ptr + off_experts)
|
||||
b_scale = tl.load(b_scale_ptr + off_experts)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
if even_Ks:
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs, cache_modifier=".ca", eviction_policy="evict_first")
|
||||
else:
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
|
||||
# We accumulate along the K dimension.
|
||||
if use_int8_w8a16:
|
||||
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
||||
elif use_fp8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
k_start = k * BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
a_scale = tl.load(
|
||||
a_scale_ptrs + offs_ks * stride_ask,
|
||||
mask=token_mask,
|
||||
other=0.0,
|
||||
)
|
||||
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
||||
|
||||
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
||||
else:
|
||||
accumulator = tl.dot(a, b, acc=accumulator)
|
||||
else:
|
||||
accumulator += tl.dot(a, b)
|
||||
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
if use_int8_w8a16:
|
||||
accumulator = (accumulator * b_scale).to(compute_type)
|
||||
elif use_fp8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
else:
|
||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||
else:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
@@ -107,6 +107,7 @@ class LinearBase(nn.Layer):
|
||||
or current_platform.is_iluvatar()
|
||||
or current_platform.is_gcu()
|
||||
or current_platform.is_dcu()
|
||||
or current_platform.is_maca()
|
||||
):
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
|
@@ -49,6 +49,12 @@ def get_moe_method():
|
||||
from fastdeploy.model_executor.layers.backends import GCUFusedMoeMethod
|
||||
|
||||
return GCUFusedMoeMethod(None)
|
||||
elif current_platform.is_maca():
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
MetaxTritonWeightOnlyMoEMethod,
|
||||
)
|
||||
|
||||
return MetaxTritonWeightOnlyMoEMethod(None)
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
@@ -94,6 +94,16 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
)
|
||||
|
||||
return DCUWeightOnlyLinearMethod(self)
|
||||
elif current_platform.is_maca():
|
||||
if isinstance(layer, FusedMoE):
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
MetaxTritonWeightOnlyMoEMethod,
|
||||
)
|
||||
|
||||
return MetaxTritonWeightOnlyMoEMethod(self)
|
||||
else:
|
||||
|
||||
return GPUWeightOnlyLinearMethod(self)
|
||||
else:
|
||||
if isinstance(layer, FusedMoE):
|
||||
if layer.use_method == "cutlass":
|
||||
@@ -196,14 +206,24 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
raise NotImplementedError
|
||||
|
||||
def apply(self, layer, x):
|
||||
linear_out = weight_only_linear(
|
||||
x,
|
||||
weight=layer.weight,
|
||||
bias=layer.bias if layer.add_bias else None,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"),
|
||||
arch=self.quant_config.weight_only_linear_arch,
|
||||
)
|
||||
if current_platform.is_maca():
|
||||
linear_out = weight_only_linear(
|
||||
x,
|
||||
weight=layer.weight,
|
||||
bias=layer.bias if layer.add_bias else None,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"),
|
||||
arch=80,
|
||||
)
|
||||
else:
|
||||
linear_out = weight_only_linear(
|
||||
x,
|
||||
weight=layer.weight,
|
||||
bias=layer.bias if layer.add_bias else None,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"),
|
||||
arch=self.quant_config.weight_only_linear_arch,
|
||||
)
|
||||
return linear_out
|
||||
|
||||
|
||||
@@ -240,6 +260,7 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
algo=self.quant_config.algo,
|
||||
arch=self.quant_config.weight_only_linear_arch,
|
||||
)
|
||||
|
||||
if current_platform.is_maca():
|
||||
quanted_weight_tensor = paddle.transpose(quanted_weight_tensor, [1, 0])
|
||||
layer.weight.set_value(quanted_weight_tensor)
|
||||
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))
|
||||
|
@@ -51,6 +51,10 @@ class ErnieRotaryEmbedding:
|
||||
# shape: [B, S, D]
|
||||
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
|
||||
return rot_emb
|
||||
elif paddle.is_compiled_with_custom_device("metax_gpu"):
|
||||
# shape: [B, S, D]
|
||||
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32")
|
||||
emb = paddle.stack([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim))
|
||||
else:
|
||||
# shape: [B, S, D/2]
|
||||
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
|
||||
|
@@ -119,6 +119,23 @@ def apply_penalty_multi_scores(
|
||||
min_dec_lens,
|
||||
eos_token_ids,
|
||||
)
|
||||
elif current_platform.is_maca():
|
||||
from fastdeploy.model_executor.ops.gpu import get_token_penalty_multi_scores
|
||||
|
||||
logits = get_token_penalty_multi_scores(
|
||||
pre_token_ids,
|
||||
prompt_ids,
|
||||
prompt_lens,
|
||||
logits,
|
||||
repetition_penalties,
|
||||
frequency_penalties,
|
||||
presence_penalties,
|
||||
temperature,
|
||||
bad_words_token_ids,
|
||||
step_idx,
|
||||
min_dec_lens,
|
||||
eos_token_ids,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@@ -177,6 +177,7 @@ class Sampler(nn.Layer):
|
||||
or current_platform.is_iluvatar()
|
||||
or current_platform.is_gcu()
|
||||
or current_platform.is_dcu()
|
||||
or current_platform.is_maca()
|
||||
):
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
|
@@ -45,6 +45,14 @@ elif current_platform.is_dcu():
|
||||
step_paddle,
|
||||
update_inputs,
|
||||
)
|
||||
elif current_platform.is_maca():
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
get_padding_offset,
|
||||
save_output,
|
||||
set_stop_value_multi_ends,
|
||||
step_paddle,
|
||||
update_inputs,
|
||||
)
|
||||
else:
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
get_padding_offset,
|
||||
@@ -225,6 +233,19 @@ def post_process_normal(
|
||||
model_output.stop_seqs_len,
|
||||
False,
|
||||
) # multi ends
|
||||
elif current_platform.is_maca():
|
||||
set_stop_value_multi_ends(
|
||||
sampler_output.sampled_token_ids,
|
||||
model_output.stop_flags,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.eos_token_id,
|
||||
model_output.next_tokens,
|
||||
model_output.pre_ids,
|
||||
model_output.step_idx,
|
||||
model_output.stop_token_ids,
|
||||
model_output.stop_seqs_len,
|
||||
False,
|
||||
) # multi ends
|
||||
else:
|
||||
set_stop_value_multi_ends(
|
||||
sampler_output.sampled_token_ids,
|
||||
@@ -573,6 +594,18 @@ def rebuild_padding(
|
||||
output_padding_offset,
|
||||
max_input_length,
|
||||
)
|
||||
elif current_platform.is_maca():
|
||||
from fastdeploy.model_executor.ops.gpu import rebuild_padding
|
||||
|
||||
hidden_states = rebuild_padding(
|
||||
tmp_out,
|
||||
cum_offsets,
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
max_input_length,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Not supported platform")
|
||||
return hidden_states
|
||||
|
@@ -23,6 +23,7 @@ from .cuda import CUDAPlatform
|
||||
from .dcu import DCUPlatform
|
||||
from .gcu import GCUPlatform
|
||||
from .iluvatar import IluvatarPlatform
|
||||
from .maca import MACAPlatform
|
||||
from .npu import NPUPlatform
|
||||
from .xpu import XPUPlatform
|
||||
|
||||
@@ -46,6 +47,8 @@ def __getattr__(name: str):
|
||||
_current_platform = IluvatarPlatform()
|
||||
elif paddle.is_compiled_with_custom_device("gcu"):
|
||||
_current_platform = GCUPlatform()
|
||||
elif paddle.is_compiled_with_custom_device("metax_gpu"):
|
||||
_current_platform = MACAPlatform()
|
||||
else:
|
||||
_current_platform = CPUPlatform()
|
||||
return _current_platform
|
||||
|
@@ -77,6 +77,12 @@ class Platform:
|
||||
"""
|
||||
return paddle.is_compiled_with_custom_device("gcu")
|
||||
|
||||
def is_maca(self) -> bool:
|
||||
"""
|
||||
whether platform is metax gpu
|
||||
"""
|
||||
return paddle.is_compiled_with_custom_device("metax_gpu")
|
||||
|
||||
@classmethod
|
||||
def get_attention_backend_cls(self, selected_backend):
|
||||
"""Get the attention backend"""
|
||||
|
65
fastdeploy/platforms/maca.py
Normal file
65
fastdeploy/platforms/maca.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
# Copyright (c) 2025 MetaX-tech Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
"""
|
||||
maca platform file
|
||||
"""
|
||||
|
||||
import paddle
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from .base import Platform, _Backend
|
||||
|
||||
|
||||
class MACAPlatform(Platform):
|
||||
"""
|
||||
maca platform class
|
||||
"""
|
||||
|
||||
device_name = "metax_gpu"
|
||||
|
||||
@classmethod
|
||||
def available(self):
|
||||
"""
|
||||
Check whether MACA is available.
|
||||
"""
|
||||
try:
|
||||
assert len(paddle.static.cuda_places()) > 0
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"You are using GPU version PaddlePaddle, but there is no GPU "
|
||||
"detected on your machine. Maybe CUDA devices is not set properly."
|
||||
f"\n Original Error is {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_attention_backend_cls(cls, selected_backend: _Backend):
|
||||
"""
|
||||
get_attention_backend_cls
|
||||
"""
|
||||
if selected_backend == _Backend.NATIVE_ATTN:
|
||||
logger.info("Using NATIVE ATTN backend.")
|
||||
return "fastdeploy.model_executor.layers.attention.PaddleNativeAttnBackend"
|
||||
elif selected_backend == _Backend.APPEND_ATTN:
|
||||
logger.info("Using FLASH ATTN backend to instead of attend attention.")
|
||||
return "fastdeploy.model_executor.layers.backends.metax.attention.flash_attn_backend.FlashAttentionBackend"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid attention backend you specified.\n"
|
||||
"Now only support [NATIVE_ATTN, MLA_ATTN, APPEND_ATTN] in cuda place."
|
||||
)
|
1664
fastdeploy/worker/metax_model_runner.py
Normal file
1664
fastdeploy/worker/metax_model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
203
fastdeploy/worker/metax_worker.py
Normal file
203
fastdeploy/worker/metax_worker.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.utils import get_logger
|
||||
from fastdeploy.worker.metax_model_runner import MetaxModelRunner
|
||||
from fastdeploy.worker.output import ModelRunnerOutput
|
||||
from fastdeploy.worker.worker_base import WorkerBase
|
||||
|
||||
logger = get_logger("metax_worker", "metax_worker.log")
|
||||
|
||||
|
||||
class MetaxWorker(WorkerBase):
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
):
|
||||
super().__init__(
|
||||
fd_config=fd_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
)
|
||||
pass
|
||||
|
||||
def init_device(self):
|
||||
"""
|
||||
Initialize device and construct model runner
|
||||
"""
|
||||
self.max_chips_per_node = 8
|
||||
if paddle.is_compiled_with_custom_device("metax_gpu"):
|
||||
# Set evironment variable
|
||||
self.device_ids = self.parallel_config.device_ids.split(",")
|
||||
self.device = f"metax_gpu:{self.local_rank % self.max_chips_per_node}"
|
||||
paddle.device.set_device(self.device)
|
||||
paddle.set_default_dtype(self.parallel_config.dtype)
|
||||
|
||||
gc.collect()
|
||||
paddle.device.cuda.empty_cache()
|
||||
else:
|
||||
raise RuntimeError(f"Not support device type: {self.device_config.device}")
|
||||
|
||||
# Construct model runner
|
||||
self.model_runner: MetaxModelRunner = MetaxModelRunner(
|
||||
fd_config=self.fd_config,
|
||||
device=self.device,
|
||||
device_id=self.device_ids[self.local_rank % self.max_chips_per_node],
|
||||
rank=self.rank,
|
||||
local_rank=self.local_rank,
|
||||
)
|
||||
|
||||
def exist_prefill(self):
|
||||
"""
|
||||
check whether prefill stage exist
|
||||
"""
|
||||
return self.model_runner.exist_prefill()
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
"""
|
||||
Profiles the peak memory usage of the model to determine how much
|
||||
memory can be used for KV cache without OOMs.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||
that can be allocated with the remaining free memory.
|
||||
|
||||
Tip:
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
"""Will implement later"""
|
||||
|
||||
# 1. Record memory state before profile run
|
||||
start_time = time.perf_counter()
|
||||
Gb = 1024**3
|
||||
|
||||
local_rank = self.local_rank % self.max_chips_per_node
|
||||
paddle.device.cuda.reset_max_memory_reserved(local_rank)
|
||||
paddle.device.cuda.reset_max_memory_allocated(local_rank)
|
||||
# max memory for Allocator
|
||||
paddle_reserved_mem_before_run = paddle.device.cuda.max_memory_reserved(local_rank)
|
||||
# max memory for Tensor
|
||||
paddle_allocated_mem_before_run = paddle.device.cuda.max_memory_allocated(local_rank) # not reserved
|
||||
|
||||
device_id = int(self.device_ids[local_rank])
|
||||
if os.getenv("MACA_VISIBLE_DEVICES") is not None:
|
||||
device_id = int(os.getenv("MACA_VISIBLE_DEVICES").split(",")[device_id])
|
||||
|
||||
import pymxsml
|
||||
|
||||
pymxsml.mxSmlInit()
|
||||
info = pymxsml.mxSmlGetMemoryInfo(device_id)
|
||||
before_run_meminfo_total = info.vramTotal * 1024
|
||||
before_run_meminfo_used = info.vramUse * 1024
|
||||
before_run_meminfo_free = before_run_meminfo_total - before_run_meminfo_used
|
||||
|
||||
logger.info("Before running the profile, the memory usage info of Metax GPU is as follows:")
|
||||
logger.info(f"Device Index: {device_id}")
|
||||
logger.info(f"Device Total memory: {before_run_meminfo_total / Gb}")
|
||||
logger.info(f"Device used memory: {before_run_meminfo_used / Gb}")
|
||||
logger.info(f"Device free memory: {before_run_meminfo_free / Gb}")
|
||||
logger.info(f"Paddle reserved memory: {paddle_reserved_mem_before_run / Gb}")
|
||||
logger.info(f"Paddle allocated memory: {paddle_allocated_mem_before_run / Gb}")
|
||||
|
||||
# 2. Profile run
|
||||
self.model_runner.profile_run()
|
||||
|
||||
# 3. Statistical memory information
|
||||
paddle_reserved_mem_after_run = paddle.device.cuda.max_memory_reserved(local_rank)
|
||||
paddle_allocated_mem_after_run = paddle.device.cuda.max_memory_allocated(local_rank)
|
||||
|
||||
model_block_memory_used = self.cal_theortical_kvcache()
|
||||
paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run
|
||||
|
||||
paddle.device.cuda.empty_cache()
|
||||
|
||||
info = pymxsml.mxSmlGetMemoryInfo(device_id)
|
||||
after_run_meminfo_total = info.vramTotal * 1024
|
||||
after_run_meminfo_used = info.vramUse * 1024
|
||||
after_run_meminfo_free = after_run_meminfo_total - after_run_meminfo_used
|
||||
|
||||
available_kv_cache_memory = (
|
||||
after_run_meminfo_total * self.cache_config.gpu_memory_utilization
|
||||
- after_run_meminfo_used
|
||||
- paddle_peak_increase
|
||||
)
|
||||
available_kv_cache_memory += model_block_memory_used * self.parallel_config.total_block_num
|
||||
|
||||
end_time = time.perf_counter()
|
||||
|
||||
logger.info("After running the profile, the memory usage info of Metax GPU is as follows:")
|
||||
logger.info(f"Device Index: {device_id}")
|
||||
logger.info(f"Device Total memory: {after_run_meminfo_total / Gb}")
|
||||
logger.info(f"Device used memory: {after_run_meminfo_used / Gb}")
|
||||
logger.info(f"Device free memory: {after_run_meminfo_free / Gb}")
|
||||
logger.info(f"Paddle reserved memory: {paddle_reserved_mem_after_run / Gb}")
|
||||
logger.info(f"Paddle allocated memory: {paddle_allocated_mem_after_run / Gb}")
|
||||
logger.info(f"Paddle available_kv_cache_memory: {available_kv_cache_memory / Gb}")
|
||||
logger.info(f"Profile time: {end_time - start_time}")
|
||||
|
||||
return available_kv_cache_memory
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""Load model"""
|
||||
self.model_runner.load_model()
|
||||
|
||||
def get_model(self) -> nn.Layer:
|
||||
"""Get current model"""
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int) -> None:
|
||||
"""Initizlize the KV Cache with accurate num_gpu_blocks"""
|
||||
# accurate cache size
|
||||
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
model_forward_batch: Optional[List[Request]] = None,
|
||||
) -> Optional[ModelRunnerOutput]:
|
||||
""" """
|
||||
output = self.model_runner.execute_model(model_forward_batch)
|
||||
return output
|
||||
|
||||
def preprocess_new_task(self, req_dicts: List[Request]) -> None:
|
||||
"""Process new requests and then start the decode loop
|
||||
and workers and modelrunners should not perceive it.
|
||||
"""
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.model_runner.insert_tasks_v1(req_dicts=req_dicts)
|
||||
else:
|
||||
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts)
|
||||
|
||||
def check_health(self) -> bool:
|
||||
""" """
|
||||
return True
|
||||
|
||||
def cal_theortical_kvcache(self) -> int:
|
||||
"""Calculate the block memory required"""
|
||||
return self.model_runner.cal_theortical_kvcache()
|
@@ -75,6 +75,10 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase:
|
||||
from fastdeploy.worker.gcu_worker import GcuWorker
|
||||
|
||||
return GcuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
|
||||
if current_platform.is_maca():
|
||||
from fastdeploy.worker.metax_worker import MetaxWorker
|
||||
|
||||
return MetaxWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
|
||||
|
||||
|
||||
def init_distributed_environment(seed: int = 20) -> Tuple[int, int]:
|
||||
|
39
requirements_metaxgpu.txt
Normal file
39
requirements_metaxgpu.txt
Normal file
@@ -0,0 +1,39 @@
|
||||
setuptools>=62.3.0,<80.0
|
||||
pre-commit
|
||||
yapf
|
||||
flake8
|
||||
ruamel.yaml
|
||||
zmq
|
||||
aiozmq
|
||||
openai>=1.93.0
|
||||
tqdm
|
||||
pynvml
|
||||
uvicorn
|
||||
fastapi
|
||||
paddleformers
|
||||
redis
|
||||
etcd3
|
||||
httpx
|
||||
tool_helpers
|
||||
cupy-cuda12x
|
||||
pybind11[global]
|
||||
tabulate
|
||||
gradio
|
||||
xlwt
|
||||
visualdl
|
||||
setuptools-scm>=8
|
||||
prometheus-client
|
||||
decord
|
||||
moviepy
|
||||
triton
|
||||
use-triton-in-paddle
|
||||
crcmod
|
||||
fastsafetensors==0.1.14
|
||||
msgpack
|
||||
opentelemetry-api>=1.24.0
|
||||
opentelemetry-sdk>=1.24.0
|
||||
opentelemetry-instrumentation-redis
|
||||
opentelemetry-instrumentation-mysql
|
||||
opentelemetry-distro
|
||||
opentelemetry-exporter-otlp
|
||||
opentelemetry-instrumentation-fastapi
|
6
setup.py
6
setup.py
@@ -151,13 +151,15 @@ def load_requirements():
|
||||
requirements_file_name = "requirements_iluvatar.txt"
|
||||
elif paddle.is_compiled_with_rocm():
|
||||
requirements_file_name = "requirements_dcu.txt"
|
||||
elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
requirements_file_name = "requirements_metaxgpu.txt"
|
||||
requirements_path = os.path.join(os.path.dirname(__file__), requirements_file_name)
|
||||
with open(requirements_path, "r") as f:
|
||||
return [line.strip() for line in f if line.strip() and not line.startswith("#")]
|
||||
|
||||
|
||||
def get_device_type():
|
||||
"""Get the device type (rocm/gpu/xpu/npu/cpu) that paddle is compiled with."""
|
||||
"""Get the device type (rocm/gpu/xpu/npu/cpu/metax-gpu) that paddle is compiled with."""
|
||||
if paddle.is_compiled_with_rocm():
|
||||
return "rocm"
|
||||
elif paddle.is_compiled_with_cuda():
|
||||
@@ -170,6 +172,8 @@ def get_device_type():
|
||||
return "iluvatar-gpu"
|
||||
elif paddle.is_compiled_with_custom_device("gcu"):
|
||||
return "gcu"
|
||||
elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
return "metax-gpu"
|
||||
else:
|
||||
return "cpu"
|
||||
|
||||
|
Reference in New Issue
Block a user