Support MLA_CACHE & Fix V1_Schedule Bug (#4318)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

Support MLA_CACHE & Fix V1_Schedule Bug
This commit is contained in:
AIbin
2025-10-09 12:11:25 +08:00
committed by GitHub
parent 791b101195
commit 48fd5d757d
2 changed files with 41 additions and 14 deletions

View File

@@ -460,8 +460,7 @@ class ResourceManagerV1(ResourceManager):
# Prepare decoding task # Prepare decoding task
scheduled_reqs.append(self._prepare_decode_task(request)) scheduled_reqs.append(self._prepare_decode_task(request))
num_decoding_req_nums += 1 num_decoding_req_nums += 1
token_budget -= 1 token_budget -= 1
if ( if (
request.use_extend_tables request.use_extend_tables
and request.request_id not in self.using_extend_tables_req_id and request.request_id not in self.using_extend_tables_req_id

View File

@@ -1192,31 +1192,48 @@ class GPUModelRunner(ModelRunnerBase):
logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}") logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}")
cache_kvs_list = [] cache_kvs_list = []
# NOTE:(changwenbin) Determine whether it is Multi-Head Latent Attention,
# To rationalize the allocation of kvcache.
from fastdeploy import envs
self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN"
for i in range(self.model_config.num_hidden_layers): for i in range(self.model_config.num_hidden_layers):
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" if not self.mla_cache:
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
if create_cache_tensor: if create_cache_tensor:
logger.info(f"..creating kv cache for layer {i}: {kv_cache_shape}") logger.info(f"..creating kv cache for layer {i}: {kv_cache_shape}")
key_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type) key_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
set_data_ipc(key_cache, key_cache_name) set_data_ipc(key_cache, key_cache_name)
set_data_ipc(val_cache, val_cache_name) if not self.mla_cache:
cache_kvs_list.extend([key_cache, val_cache]) val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
set_data_ipc(val_cache, val_cache_name)
cache_kvs_list.extend([key_cache, val_cache])
else:
cache_kvs_list.extend([key_cache])
if kv_cache_quant_type == "block_wise_fp8": if kv_cache_quant_type == "block_wise_fp8":
key_cache_scales = paddle.full( key_cache_scales = paddle.full(
shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
) )
val_cache_scales = paddle.full( if not self.mla_cache:
shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() val_cache_scales = paddle.full(
) shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
cache_kvs_list.extend([key_cache_scales, val_cache_scales]) )
cache_kvs_list.extend([key_cache_scales, val_cache_scales])
else:
cache_kvs_list.extend([key_cache_scales])
else: else:
logger.info(f"..attaching kv cache for layer {i}: {kv_cache_shape}") logger.info(f"..attaching kv cache for layer {i}: {kv_cache_shape}")
key_cache = paddle.empty(shape=[], dtype=cache_type) key_cache = paddle.empty(shape=[], dtype=cache_type)
val_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape) key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape) if not self.mla_cache:
cache_kvs_list.extend([key_cache, val_cache]) val_cache = paddle.empty(shape=[], dtype=cache_type)
val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape)
cache_kvs_list.extend([key_cache, val_cache])
else:
cache_kvs_list.extend([key_cache])
self.share_inputs["caches"] = cache_kvs_list self.share_inputs["caches"] = cache_kvs_list
if not profile and create_cache_tensor: if not profile and create_cache_tensor:
@@ -1936,7 +1953,18 @@ class GPUModelRunner(ModelRunnerBase):
if self.speculative_method in ["mtp"] if self.speculative_method in ["mtp"]
else self.model_config.num_hidden_layers else self.model_config.num_hidden_layers
) )
required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v
# NOTE:(changwenbin) Determie whether it is Multi-Head Latent Attention,
# To rationalize the allocation of kvcache.
if self.mla_cache:
required_memory = (
byte_of_dtype
* (self.fd_config.model_config.kv_lora_rank + self.fd_config.model_config.qk_rope_head_dim)
* (self.cache_config.block_size)
* num_layers
) # compress_kv + k_pe
else:
required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v
return required_memory return required_memory
def not_need_stop(self) -> bool: def not_need_stop(self) -> bool: