[Feature] support chunked moe (#4575)

* [Feature] support chunked moe

* update

* update

* fix and add test

* update

* fix conflict and modity test

* fix fused_moe

* fix fused_moe

* fix docstring

* fix

* fix typo

* fix test

* fix

* fix

* fix test

* fix test
This commit is contained in:
Longzhi Wang
2025-12-01 15:17:18 +08:00
committed by GitHub
parent 6f42c37359
commit add524d80c
10 changed files with 405 additions and 5 deletions

View File

@@ -540,6 +540,11 @@ class ParallelConfig:
self.expert_parallel_size = 1 # EP degree
self.data_parallel_size = 1 # DP degree
self.enable_expert_parallel = False
self.enable_chunked_moe = False
self.chunked_moe_size = 256
self.max_moe_num_chunk = 1
self.moe_num_chunk = 1
self.local_data_parallel_id = 0
# Engine worker queue port
self.engine_worker_queue_port: str = "9923"

View File

@@ -286,6 +286,16 @@ class EngineArgs:
Enable expert parallelism.
"""
enable_chunked_moe: bool = False
"""
Whether use chunked moe.
"""
chunked_moe_size: int = 256
"""
Chunk size of moe input.
"""
cache_transfer_protocol: str = "ipc"
"""
Protocol to use for cache transfer.
@@ -870,6 +880,18 @@ class EngineArgs:
default=EngineArgs.eplb_config,
help="Config of eplb.",
)
parallel_group.add_argument(
"--enable-chunked-moe",
action="store_true",
default=EngineArgs.enable_chunked_moe,
help="Use chunked moe.",
)
parallel_group.add_argument(
"--chunked-moe-size",
type=int,
default=EngineArgs.chunked_moe_size,
help="Chunked size of moe input.",
)
# Load group
load_group = parser.add_argument_group("Load Configuration")

View File

@@ -812,6 +812,7 @@ class AsyncLLMEngine:
f" --splitwise_role {self.cfg.scheduler_config.splitwise_role}"
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}"
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
f" --chunked_moe_size {self.cfg.parallel_config.chunked_moe_size}"
f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}"
f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'"
f" --ori_vocab_size {ori_vocab_size}"

View File

@@ -544,6 +544,7 @@ class LLMEngine:
f" --splitwise_role {self.cfg.scheduler_config.splitwise_role}"
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}"
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
f" --chunked_moe_size {self.cfg.parallel_config.chunked_moe_size}"
f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}"
f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'"
f" --ori_vocab_size {ori_vocab_size}"
@@ -573,6 +574,7 @@ class LLMEngine:
worker_store_true_flag = {
"enable_expert_parallel": self.cfg.parallel_config.enable_expert_parallel,
"enable_chunked_moe": self.cfg.parallel_config.enable_chunked_moe,
"enable_prefix_caching": self.cfg.cache_config.enable_prefix_caching,
"enable_chunked_prefill": self.cfg.cache_config.enable_chunked_prefill,
"do_profile": self.do_profile,

View File

@@ -612,6 +612,7 @@ class FusedMoE(nn.Layer):
multi_outs = paddle.zeros([token_num_per_rank * self.tp_size, x.shape[1]], dtype=x.dtype)
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
out = multi_outs[:token_num, :]
return out
def forward(self, x: paddle.Tensor, gate: nn.Layer):
@@ -633,9 +634,63 @@ class FusedMoE(nn.Layer):
and token_num >= self.tp_size
):
out = self.forward_split_allgather(x, gate)
elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe:
out = self.forward_chunked_moe(x, gate)
else:
out = self.quant_method.apply(self, x, gate)
out = self.forward_normal(x, gate)
if self.reduce_results and self.tp_size > 1:
out = tensor_model_parallel_all_reduce(out, self.tp_group)
return out
def forward_chunked_moe(self, x: paddle.Tensor, gate: nn.Layer):
"""
Split input to multi chunk to reduce the memory usage of moe.
Args:
x (Tensor): Input tensor to the moe layer.
Returns:
Tensor: Output tensor.s
"""
chunk_size = self.fd_config.parallel_config.chunked_moe_size
token_num = x.shape[0]
fake_x = paddle.empty(
shape=[0, self.fd_config.model_config.hidden_size],
dtype=paddle.get_default_dtype(),
)
# input size that are less than a chunk, less than the max size data or empty input
# need to be repeated until the max chunk data infer MOE finished.
if token_num > chunk_size: # chunked moe
x_split_list = paddle.tensor_split(x, self.fd_config.parallel_config.moe_num_chunk, axis=0)
out_split_list = [None] * self.fd_config.parallel_config.moe_num_chunk
for i in range(self.fd_config.parallel_config.max_moe_num_chunk):
if i < self.fd_config.parallel_config.moe_num_chunk:
out_split_list[i] = self.quant_method.apply(self, x_split_list[i], gate)
else:
# just need to use real data to infer max_moe_num_chunk times.
self.quant_method.apply(self, fake_x, gate)
out = paddle.concat(out_split_list, axis=0)
else:
# when only one chunk, just need to use real data to infer once.
out = self.quant_method.apply(self, x, gate)
for i in range(self.fd_config.parallel_config.max_moe_num_chunk - 1):
self.quant_method.apply(self, fake_x, gate)
return out
def forward_normal(self, x: paddle.Tensor, gate: nn.Layer):
"""
Normal mode of forward.
Args:
x (Tensor): Input tensor to the moe layer.
Returns:
Tensor: Output tensor.s
"""
out = self.quant_method.apply(self, x, gate)
return out

View File

@@ -95,7 +95,11 @@ from fastdeploy.model_executor.logits_processor import build_logits_processors
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling
from fastdeploy.output.pooler import PoolerOutput
from fastdeploy.worker.model_runner_base import ModelRunnerBase
from fastdeploy.worker.model_runner_base import (
DistributedOut,
DistributedStatus,
ModelRunnerBase,
)
from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, ModelRunnerOutput
@@ -250,6 +254,56 @@ class GPUModelRunner(ModelRunnerBase):
return if_only_prefill
def collect_distributed_status(self):
"""
Collect distributed status
"""
dist_status_list = []
dist_status_obj = DistributedStatus()
dist_out = DistributedOut()
prefill_exists = None
if_only_decode = True
# mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
prefill_exists = self.exist_prefill()
dist_status_obj.only_decode = not prefill_exists
# whether chunked moe
if self.fd_config.parallel_config.enable_chunked_moe:
chunk_size = self.fd_config.parallel_config.chunked_moe_size
token_num = self.share_inputs["ids_remove_padding"].shape[0]
if token_num > chunk_size:
self.fd_config.parallel_config.moe_num_chunk = (token_num + chunk_size - 1) // chunk_size
else:
self.fd_config.parallel_config.moe_num_chunk = 1
dist_status_obj.moe_num_chunk = self.fd_config.parallel_config.moe_num_chunk
# only ep need to collect and sync distributed status
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
# call once to gather all status
paddle.distributed.all_gather_object(dist_status_list, dist_status_obj)
# Update Batch type for cuda graph for if_only_decode
if_only_decode = all(dist_status.only_decode for dist_status in dist_status_list)
if_only_decode = if_only_decode and not (
prefill_exists if prefill_exists is not None else self.exist_prefill()
)
max_moe_num_chunk = None
if self.fd_config.parallel_config.enable_chunked_moe:
max_moe_num_chunk = max(dist_status.moe_num_chunk for dist_status in dist_status_list)
dist_out = DistributedOut(
if_only_decode=if_only_decode,
max_moe_num_chunk=max_moe_num_chunk,
)
return dist_out
def only_decode(self):
"""
check whether decode only
@@ -1355,7 +1409,7 @@ class GPUModelRunner(ModelRunnerBase):
def initialize_forward_meta(self, is_dummy_or_profile_run=False):
"""
Initialize forward meta and attention meta data
Initialize forward meta, attention meta data and update some config.
"""
# Initialize forward meta
self.forward_meta = ForwardMeta(
@@ -1386,8 +1440,12 @@ class GPUModelRunner(ModelRunnerBase):
kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"],
)
# Update Batch type for cuda graph for only_decode_batch
if_only_decode = self.only_decode()
dist_status = self.collect_distributed_status()
if_only_decode = dist_status.if_only_decode
if self.fd_config.parallel_config.enable_chunked_moe:
self.fd_config.parallel_config.max_moe_num_chunk = dist_status.max_moe_num_chunk
only_decode_use_cudagraph = self.use_cudagraph and if_only_decode
# Update config about moe for better performance

View File

@@ -15,6 +15,8 @@
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional
from paddle import nn
@@ -25,6 +27,18 @@ from fastdeploy.worker.output import ModelRunnerOutput
logger = get_logger("model_runner_base", "model_runner_base.log")
@dataclass
class DistributedStatus:
only_decode: bool = True
moe_num_chunk: int = 1
@dataclass
class DistributedOut:
if_only_decode: bool = True
max_moe_num_chunk: Optional[int] = None
class ModelRunnerBase(ABC):
"""
Engine -> (WIP)Executor -> Worker -> ModelRunner -> Model

View File

@@ -720,6 +720,17 @@ def parse_args():
action="store_true",
help="enable expert parallel",
)
parser.add_argument(
"--enable_chunked_moe",
action="store_true",
help="enable chunked moe",
)
parser.add_argument(
"--chunked_moe_size",
type=int,
default=256,
help="chunk size of moe input",
)
parser.add_argument("--ori_vocab_size", type=int, default=None)
parser.add_argument("--think_end_id", type=int, default=-1)
parser.add_argument("--image_patch_id", type=int, default=-1)

View File

@@ -0,0 +1,183 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from unittest.mock import Mock
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
from fastdeploy.config import MoEPhase
from fastdeploy.model_executor.layers.moe import FusedMoE
from fastdeploy.worker.gpu_model_runner import GPUModelRunner
class MockStructuredOutputsConfig:
logits_processors = []
class MockModelConfig:
max_model_len = 10
pad_token_id = 0
eos_tokens_lens = 1
eos_token_id = 0
temperature = 1.0
penalty_score = 1.0
frequency_score = 1.0
min_length = 1
vocab_size = 1
top_p = 1.0
presence_score = 1.0
max_stop_seqs_num = 5
stop_seqs_max_len = 2
head_dim = 128
model_type = ["mock"]
moe_phase = MoEPhase(phase="prefill")
hidden_size = 1536
class MockCacheConfig:
block_size = 64
total_block_num = 256
kv_cache_ratio = 0.9
enc_dec_block_num = 2
class MockFDConfig:
class ParallelConfig:
enable_expert_parallel = True
enable_chunked_moe = True
chunked_moe_size = 2
max_moe_num_chunk = 1
moe_num_chunk = 1
use_ep = True
use_sequence_parallel_moe = False
class SchedulerConfig:
name = "default"
splitwise_role = "mixed"
max_num_seqs = 2
parallel_config = ParallelConfig()
scheduler_config = SchedulerConfig()
structured_outputs_config = MockStructuredOutputsConfig()
model_config = MockModelConfig()
class MockAttentionBackend:
def __init__(self):
pass
def init_attention_metadata(self, forward_meta):
pass
class MockQuantMethod:
def apply(self, layer, x, gate):
return x
class TestChunkedMoE(unittest.TestCase):
def setUp(self) -> None:
paddle.seed(2025)
strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": 2,
"pp_degree": 1,
"sharding_degree": 1,
}
fleet.init(is_collective=True, strategy=strategy)
self.model_runner = self.setup_model_runner()
self.fused_moe = self.setup_fused_moe()
def setup_model_runner(self):
"""Helper method to setup GPUModelRunner with different configurations"""
mock_fd_config = MockFDConfig()
mock_model_config = MockModelConfig()
mock_cache_config = MockCacheConfig()
model_runner = GPUModelRunner.__new__(GPUModelRunner)
model_runner.fd_config = mock_fd_config
model_runner.model_config = mock_model_config
model_runner.cache_config = mock_cache_config
model_runner.attn_backends = [MockAttentionBackend()]
model_runner.enable_mm = True
model_runner.cudagraph_only_prefill = False
model_runner.use_cudagraph = False
model_runner.speculative_decoding = False
model_runner._init_share_inputs(mock_fd_config.scheduler_config.max_num_seqs)
model_runner.share_inputs["caches"] = None
if dist.get_rank() == 0:
model_runner.share_inputs["ids_remove_padding"] = paddle.ones([10])
else:
model_runner.share_inputs["ids_remove_padding"] = paddle.ones([1])
return model_runner
def setup_fused_moe(self):
mock_fd_config = MockFDConfig()
fused_moe = FusedMoE.__new__(FusedMoE)
fused_moe.ep_size = 2
fused_moe.tp_size = 1
fused_moe.reduce_results = True
fused_moe.fd_config = mock_fd_config
fused_moe.quant_method = MockQuantMethod()
return fused_moe
def run_model_runner(self):
self.model_runner.initialize_forward_meta()
assert self.model_runner.fd_config.parallel_config.max_moe_num_chunk == 5, (
f"chunk size is 2, max token_num is 10, max_moe_num_chunk should be 5, "
f"but got {self.model_runner.fd_config.parallel_config.max_moe_num_chunk}"
)
if dist.get_rank() == 0:
assert self.model_runner.fd_config.parallel_config.moe_num_chunk == 5, (
f"chunk size is 2, token_num is 10, moe_num_chunk in rank 0 should be 5"
f"but got {self.model_runner.fd_config.parallel_config.moe_num_chunk}"
)
else:
assert self.model_runner.fd_config.parallel_config.moe_num_chunk == 1, (
f"chunk size is 2, token_num is 1, moe_num_chunk in rank 1 should be 1"
f", but got {self.model_runner.fd_config.parallel_config.moe_num_chunk}"
)
def run_fused_moe(self):
gate = Mock()
if dist.get_rank() == 0:
x = paddle.ones([10])
else:
x = paddle.ones([1])
out = self.fused_moe.forward(x, gate)
assert out.shape == x.shape
def test_case(self):
# check whether dist collected max_moe_num_chunk is correct.
self.run_model_runner()
# check the forward method of chunked MoE.
self.run_fused_moe()
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,49 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
def test_fused_moe_launch():
"""
test_fused_moe
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
chunked_moe_script = os.path.join(current_dir, "chunked_moe.py")
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
command = [
sys.executable,
"-m",
"paddle.distributed.launch",
"--gpus",
"0,1",
chunked_moe_script,
]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
try:
stdout, stderr = process.communicate(timeout=400)
return_code = process.returncode
except subprocess.TimeoutExpired:
process.kill()
stdout, stderr = process.communicate()
return_code = -1
print(f"std_out: {stdout}")
assert return_code == 0, f"Process exited with code {return_code}, stdout: {stdout}, stderr: {stderr}"
test_fused_moe_launch()