[Speculative Decoding][MTP] Support static CacheKV C8 quantization and optimize memory usage (#5155)

* support static cachekv c8 quantization in mtp mode

* optimize memory allocation
This commit is contained in:
freeliuzc
2025-11-21 15:10:13 +08:00
committed by GitHub
parent 3c36283d7d
commit 2d1dade5e2
6 changed files with 350 additions and 295 deletions

View File

@@ -198,6 +198,8 @@ class ModelConfig:
self.pooler_config: Optional["PoolerConfig"] = field(init=False)
self.override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
self.revision = None
self.prefix_layer_name = "layers"
self.kv_cache_quant_scale_path = ""
self.partial_rotary_factor: float = 1.0
self.num_nextn_predict_layers = 0
@@ -244,6 +246,7 @@ class ModelConfig:
self.enable_mm = is_multimodal_model
self.kv_cache_quant_scale_path = os.path.join(self.model, "kv_cache_scale.json")
if self.runner_type == "pooling":
os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = "1"
@@ -1589,6 +1592,10 @@ class FDConfig:
else:
self.scheduler_config.max_num_batched_tokens = self.model_config.max_model_len
self.scheduler_config.max_chunk_len = (
self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_extra_num_batched_tokens
)
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.model_config.max_model_len * 0.04)