mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[BugFix] fix cpu prefix cache bug (#5544)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* fix_dy_c8_bug * add block_num check * fix test case * update ci case
This commit is contained in:
@@ -66,6 +66,13 @@ def parse_args():
|
||||
choices=["uint8", "bfloat16", "block_wise_fp8"],
|
||||
help="cache dtype",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--default_dtype",
|
||||
type=str,
|
||||
default="bfloat16",
|
||||
choices=["float16", "bfloat16", "uint8"],
|
||||
help="paddle default dtype, swap_cache_batch only support float16、bfloat16 and uint8 now",
|
||||
)
|
||||
parser.add_argument("--key_cache_shape", type=str, default="", help="key cache shape")
|
||||
parser.add_argument("--value_cache_shape", type=str, default="", help="value cache shape")
|
||||
parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port")
|
||||
@@ -124,6 +131,7 @@ class CacheTransferManager:
|
||||
self.num_gpu_blocks = self.key_cache_shape[0]
|
||||
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
|
||||
self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
||||
paddle.set_default_dtype(args.default_dtype)
|
||||
|
||||
self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
@@ -279,6 +279,7 @@ class PrefixCacheManager:
|
||||
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
|
||||
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
|
||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||
+ f" --default_dtype '{self.config.model_config.dtype}'"
|
||||
+ (" --create_cache_tensor" if create_cache_tensor else "")
|
||||
+ f" >{log_dir}/launch_cache_transfer_manager_tprank{i}.log 2>&1"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user