mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-28 10:51:39 +08:00
[Executor]CUDAGraph support Speculate Decode (#4258)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [Executor]CUDAGraph support Speculate Decode
* fix problem
* solve problem
* fix
* fast compile
* CUDAGraph + mtp support eb5(only target model)
* Revert "fast compile"
This reverts commit 3cfe8373ed.
* fix precommit
* solve comment
* fix comment about #pragram unroll
---------
Co-authored-by: gongshaotian <gstain5555@outlook.com>
Co-authored-by: gongshaotian <gstian5555@outlook.com>
This commit is contained in:
@@ -33,31 +33,33 @@ class Proposer(ABC):
|
||||
the speculative decoding framework
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: FDConfig):
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
"""
|
||||
Init Speculative proposer
|
||||
"""
|
||||
cfg.parallel_config.tp_group = None
|
||||
cfg.parallel_config.ep_group = None
|
||||
self.cfg = deepcopy(cfg)
|
||||
cfg.parallel_config.tp_group = dist.get_group(
|
||||
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
||||
fd_config.parallel_config.tp_group = None
|
||||
fd_config.parallel_config.ep_group = None
|
||||
self.fd_config = deepcopy(fd_config)
|
||||
fd_config.parallel_config.tp_group = dist.get_group(
|
||||
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
||||
)
|
||||
cfg.parallel_config.ep_group = dist.get_group(
|
||||
cfg.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
|
||||
fd_config.parallel_config.ep_group = dist.get_group(
|
||||
fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
|
||||
)
|
||||
self.cfg.parallel_config.tp_group = dist.get_group(
|
||||
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
||||
self.fd_config.parallel_config.tp_group = dist.get_group(
|
||||
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
||||
)
|
||||
self.cfg.parallel_config.ep_group = dist.get_group(
|
||||
cfg.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
|
||||
self.fd_config.parallel_config.ep_group = dist.get_group(
|
||||
fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
|
||||
)
|
||||
|
||||
self.parallel_config = self.cfg.parallel_config
|
||||
self.model_config = self.cfg.model_config
|
||||
self.speculative_config = self.cfg.speculative_config
|
||||
self.cache_config = self.cfg.cache_config
|
||||
self.quant_config = self.cfg.quant_config
|
||||
self.parallel_config = self.fd_config.parallel_config
|
||||
self.model_config = self.fd_config.model_config
|
||||
self.speculative_config = self.fd_config.speculative_config
|
||||
self.cache_config = self.fd_config.cache_config
|
||||
self.quant_config = self.fd_config.quant_config
|
||||
self.graph_opt_config = self.fd_config.graph_opt_config
|
||||
self.scheduler_config = self.fd_config.scheduler_config
|
||||
|
||||
self.max_num_seqs = self.parallel_config.max_num_seqs
|
||||
self.max_model_len = self.parallel_config.max_model_len
|
||||
|
||||
Reference in New Issue
Block a user