[BugFix] Fix the abnormal memory usage caused by shape errors in the triton moe backend (#4026)

* fix device_id to in

* fix triton_moe bug
This commit is contained in:
Yuanle Liu
2025-09-10 11:05:54 +08:00
committed by GitHub
parent dbab579299
commit c3b2a60fb8
4 changed files with 12 additions and 10 deletions

View File

@@ -671,7 +671,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
shape=self.down_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),

View File

@@ -19,9 +19,11 @@ from typing import List
import numpy as np
import paddle
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request, RequestType
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.attention import get_attention_backend
@@ -52,7 +54,7 @@ class MTPProposer(Proposer):
Proposer for Multi-Token-Prediction(MTP)
"""
def __init__(self, cfg, main_model, local_rank, device_id, target_model_inputs):
def __init__(self, cfg: FDConfig, main_model: nn.Layer, local_rank: int, device_id: int, target_model_inputs):
super().__init__(cfg)
self.num_main_model_layers = self.model_config.num_hidden_layers
self.local_rank = local_rank

View File

@@ -516,13 +516,13 @@ def print_gpu_memory_use(gpu_id: int, title: str) -> None:
print(
f"\n{title}:",
f"\n\tDevice Total memory: {meminfo.total}",
f"\n\tDevice Used memory: {meminfo.used}",
f"\n\tDevice Free memory: {meminfo.free}",
f"\n\tPaddle max memory Reserved: {paddle_max_reserved}",
f"\n\tPaddle max memory Allocated: {paddle_max_allocated}",
f"\n\tPaddle memory Reserved: {paddle_reserved}",
f"\n\tPaddle memory Allocated: {paddle_allocated}",
f"\n\tDevice Total memory(GiB): {meminfo.total / 1024.0 / 1024.0 / 1024.0}",
f"\n\tDevice Used memory(GiB): {meminfo.used / 1024.0 / 1024.0 / 1024.0}",
f"\n\tDevice Free memory(GiB): {meminfo.free / 1024.0 / 1024.0 / 1024.0}",
f"\n\tPaddle max memory Reserved(GiB): {paddle_max_reserved / 1024.0 / 1024.0 / 1024.0}",
f"\n\tPaddle max memory Allocated(GiB): {paddle_max_allocated / 1024.0 / 1024.0 / 1024.0}",
f"\n\tPaddle memory Reserved(GiB): {paddle_reserved / 1024.0 / 1024.0 / 1024.0}",
f"\n\tPaddle memory Allocated(GiB): {paddle_allocated / 1024.0 / 1024.0 / 1024.0}",
)

View File

@@ -84,7 +84,7 @@ class GpuWorker(WorkerBase):
self.model_runner: ModelRunnerBase = ModelRunner(
fd_config=self.fd_config,
device=self.device,
device_id=self.device_ids[self.local_rank % self.max_chips_per_node],
device_id=int(self.device_ids[self.local_rank % self.max_chips_per_node]),
rank=self.rank,
local_rank=self.local_rank,
)