From add524d80caff60aa0d306d50ad6d3da781b345f Mon Sep 17 00:00:00 2001 From: Longzhi Wang <583087864@qq.com> Date: Mon, 1 Dec 2025 15:17:18 +0800 Subject: [PATCH] [Feature] support chunked moe (#4575) * [Feature] support chunked moe * update * update * fix and add test * update * fix conflict and modity test * fix fused_moe * fix fused_moe * fix docstring * fix * fix typo * fix test * fix * fix * fix test * fix test --- fastdeploy/config.py | 5 + fastdeploy/engine/args_utils.py | 22 +++ fastdeploy/engine/async_llm.py | 1 + fastdeploy/engine/engine.py | 2 + fastdeploy/model_executor/layers/moe/moe.py | 57 +++++- fastdeploy/worker/gpu_model_runner.py | 66 ++++++- fastdeploy/worker/model_runner_base.py | 14 ++ fastdeploy/worker/worker_process.py | 11 ++ tests/distributed/chunked_moe.py | 183 ++++++++++++++++++++ tests/distributed/test_chunked_moe.py | 49 ++++++ 10 files changed, 405 insertions(+), 5 deletions(-) create mode 100644 tests/distributed/chunked_moe.py create mode 100644 tests/distributed/test_chunked_moe.py diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 3f4d326e1..aea9bf33d 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -540,6 +540,11 @@ class ParallelConfig: self.expert_parallel_size = 1 # EP degree self.data_parallel_size = 1 # DP degree self.enable_expert_parallel = False + self.enable_chunked_moe = False + self.chunked_moe_size = 256 + self.max_moe_num_chunk = 1 + self.moe_num_chunk = 1 + self.local_data_parallel_id = 0 # Engine worker queue port self.engine_worker_queue_port: str = "9923" diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 3f7952fc5..3fec68001 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -286,6 +286,16 @@ class EngineArgs: Enable expert parallelism. """ + enable_chunked_moe: bool = False + """ + Whether use chunked moe. + """ + + chunked_moe_size: int = 256 + """ + Chunk size of moe input. + """ + cache_transfer_protocol: str = "ipc" """ Protocol to use for cache transfer. @@ -870,6 +880,18 @@ class EngineArgs: default=EngineArgs.eplb_config, help="Config of eplb.", ) + parallel_group.add_argument( + "--enable-chunked-moe", + action="store_true", + default=EngineArgs.enable_chunked_moe, + help="Use chunked moe.", + ) + parallel_group.add_argument( + "--chunked-moe-size", + type=int, + default=EngineArgs.chunked_moe_size, + help="Chunked size of moe input.", + ) # Load group load_group = parser.add_argument_group("Load Configuration") diff --git a/fastdeploy/engine/async_llm.py b/fastdeploy/engine/async_llm.py index e77c0a02a..70c004aa3 100644 --- a/fastdeploy/engine/async_llm.py +++ b/fastdeploy/engine/async_llm.py @@ -812,6 +812,7 @@ class AsyncLLMEngine: f" --splitwise_role {self.cfg.scheduler_config.splitwise_role}" f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}" f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}" + f" --chunked_moe_size {self.cfg.parallel_config.chunked_moe_size}" f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}" f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'" f" --ori_vocab_size {ori_vocab_size}" diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 6271d054a..66e950ad9 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -544,6 +544,7 @@ class LLMEngine: f" --splitwise_role {self.cfg.scheduler_config.splitwise_role}" f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}" f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}" + f" --chunked_moe_size {self.cfg.parallel_config.chunked_moe_size}" f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}" f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'" f" --ori_vocab_size {ori_vocab_size}" @@ -573,6 +574,7 @@ class LLMEngine: worker_store_true_flag = { "enable_expert_parallel": self.cfg.parallel_config.enable_expert_parallel, + "enable_chunked_moe": self.cfg.parallel_config.enable_chunked_moe, "enable_prefix_caching": self.cfg.cache_config.enable_prefix_caching, "enable_chunked_prefill": self.cfg.cache_config.enable_chunked_prefill, "do_profile": self.do_profile, diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index e99356e6b..2de8cd567 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -612,6 +612,7 @@ class FusedMoE(nn.Layer): multi_outs = paddle.zeros([token_num_per_rank * self.tp_size, x.shape[1]], dtype=x.dtype) paddle.distributed.all_gather(multi_outs, out, self.tp_group) out = multi_outs[:token_num, :] + return out def forward(self, x: paddle.Tensor, gate: nn.Layer): @@ -633,9 +634,63 @@ class FusedMoE(nn.Layer): and token_num >= self.tp_size ): out = self.forward_split_allgather(x, gate) + elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe: + out = self.forward_chunked_moe(x, gate) else: - out = self.quant_method.apply(self, x, gate) + out = self.forward_normal(x, gate) if self.reduce_results and self.tp_size > 1: out = tensor_model_parallel_all_reduce(out, self.tp_group) return out + + def forward_chunked_moe(self, x: paddle.Tensor, gate: nn.Layer): + """ + Split input to multi chunk to reduce the memory usage of moe. + + Args: + x (Tensor): Input tensor to the moe layer. + + Returns: + Tensor: Output tensor.s + """ + chunk_size = self.fd_config.parallel_config.chunked_moe_size + token_num = x.shape[0] + fake_x = paddle.empty( + shape=[0, self.fd_config.model_config.hidden_size], + dtype=paddle.get_default_dtype(), + ) + # input size that are less than a chunk, less than the max size data or empty input + # need to be repeated until the max chunk data infer MOE finished. + if token_num > chunk_size: # chunked moe + x_split_list = paddle.tensor_split(x, self.fd_config.parallel_config.moe_num_chunk, axis=0) + out_split_list = [None] * self.fd_config.parallel_config.moe_num_chunk + + for i in range(self.fd_config.parallel_config.max_moe_num_chunk): + if i < self.fd_config.parallel_config.moe_num_chunk: + out_split_list[i] = self.quant_method.apply(self, x_split_list[i], gate) + else: + # just need to use real data to infer max_moe_num_chunk times. + self.quant_method.apply(self, fake_x, gate) + + out = paddle.concat(out_split_list, axis=0) + else: + # when only one chunk, just need to use real data to infer once. + out = self.quant_method.apply(self, x, gate) + for i in range(self.fd_config.parallel_config.max_moe_num_chunk - 1): + self.quant_method.apply(self, fake_x, gate) + + return out + + def forward_normal(self, x: paddle.Tensor, gate: nn.Layer): + """ + Normal mode of forward. + + Args: + x (Tensor): Input tensor to the moe layer. + + Returns: + Tensor: Output tensor.s + + """ + out = self.quant_method.apply(self, x, gate) + return out diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 16f3b0a0e..e4ed0de57 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -95,7 +95,11 @@ from fastdeploy.model_executor.logits_processor import build_logits_processors from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling from fastdeploy.output.pooler import PoolerOutput -from fastdeploy.worker.model_runner_base import ModelRunnerBase +from fastdeploy.worker.model_runner_base import ( + DistributedOut, + DistributedStatus, + ModelRunnerBase, +) from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, ModelRunnerOutput @@ -250,6 +254,56 @@ class GPUModelRunner(ModelRunnerBase): return if_only_prefill + def collect_distributed_status(self): + """ + Collect distributed status + """ + dist_status_list = [] + dist_status_obj = DistributedStatus() + dist_out = DistributedOut() + + prefill_exists = None + if_only_decode = True + # mix ep in single node + if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": + prefill_exists = self.exist_prefill() + dist_status_obj.only_decode = not prefill_exists + + # whether chunked moe + if self.fd_config.parallel_config.enable_chunked_moe: + chunk_size = self.fd_config.parallel_config.chunked_moe_size + token_num = self.share_inputs["ids_remove_padding"].shape[0] + + if token_num > chunk_size: + self.fd_config.parallel_config.moe_num_chunk = (token_num + chunk_size - 1) // chunk_size + else: + self.fd_config.parallel_config.moe_num_chunk = 1 + + dist_status_obj.moe_num_chunk = self.fd_config.parallel_config.moe_num_chunk + + # only ep need to collect and sync distributed status + if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": + # call once to gather all status + paddle.distributed.all_gather_object(dist_status_list, dist_status_obj) + + # Update Batch type for cuda graph for if_only_decode + if_only_decode = all(dist_status.only_decode for dist_status in dist_status_list) + + if_only_decode = if_only_decode and not ( + prefill_exists if prefill_exists is not None else self.exist_prefill() + ) + + max_moe_num_chunk = None + if self.fd_config.parallel_config.enable_chunked_moe: + max_moe_num_chunk = max(dist_status.moe_num_chunk for dist_status in dist_status_list) + + dist_out = DistributedOut( + if_only_decode=if_only_decode, + max_moe_num_chunk=max_moe_num_chunk, + ) + + return dist_out + def only_decode(self): """ check whether decode only @@ -1355,7 +1409,7 @@ class GPUModelRunner(ModelRunnerBase): def initialize_forward_meta(self, is_dummy_or_profile_run=False): """ - Initialize forward meta and attention meta data + Initialize forward meta, attention meta data and update some config. """ # Initialize forward meta self.forward_meta = ForwardMeta( @@ -1386,8 +1440,12 @@ class GPUModelRunner(ModelRunnerBase): kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], ) - # Update Batch type for cuda graph for only_decode_batch - if_only_decode = self.only_decode() + dist_status = self.collect_distributed_status() + + if_only_decode = dist_status.if_only_decode + if self.fd_config.parallel_config.enable_chunked_moe: + self.fd_config.parallel_config.max_moe_num_chunk = dist_status.max_moe_num_chunk + only_decode_use_cudagraph = self.use_cudagraph and if_only_decode # Update config about moe for better performance diff --git a/fastdeploy/worker/model_runner_base.py b/fastdeploy/worker/model_runner_base.py index 699182576..ece670530 100644 --- a/fastdeploy/worker/model_runner_base.py +++ b/fastdeploy/worker/model_runner_base.py @@ -15,6 +15,8 @@ """ from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional from paddle import nn @@ -25,6 +27,18 @@ from fastdeploy.worker.output import ModelRunnerOutput logger = get_logger("model_runner_base", "model_runner_base.log") +@dataclass +class DistributedStatus: + only_decode: bool = True + moe_num_chunk: int = 1 + + +@dataclass +class DistributedOut: + if_only_decode: bool = True + max_moe_num_chunk: Optional[int] = None + + class ModelRunnerBase(ABC): """ Engine -> (WIP)Executor -> Worker -> ModelRunner -> Model diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index f836e07c6..67d58e7f8 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -720,6 +720,17 @@ def parse_args(): action="store_true", help="enable expert parallel", ) + parser.add_argument( + "--enable_chunked_moe", + action="store_true", + help="enable chunked moe", + ) + parser.add_argument( + "--chunked_moe_size", + type=int, + default=256, + help="chunk size of moe input", + ) parser.add_argument("--ori_vocab_size", type=int, default=None) parser.add_argument("--think_end_id", type=int, default=-1) parser.add_argument("--image_patch_id", type=int, default=-1) diff --git a/tests/distributed/chunked_moe.py b/tests/distributed/chunked_moe.py new file mode 100644 index 000000000..86adce5b5 --- /dev/null +++ b/tests/distributed/chunked_moe.py @@ -0,0 +1,183 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import Mock + +import paddle +import paddle.distributed as dist +from paddle.distributed import fleet + +from fastdeploy.config import MoEPhase +from fastdeploy.model_executor.layers.moe import FusedMoE +from fastdeploy.worker.gpu_model_runner import GPUModelRunner + + +class MockStructuredOutputsConfig: + logits_processors = [] + + +class MockModelConfig: + max_model_len = 10 + pad_token_id = 0 + eos_tokens_lens = 1 + eos_token_id = 0 + temperature = 1.0 + penalty_score = 1.0 + frequency_score = 1.0 + min_length = 1 + vocab_size = 1 + top_p = 1.0 + presence_score = 1.0 + max_stop_seqs_num = 5 + stop_seqs_max_len = 2 + head_dim = 128 + model_type = ["mock"] + moe_phase = MoEPhase(phase="prefill") + hidden_size = 1536 + + +class MockCacheConfig: + block_size = 64 + total_block_num = 256 + kv_cache_ratio = 0.9 + enc_dec_block_num = 2 + + +class MockFDConfig: + class ParallelConfig: + enable_expert_parallel = True + enable_chunked_moe = True + chunked_moe_size = 2 + max_moe_num_chunk = 1 + moe_num_chunk = 1 + use_ep = True + use_sequence_parallel_moe = False + + class SchedulerConfig: + name = "default" + splitwise_role = "mixed" + max_num_seqs = 2 + + parallel_config = ParallelConfig() + scheduler_config = SchedulerConfig() + structured_outputs_config = MockStructuredOutputsConfig() + model_config = MockModelConfig() + + +class MockAttentionBackend: + def __init__(self): + pass + + def init_attention_metadata(self, forward_meta): + pass + + +class MockQuantMethod: + def apply(self, layer, x, gate): + return x + + +class TestChunkedMoE(unittest.TestCase): + def setUp(self) -> None: + paddle.seed(2025) + + strategy = fleet.DistributedStrategy() + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": 2, + "pp_degree": 1, + "sharding_degree": 1, + } + + fleet.init(is_collective=True, strategy=strategy) + + self.model_runner = self.setup_model_runner() + self.fused_moe = self.setup_fused_moe() + + def setup_model_runner(self): + """Helper method to setup GPUModelRunner with different configurations""" + mock_fd_config = MockFDConfig() + + mock_model_config = MockModelConfig() + mock_cache_config = MockCacheConfig() + + model_runner = GPUModelRunner.__new__(GPUModelRunner) + model_runner.fd_config = mock_fd_config + model_runner.model_config = mock_model_config + model_runner.cache_config = mock_cache_config + model_runner.attn_backends = [MockAttentionBackend()] + model_runner.enable_mm = True + model_runner.cudagraph_only_prefill = False + model_runner.use_cudagraph = False + model_runner.speculative_decoding = False + model_runner._init_share_inputs(mock_fd_config.scheduler_config.max_num_seqs) + model_runner.share_inputs["caches"] = None + + if dist.get_rank() == 0: + model_runner.share_inputs["ids_remove_padding"] = paddle.ones([10]) + else: + model_runner.share_inputs["ids_remove_padding"] = paddle.ones([1]) + + return model_runner + + def setup_fused_moe(self): + mock_fd_config = MockFDConfig() + + fused_moe = FusedMoE.__new__(FusedMoE) + fused_moe.ep_size = 2 + fused_moe.tp_size = 1 + fused_moe.reduce_results = True + + fused_moe.fd_config = mock_fd_config + fused_moe.quant_method = MockQuantMethod() + return fused_moe + + def run_model_runner(self): + self.model_runner.initialize_forward_meta() + + assert self.model_runner.fd_config.parallel_config.max_moe_num_chunk == 5, ( + f"chunk size is 2, max token_num is 10, max_moe_num_chunk should be 5, " + f"but got {self.model_runner.fd_config.parallel_config.max_moe_num_chunk}" + ) + if dist.get_rank() == 0: + assert self.model_runner.fd_config.parallel_config.moe_num_chunk == 5, ( + f"chunk size is 2, token_num is 10, moe_num_chunk in rank 0 should be 5" + f"but got {self.model_runner.fd_config.parallel_config.moe_num_chunk}" + ) + else: + assert self.model_runner.fd_config.parallel_config.moe_num_chunk == 1, ( + f"chunk size is 2, token_num is 1, moe_num_chunk in rank 1 should be 1" + f", but got {self.model_runner.fd_config.parallel_config.moe_num_chunk}" + ) + + def run_fused_moe(self): + gate = Mock() + if dist.get_rank() == 0: + x = paddle.ones([10]) + else: + x = paddle.ones([1]) + + out = self.fused_moe.forward(x, gate) + assert out.shape == x.shape + + def test_case(self): + # check whether dist collected max_moe_num_chunk is correct. + self.run_model_runner() + # check the forward method of chunked MoE. + self.run_fused_moe() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/distributed/test_chunked_moe.py b/tests/distributed/test_chunked_moe.py new file mode 100644 index 000000000..3740b4159 --- /dev/null +++ b/tests/distributed/test_chunked_moe.py @@ -0,0 +1,49 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import sys + + +def test_fused_moe_launch(): + """ + test_fused_moe + """ + current_dir = os.path.dirname(os.path.abspath(__file__)) + chunked_moe_script = os.path.join(current_dir, "chunked_moe.py") + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" + command = [ + sys.executable, + "-m", + "paddle.distributed.launch", + "--gpus", + "0,1", + chunked_moe_script, + ] + + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + try: + stdout, stderr = process.communicate(timeout=400) + return_code = process.returncode + except subprocess.TimeoutExpired: + process.kill() + stdout, stderr = process.communicate() + return_code = -1 + print(f"std_out: {stdout}") + assert return_code == 0, f"Process exited with code {return_code}, stdout: {stdout}, stderr: {stderr}" + + +test_fused_moe_launch()