[Feature][MTP] Support cacheKV transfer in per_chunk mode (#2890)

* support chunk_prefill both normal and speculative_decoding(mtp)

* optimize pd-disaggregation config

* fix bug
This commit is contained in:
freeliuzc
2025-07-17 17:58:08 +08:00
committed by GitHub
parent 67180c1ff9
commit d49f8fb30a
10 changed files with 110 additions and 27 deletions

View File

@@ -16,6 +16,7 @@
from __future__ import annotations
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Literal, Optional
@@ -109,7 +110,7 @@ class ModelConfig:
self.ori_vocab_size = self.vocab_size
if "Ernie4_5_ForCausalLM" in self.architectures or "Ernie4_5_MoeForCausalLM" in self.architectures:
self.ori_vocab_size = args["ori_vocab_size"]
self.ori_vocab_size = args.get("ori_vocab_size", self.ori_vocab_size)
class ParallelConfig:
"""Configuration for the distributed execution."""
@@ -191,6 +192,18 @@ class ParallelConfig:
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
self.enable_custom_all_reduce: bool = False
# pd_disaggregation
use_pd_disaggregation: int = int(
os.getenv("FLAGS_use_pd_disaggregation", 0))
use_pd_disaggregation_per_chunk: int = int(
os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0))
if use_pd_disaggregation_per_chunk:
self.pd_disaggregation_mode = "per_chunk"
elif use_pd_disaggregation:
self.pd_disaggregation_mode = "per_query"
else:
self.pd_disaggregation_mode = "None"
class SpeculativeConfig:
"""
Configuration for speculative decoding.