mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
* Adjust processing order * fix bug * fix update_parameters bug * refine code
This commit is contained in:
@@ -18,6 +18,7 @@ from contextlib import contextmanager
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, Dict, List, Optional
|
from typing import Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import paddle.jit.dy2static.utils as jit_utils
|
||||||
import paddle.nn.layer
|
import paddle.nn.layer
|
||||||
from paddle.device.cuda import graphs
|
from paddle.device.cuda import graphs
|
||||||
|
|
||||||
@@ -52,27 +53,24 @@ class ConcreteSizeEntry:
|
|||||||
|
|
||||||
class Dy2StCudaGraphManager:
|
class Dy2StCudaGraphManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# NOTE(gongshaotian): Use local import to avoid RLHF version problems
|
|
||||||
from paddle.jit.dy2static.utils import CUDAGraphState
|
|
||||||
|
|
||||||
self.state = CUDAGraphState.DISABLE
|
self.state = jit_utils.CUDAGraphState.DISABLE
|
||||||
self.captured_batch_size = set()
|
self.captured_batch_size = set()
|
||||||
self.batch_size = -1
|
self.batch_size = -1
|
||||||
|
|
||||||
def run_impl(self, original_run_impl, inputs, parameters, attrs):
|
def run_impl(self, original_run_impl, inputs, parameters, attrs):
|
||||||
from paddle.jit.dy2static.utils import CUDAGraphState
|
|
||||||
|
|
||||||
run_state = self.state
|
run_state = self.state
|
||||||
prog_attrs, cuda_graph_attrs = attrs
|
prog_attrs, cuda_graph_attrs = attrs
|
||||||
if run_state == CUDAGraphState.REPLAY:
|
if run_state == jit_utils.CUDAGraphState.REPLAY:
|
||||||
if self.batch_size not in self.captured_batch_size:
|
if self.batch_size not in self.captured_batch_size:
|
||||||
run_state = CUDAGraphState.DISABLE
|
run_state = jit_utils.CUDAGraphState.DISABLE
|
||||||
elif run_state == CUDAGraphState.CAPTURE:
|
elif run_state == jit_utils.CUDAGraphState.CAPTURE:
|
||||||
self.captured_batch_size.add(self.batch_size)
|
self.captured_batch_size.add(self.batch_size)
|
||||||
|
|
||||||
cuda_graph_attrs |= {
|
cuda_graph_attrs |= {
|
||||||
"cuda_graph_state": run_state,
|
"cuda_graph_state": run_state,
|
||||||
"cuda_graph_dispatch_key": self.batch_size if run_state != CUDAGraphState.DISABLE else 0,
|
"cuda_graph_dispatch_key": self.batch_size if run_state != jit_utils.CUDAGraphState.DISABLE else 0,
|
||||||
}
|
}
|
||||||
return original_run_impl(inputs, parameters, (prog_attrs, cuda_graph_attrs))
|
return original_run_impl(inputs, parameters, (prog_attrs, cuda_graph_attrs))
|
||||||
|
|
||||||
@@ -105,7 +103,6 @@ class CudaGraphPiecewiseBackend:
|
|||||||
self.cuda_graph_manager = Dy2StCudaGraphManager()
|
self.cuda_graph_manager = Dy2StCudaGraphManager()
|
||||||
|
|
||||||
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
|
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
|
||||||
from paddle.jit.dy2static.utils import CUDAGraphState
|
|
||||||
|
|
||||||
if not entry.captured:
|
if not entry.captured:
|
||||||
# Warmup the model
|
# Warmup the model
|
||||||
@@ -122,14 +119,14 @@ class CudaGraphPiecewiseBackend:
|
|||||||
entry.input_addresses = input_addresses
|
entry.input_addresses = input_addresses
|
||||||
|
|
||||||
# Capture
|
# Capture
|
||||||
self.cuda_graph_manager.state = CUDAGraphState.CAPTURE
|
self.cuda_graph_manager.state = jit_utils.CUDAGraphState.CAPTURE
|
||||||
self.cuda_graph_manager.batch_size = entry.real_shape
|
self.cuda_graph_manager.batch_size = entry.real_shape
|
||||||
entry.captured = True
|
entry.captured = True
|
||||||
with self.cuda_graph_manager.run_impl_guard():
|
with self.cuda_graph_manager.run_impl_guard():
|
||||||
entry.runnable(**kwargs)
|
entry.runnable(**kwargs)
|
||||||
|
|
||||||
# Replay
|
# Replay
|
||||||
self.cuda_graph_manager.state = CUDAGraphState.REPLAY
|
self.cuda_graph_manager.state = jit_utils.CUDAGraphState.REPLAY
|
||||||
self.cuda_graph_manager.batch_size = entry.real_shape
|
self.cuda_graph_manager.batch_size = entry.real_shape
|
||||||
with self.cuda_graph_manager.run_impl_guard():
|
with self.cuda_graph_manager.run_impl_guard():
|
||||||
return entry.runnable(**kwargs)
|
return entry.runnable(**kwargs)
|
||||||
|
@@ -44,6 +44,7 @@ class DynamicWeightManager:
|
|||||||
self.model: nn.Layer = model
|
self.model: nn.Layer = model
|
||||||
self._capture_model_state()
|
self._capture_model_state()
|
||||||
self.update_parameters()
|
self.update_parameters()
|
||||||
|
self.finalize_update()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, "
|
f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, "
|
||||||
@@ -79,8 +80,6 @@ class DynamicWeightManager:
|
|||||||
|
|
||||||
logger.info(f"Update parameters in {time.perf_counter()-start_time:.2f}s")
|
logger.info(f"Update parameters in {time.perf_counter()-start_time:.2f}s")
|
||||||
|
|
||||||
self._finalize_update(pid)
|
|
||||||
|
|
||||||
def _update_ipc_snapshot(self):
|
def _update_ipc_snapshot(self):
|
||||||
"""Update using IPC snapshot strategy for elastic recovery."""
|
"""Update using IPC snapshot strategy for elastic recovery."""
|
||||||
model_path = os.path.join(
|
model_path = os.path.join(
|
||||||
@@ -144,7 +143,7 @@ class DynamicWeightManager:
|
|||||||
if src.shape != dst.shape:
|
if src.shape != dst.shape:
|
||||||
raise ValueError(f"Shape mismatch for {name}: {src.shape} vs {dst.shape}")
|
raise ValueError(f"Shape mismatch for {name}: {src.shape} vs {dst.shape}")
|
||||||
|
|
||||||
def _finalize_update(self, pid: int):
|
def finalize_update(self, pid: int = 0):
|
||||||
"""Finalize update process with verification."""
|
"""Finalize update process with verification."""
|
||||||
self._verify_parameters("update")
|
self._verify_parameters("update")
|
||||||
if self.parallel_config.tensor_parallel_size > 1:
|
if self.parallel_config.tensor_parallel_size > 1:
|
||||||
|
@@ -1705,25 +1705,27 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.forward_meta.clear_caches()
|
self.forward_meta.clear_caches()
|
||||||
|
|
||||||
def clear_parameters(self, pid):
|
def clear_parameters(self, pid):
|
||||||
""" " Dynamic model loader use to clear parameters use for RL"""
|
"""Dynamic model loader use to clear parameters use for RL"""
|
||||||
|
# Clear CUDAGraph
|
||||||
|
if self.use_cudagraph:
|
||||||
|
self.model.clear_grpah_opt_backend()
|
||||||
|
# Clear parameters and Send single
|
||||||
self.dynamic_weight_manager.clear_parameters(pid)
|
self.dynamic_weight_manager.clear_parameters(pid)
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
paddle.device.cuda.empty_cache()
|
paddle.device.cuda.empty_cache()
|
||||||
|
|
||||||
# Clear CudaGraph
|
|
||||||
if self.use_cudagraph:
|
|
||||||
self.model.clear_grpah_opt_backend()
|
|
||||||
|
|
||||||
self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory")
|
self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory")
|
||||||
|
|
||||||
def update_parameters(self, pid):
|
def update_parameters(self, pid):
|
||||||
""" " Dynamic model loader use to update parameters use for RL"""
|
"""Dynamic model loader use to update parameters use for RL"""
|
||||||
|
# Update parameters
|
||||||
self.dynamic_weight_manager.update_parameters(pid)
|
self.dynamic_weight_manager.update_parameters(pid)
|
||||||
self.initialize_kv_cache()
|
self.initialize_kv_cache()
|
||||||
|
# Recapture CUDAGraph
|
||||||
# Recapture CudaGraph
|
|
||||||
if self.use_cudagraph:
|
if self.use_cudagraph:
|
||||||
self.capture_model()
|
self.capture_model()
|
||||||
|
# Send single
|
||||||
|
self.dynamic_weight_manager.finalize_update(pid)
|
||||||
|
|
||||||
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
|
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user