diff --git a/fastdeploy/spec_decode/base.py b/fastdeploy/spec_decode/base.py index 1719b7f26..900e99c9b 100644 --- a/fastdeploy/spec_decode/base.py +++ b/fastdeploy/spec_decode/base.py @@ -38,13 +38,21 @@ class Proposer(ABC): Init Speculative proposer """ cfg.parallel_config.tp_group = None + cfg.parallel_config.ep_group = None self.cfg = deepcopy(cfg) cfg.parallel_config.tp_group = dist.get_group( cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET ) + cfg.parallel_config.ep_group = dist.get_group( + cfg.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET + ) self.cfg.parallel_config.tp_group = dist.get_group( cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET ) + self.cfg.parallel_config.ep_group = dist.get_group( + cfg.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET + ) + self.parallel_config = self.cfg.parallel_config self.model_config = self.cfg.model_config self.speculative_config = self.cfg.speculative_config diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 6ec6ee190..2614c4596 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -695,6 +695,9 @@ class MTPProposer(Proposer): if substep != self.num_model_steps - 1: target_hidden_states = self._get_self_hidden_states(hidden_states) + else: + if hasattr(self.model, "empty_input_forward"): + self.model.empty_input_forward() def _get_self_hidden_states(self, hidden_states): target_hidden_states = eagle_get_self_hidden_states(