[TSP] Support qwen3 moe tsp + cudagraph (#4871)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* support qwen3_moe tsp mode

* fix

* fix

* update

* update

* update

* fix

* support external_rmsnorm

* update

* fix
This commit is contained in:
Yuanle Liu
2025-11-10 23:37:51 +08:00
committed by GitHub
parent fb2eb403ab
commit 3dc0ffa46d
28 changed files with 173 additions and 273 deletions

View File

@@ -240,7 +240,7 @@ class EngineArgs:
disable_custom_all_reduce: bool = False
"""
Flag to enable the custom all-reduce kernel.
Flag to disable the custom all-reduce kernel.
"""
use_internode_ll_two_stage: bool = False
@@ -248,6 +248,19 @@ class EngineArgs:
Flag to use the internode_ll_two_stage kernel.
"""
disable_sequence_parallel_moe: bool = False
"""
# The all_reduce at the end of attention (during o_proj) means that
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# This optimization is enabled by default, and can be disabled by using this flag.
"""
engine_worker_queue_port: str = "0"
"""
Port for worker queue communication.
@@ -766,6 +779,12 @@ class EngineArgs:
default=EngineArgs.use_internode_ll_two_stage,
help="Flag to use the internode_ll_two_stage kernel.",
)
parallel_group.add_argument(
"--disable-sequence-parallel-moe",
action="store_true",
default=EngineArgs.disable_sequence_parallel_moe,
help="Flag to disable disable the sequence parallel moe.",
)
parallel_group.add_argument(
"--max-num-seqs",
type=int,