This commit is contained in:
Yuanle Liu
2025-07-16 15:33:10 +08:00
committed by GitHub
parent a83a3eea5f
commit dda4a9f848
10 changed files with 26 additions and 131 deletions

View File

@@ -378,9 +378,7 @@ class LoadConfig:
dynamic_load_weight: Whether to enable dynamic weight loading dynamic_load_weight: Whether to enable dynamic weight loading
load_strategy: Specifies the weight loading method when enabled: load_strategy: Specifies the weight loading method when enabled:
- 'ipc': Real-time IPC streaming with automatic resharding - 'ipc': Real-time IPC streaming with automatic resharding
- 'ipc_no_reshard': Real-time IPC streaming without weight process
- 'ipc_snapshot': Load from disk snapshot of IPC weights - 'ipc_snapshot': Load from disk snapshot of IPC weights
- 'meta': provide RL traing worker, no_weights_load
- None: No dynamic loading - None: No dynamic loading
""" """
def __init__( def __init__(
@@ -389,7 +387,7 @@ class LoadConfig:
): ):
self.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1 self.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1
self.dynamic_load_weight: bool = False self.dynamic_load_weight: bool = False
self.load_strategy: Optional[Literal['ipc', 'ipc_no_reshard', 'ipc_snapshot', 'meta']] = None self.load_strategy: Optional[Literal['ipc', 'ipc_snapshot']] = None
for key, value in args.items(): for key, value in args.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)

View File

@@ -92,7 +92,7 @@ class EngineArgs:
""" """
dynamic load weight dynamic load weight
""" """
load_strategy: str = "meta" load_strategy: str = "ipc_snapshot"
""" """
dynamic load weight strategy dynamic load weight strategy
""" """

View File

@@ -43,7 +43,7 @@ class ModelConfig:
model_name_or_path: str, model_name_or_path: str,
config_json_file: str = "config.json", config_json_file: str = "config.json",
dynamic_load_weight: bool = False, dynamic_load_weight: bool = False,
load_strategy: str = "meta", load_strategy: str = "ipc_snapshot",
quantization: str = None, quantization: str = None,
download_dir: Optional[str] = None): download_dir: Optional[str] = None):
""" """

View File

@@ -140,7 +140,7 @@ class FusedMoE(nn.Layer):
shape=gate_weight_shape, shape=gate_weight_shape,
dtype="float32", dtype="float32",
) )
if self.model_config.moe_use_aux_free: if self.fd_config.model_config.moe_use_aux_free:
self.gate_correction_bias = self.create_parameter( self.gate_correction_bias = self.create_parameter(
shape=gate_correction_bias_shape, shape=gate_correction_bias_shape,
dtype="float32", dtype="float32",

View File

@@ -519,43 +519,6 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
""" """
return self.blocks[0].mlp.fc2.weight.dtype return self.blocks[0].mlp.fc2.weight.dtype
def get_name_mappings_to_training(self, ):
""" get_name_mappings_to_training """
infer_to_train = {}
# vit train names
vit_names = [
"vision_model.patch_embed.proj.weight", "vision_model.ln.weight",
"vision_model.ln.bias"
]
vit_layer = 32
for layer_idx in range(vit_layer):
vit_names.append(f"vision_model.blocks.{layer_idx}.norm1.weight")
vit_names.append(f"vision_model.blocks.{layer_idx}.norm1.bias")
vit_names.append(f"vision_model.blocks.{layer_idx}.norm2.weight")
vit_names.append(f"vision_model.blocks.{layer_idx}.norm2.bias")
vit_names.append(
f"vision_model.blocks.{layer_idx}.attn.qkv.weight")
vit_names.append(f"vision_model.blocks.{layer_idx}.attn.qkv.bias")
vit_names.append(
f"vision_model.blocks.{layer_idx}.attn.proj.weight")
vit_names.append(f"vision_model.blocks.{layer_idx}.attn.proj.bias")
vit_names.append(f"vision_model.blocks.{layer_idx}.mlp.fc1.weight")
vit_names.append(f"vision_model.blocks.{layer_idx}.mlp.fc1.bias")
vit_names.append(f"vision_model.blocks.{layer_idx}.mlp.fc2.weight")
vit_names.append(f"vision_model.blocks.{layer_idx}.mlp.fc2.bias")
for train_name in vit_names:
infer_to_train[train_name] = train_name
return infer_to_train
def rot_pos_emb(self, grid_thw, num_pad=0): def rot_pos_emb(self, grid_thw, num_pad=0):
"""_summary_ """_summary_

View File

@@ -513,7 +513,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
model_config.spatial_conv_size, model_config.spatial_conv_size,
model_config.temporal_conv_size, model_config.temporal_conv_size,
config=model_config, config=model_config,
prefix_name="ernie.resampler_model", prefix_name="resampler_model",
) )
resampler_model = paddle.amp.decorate( resampler_model = paddle.amp.decorate(
models=resampler_model, level="O2", dtype="bfloat16" models=resampler_model, level="O2", dtype="bfloat16"

View File

@@ -210,31 +210,6 @@ class VariableResolutionResamplerModel(nn.Layer):
mark_as_sequence_parallel_parameter(self.mlp.bias) mark_as_sequence_parallel_parameter(self.mlp.bias)
mark_as_sequence_parallel_parameter(self.after_norm.weight) mark_as_sequence_parallel_parameter(self.after_norm.weight)
def get_name_mappings_to_training(self, ):
""" get_name_mappings_to_training """
infer_to_train = {}
resampler_names = [
"ernie.resampler_model.spatial_linear.0.weight",
"ernie.resampler_model.spatial_linear.0.bias",
"ernie.resampler_model.spatial_linear.2.weight",
"ernie.resampler_model.spatial_linear.2.bias",
"ernie.resampler_model.spatial_linear.3.weight",
"ernie.resampler_model.spatial_linear.3.bias",
"ernie.resampler_model.temporal_linear.0.weight",
"ernie.resampler_model.temporal_linear.0.bias",
"ernie.resampler_model.temporal_linear.2.weight",
"ernie.resampler_model.temporal_linear.2.bias",
"ernie.resampler_model.temporal_linear.3.weight",
"ernie.resampler_model.temporal_linear.3.bias",
"ernie.resampler_model.mlp.weight",
"ernie.resampler_model.mlp.bias",
"ernie.resampler_model.after_norm.weight",
]
for train_name in resampler_names:
infer_to_train[train_name[len("ernie."):]] = train_name
return infer_to_train
def spatial_conv_reshape(self, x, spatial_conv_size): def spatial_conv_reshape(self, x, spatial_conv_size):
""" """
Linear 前的 reshape为了让 Linear 能模仿 conv 的感受野 Linear 前的 reshape为了让 Linear 能模仿 conv 的感受野
@@ -376,9 +351,11 @@ class VariableResolutionResamplerModel(nn.Layer):
for param_name, param in params_dict.items(): for param_name, param in params_dict.items():
state_dict_key = f"{self.prefix_name}.{param_name}" state_dict_key = f"{self.prefix_name}.{param_name}"
if state_dict_key not in state_dict: if state_dict_key not in state_dict:
raise ValueError( state_dict_key = f"ernie.{self.prefix_name}.{param_name}"
f"The key {state_dict_key} does not exist in state_dict. " if state_dict_key not in state_dict:
) raise ValueError(
f"The key {state_dict_key} does not exist in state_dict. "
)
tensor = get_tensor(state_dict.pop(state_dict_key)) tensor = get_tensor(state_dict.pop(state_dict_key))
if param.shape != tensor.shape: if param.shape != tensor.shape:
raise ValueError( raise ValueError(

View File

@@ -16,7 +16,7 @@
import os import os
import time import time
from multiprocessing.shared_memory import SharedMemory from multiprocessing.shared_memory import SharedMemory
from typing import Any, Dict, List from typing import Any, Dict
import numpy as np import numpy as np
import paddle import paddle
@@ -24,9 +24,6 @@ from paddle import nn
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.model_executor.load_weight_utils import \
load_composite_checkpoint
from fastdeploy.model_executor.model_loader import MODEL_CLASSES
class DynamicWeightManager: class DynamicWeightManager:
@@ -43,11 +40,9 @@ class DynamicWeightManager:
self.meta_src_id = self._get_gpu_id() self.meta_src_id = self._get_gpu_id()
self.first_load = True self.first_load = True
self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}" self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}"
self.models: List[nn.Layer] = [model] self.model: nn.Layer = model
self._capture_model_state() self._capture_model_state()
self.update_parameters()
if self.load_config.load_strategy != "meta":
self.update_parameters()
logger.info( logger.info(
f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, " f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, "
@@ -56,17 +51,11 @@ class DynamicWeightManager:
@paddle.no_grad() @paddle.no_grad()
def _capture_model_state(self): def _capture_model_state(self):
"""Capture and store initial model parameters state.""" """Capture and store initial model parameters state."""
for model in self.models: for name, param in self.model.state_dict().items():
for name, param in model.state_dict().items(): logger.debug(
logger.debug( f"Model param: {name}, shape={param.shape}, dtype={param.dtype}"
f"Model param: {name}, shape={param.shape}, dtype={param.dtype}" )
) self.state_dict[name] = param
self.state_dict[name] = param
def add_model(self, model: nn.Layer):
""""add model"""
self.models.append(model)
self._capture_model_state()
def update_parameters(self, pid: int = 0) -> None: def update_parameters(self, pid: int = 0) -> None:
"""Core method to update model parameters based on strategy.""" """Core method to update model parameters based on strategy."""
@@ -79,8 +68,6 @@ class DynamicWeightManager:
strategy_handlers = { strategy_handlers = {
"ipc_snapshot": self._update_ipc_snapshot, "ipc_snapshot": self._update_ipc_snapshot,
"ipc": self._update_ipc, "ipc": self._update_ipc,
"ipc_no_reshard": self._update_ipc_no_reshard,
"normal": self.load_model,
} }
if handler := strategy_handlers.get(self.load_config.load_strategy): if handler := strategy_handlers.get(self.load_config.load_strategy):
@@ -106,13 +93,7 @@ class DynamicWeightManager:
fallback_path = f"/shared_ipc_meta/model_state.tp0{self.meta_src_id}.pdparams" fallback_path = f"/shared_ipc_meta/model_state.tp0{self.meta_src_id}.pdparams"
ipc_state_dict = paddle.load(fallback_path) ipc_state_dict = paddle.load(fallback_path)
try: self._update_model_from_state(ipc_state_dict, "snapshot")
self._update_model_from_state(ipc_state_dict, "snapshot")
except Exception:
self.models[0].set_state_dict(ipc_state_dict)
logger.warning(
"load model from no_reshard weight, maybe need more GPU memory"
)
logger.info( logger.info(
f"IPC snapshot update parameters completed from {model_path}") f"IPC snapshot update parameters completed from {model_path}")
@@ -124,34 +105,12 @@ class DynamicWeightManager:
logger.info( logger.info(
f"IPC update parameters completed from file: {self.ipc_path}") f"IPC update parameters completed from file: {self.ipc_path}")
def _update_ipc_no_reshard(self):
"""Update using no-reshard IPC strategy (faster but uses more memory)."""
ipc_meta = paddle.load(self.ipc_path)
state_dict = self._convert_ipc_meta_to_tensor(ipc_meta)
self.models[0].set_state_dict(state_dict)
logger.info(
f"IPC no-reshard update parameters completed from file: {self.ipc_path}"
)
def load_model(self) -> nn.Layer:
"""Standard model loading without IPC."""
architectures = self.fd_config.model_config.architectures[0]
model_class = MODEL_CLASSES[architectures]
state_dict = load_composite_checkpoint(
self.fd_config.parallel_config.model_name_or_path,
model_class,
self.fd_config.model_config,
return_numpy=True)
self.models[0].set_state_dict(state_dict)
logger.info("normal load update parameters completed")
def clear_parameters(self, pid: int = 0) -> None: def clear_parameters(self, pid: int = 0) -> None:
"""Clear all model parameters and free memory.""" """Clear all model parameters and free memory."""
logger.info("start clear paramaters") logger.info("start clear paramaters")
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
for model in self.models: for param in self.model.state_dict().values():
for param in model.state_dict().values(): param._clear_data()
param._clear_data()
self._verify_parameters("clearance") self._verify_parameters("clearance")
if self.nranks > 1: if self.nranks > 1:

View File

@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" """
from fastdeploy.worker.worker_process import initialize_fd_config from fastdeploy.worker.worker_process import initialize_fd_config
@@ -24,7 +25,7 @@ class RolloutModelConfig:
max_model_len: int = 32768, max_model_len: int = 32768,
tensor_parallel_size: int = 4, tensor_parallel_size: int = 4,
dynamic_load_weight: bool = True, dynamic_load_weight: bool = True,
load_strategy: str = "meta", load_strategy: str = "ipc_snapshot",
enable_mm: bool = False, enable_mm: bool = False,
# Default values for all other parameters # Default values for all other parameters
max_num_seqs: int = 34, max_num_seqs: int = 34,

View File

@@ -535,14 +535,11 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--load_strategy", "--load_strategy",
type=str, type=str,
choices=['ipc', 'ipc_no_reshard', 'ipc_snapshot', 'meta', 'normal'], choices=['ipc', 'ipc_snapshot'],
default='meta', default="ipc_snapshot",
help="Weight loading method when dynamic loading is enabled: " help="Weight loading method when dynamic loading is enabled: "
"'ipc': real-time IPC streaming with automatic resharding, " "'ipc': real-time IPC streaming with automatic resharding, "
"'ipc_no_reshard': IPC streaming without weight processing, " "'ipc_snapshot': load from disk snapshot of IPC weights.")
"'ipc_snapshot': load from disk snapshot of IPC weights, "
"'meta': provide RL traing worker, no_weights_load"
"'normal':normal load weight")
parser.add_argument("--enable_mm", parser.add_argument("--enable_mm",
type=str, type=str,
default="false", default="false",