cp dynamic Cfp8 (#4120)

* supports dynamic Cfp8

* add unittest

* fix dynamic Cfp8 computing error

* fix Cfp8 for RL load

---------

Co-authored-by: carryyu <569782149@qq.com>
This commit is contained in:
Yuan Xiaolan
2025-09-17 11:55:47 +08:00
committed by GitHub
parent b6caf6e622
commit 25aa2d94aa
21 changed files with 1428 additions and 227 deletions

View File

@@ -1010,6 +1010,8 @@ class GPUModelRunner(ModelRunnerBase):
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type
)
if kv_cache_quant_type == "block_wise_fp8":
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
@@ -1037,6 +1039,17 @@ class GPUModelRunner(ModelRunnerBase):
fill_value=0,
dtype=cache_type,
)
if kv_cache_quant_type == "block_wise_fp8":
cache_kvs[f"key_cache_scales_{i}"] = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
)
cache_kvs[f"value_cache_scales_{i}"] = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
)
self.share_inputs["caches"] = list(cache_kvs.values())
for value in cache_kvs.values():
del value