fix deepcopy(tp_group) in spec (#3648)

This commit is contained in:
lzy
2025-08-29 16:08:21 +08:00
committed by GitHub
parent 45f81b34f0
commit 48d760539b
3 changed files with 12 additions and 2 deletions

View File

@@ -112,7 +112,7 @@ class Ernie4_5_MoE(nn.Layer):
self.tp_group = fd_config.parallel_config.tp_group self.tp_group = fd_config.parallel_config.tp_group
self.use_ep = self.expert_parallel_size > 1 self.use_ep = self.expert_parallel_size > 1
self.us_tp = self.tensor_parallel_size > 1 self.use_tp = self.tensor_parallel_size > 1
if moe_quant_type == "w4a8" or moe_quant_type == "w4afp8": if moe_quant_type == "w4a8" or moe_quant_type == "w4afp8":
weight_key_map = { weight_key_map = {

View File

@@ -58,7 +58,7 @@ class Qwen3MoeBlock(nn.Layer):
self.tp_group = fd_config.parallel_config.tp_group self.tp_group = fd_config.parallel_config.tp_group
self.use_ep = self.expert_parallel_size > 1 self.use_ep = self.expert_parallel_size > 1
self.us_tp = self.tensor_parallel_size > 1 self.use_tp = self.tensor_parallel_size > 1
weight_key_map = { weight_key_map = {
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight", "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",

View File

@@ -18,6 +18,9 @@ from abc import ABC, abstractmethod
from copy import deepcopy from copy import deepcopy
from typing import Any from typing import Any
import paddle.distributed as dist
from fastdeploy import envs
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.utils import spec_logger from fastdeploy.utils import spec_logger
@@ -34,7 +37,14 @@ class Proposer(ABC):
""" """
Init Speculative proposer Init Speculative proposer
""" """
cfg.parallel_config.tp_group = None
self.cfg = deepcopy(cfg) self.cfg = deepcopy(cfg)
cfg.parallel_config.tp_group = dist.get_group(
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
)
self.cfg.parallel_config.tp_group = dist.get_group(
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
)
self.parallel_config = self.cfg.parallel_config self.parallel_config = self.cfg.parallel_config
self.model_config = self.cfg.model_config self.model_config = self.cfg.model_config
self.speculative_config = self.cfg.speculative_config self.speculative_config = self.cfg.speculative_config