mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature][MTP] Support cacheKV transfer in per_chunk mode (#2890)
* support chunk_prefill both normal and speculative_decoding(mtp) * optimize pd-disaggregation config * fix bug
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
@@ -109,7 +110,7 @@ class ModelConfig:
|
||||
|
||||
self.ori_vocab_size = self.vocab_size
|
||||
if "Ernie4_5_ForCausalLM" in self.architectures or "Ernie4_5_MoeForCausalLM" in self.architectures:
|
||||
self.ori_vocab_size = args["ori_vocab_size"]
|
||||
self.ori_vocab_size = args.get("ori_vocab_size", self.ori_vocab_size)
|
||||
|
||||
class ParallelConfig:
|
||||
"""Configuration for the distributed execution."""
|
||||
@@ -191,6 +192,18 @@ class ParallelConfig:
|
||||
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
|
||||
self.enable_custom_all_reduce: bool = False
|
||||
|
||||
# pd_disaggregation
|
||||
use_pd_disaggregation: int = int(
|
||||
os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||
use_pd_disaggregation_per_chunk: int = int(
|
||||
os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0))
|
||||
if use_pd_disaggregation_per_chunk:
|
||||
self.pd_disaggregation_mode = "per_chunk"
|
||||
elif use_pd_disaggregation:
|
||||
self.pd_disaggregation_mode = "per_query"
|
||||
else:
|
||||
self.pd_disaggregation_mode = "None"
|
||||
|
||||
class SpeculativeConfig:
|
||||
"""
|
||||
Configuration for speculative decoding.
|
||||
|
Reference in New Issue
Block a user