mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
support mtp in ep64 (#4280)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
This commit is contained in:
@@ -38,13 +38,21 @@ class Proposer(ABC):
|
|||||||
Init Speculative proposer
|
Init Speculative proposer
|
||||||
"""
|
"""
|
||||||
cfg.parallel_config.tp_group = None
|
cfg.parallel_config.tp_group = None
|
||||||
|
cfg.parallel_config.ep_group = None
|
||||||
self.cfg = deepcopy(cfg)
|
self.cfg = deepcopy(cfg)
|
||||||
cfg.parallel_config.tp_group = dist.get_group(
|
cfg.parallel_config.tp_group = dist.get_group(
|
||||||
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
||||||
)
|
)
|
||||||
|
cfg.parallel_config.ep_group = dist.get_group(
|
||||||
|
cfg.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
|
||||||
|
)
|
||||||
self.cfg.parallel_config.tp_group = dist.get_group(
|
self.cfg.parallel_config.tp_group = dist.get_group(
|
||||||
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
||||||
)
|
)
|
||||||
|
self.cfg.parallel_config.ep_group = dist.get_group(
|
||||||
|
cfg.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
|
||||||
|
)
|
||||||
|
|
||||||
self.parallel_config = self.cfg.parallel_config
|
self.parallel_config = self.cfg.parallel_config
|
||||||
self.model_config = self.cfg.model_config
|
self.model_config = self.cfg.model_config
|
||||||
self.speculative_config = self.cfg.speculative_config
|
self.speculative_config = self.cfg.speculative_config
|
||||||
|
@@ -695,6 +695,9 @@ class MTPProposer(Proposer):
|
|||||||
|
|
||||||
if substep != self.num_model_steps - 1:
|
if substep != self.num_model_steps - 1:
|
||||||
target_hidden_states = self._get_self_hidden_states(hidden_states)
|
target_hidden_states = self._get_self_hidden_states(hidden_states)
|
||||||
|
else:
|
||||||
|
if hasattr(self.model, "empty_input_forward"):
|
||||||
|
self.model.empty_input_forward()
|
||||||
|
|
||||||
def _get_self_hidden_states(self, hidden_states):
|
def _get_self_hidden_states(self, hidden_states):
|
||||||
target_hidden_states = eagle_get_self_hidden_states(
|
target_hidden_states = eagle_get_self_hidden_states(
|
||||||
|
Reference in New Issue
Block a user