mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
support dynamic load for normal (#5437)
This commit is contained in:
@@ -95,3 +95,31 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
# TODO(gongshaotian): Now, only support safetensor
|
||||
self.load_weights(model, fd_config, architectures)
|
||||
return model
|
||||
|
||||
def load_rl_mock_model(self, fd_config: FDConfig) -> nn.Layer:
|
||||
"""use for rl model load"""
|
||||
# (TODO:gaoziyuan) optimze
|
||||
original_architectures = fd_config.model_config.architectures[0]
|
||||
logger.info(f"Starting to load model {original_architectures}.")
|
||||
|
||||
import fastdeploy.rl # noqa
|
||||
|
||||
if fd_config.speculative_config.model_type != "mtp":
|
||||
model_architectures = original_architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
|
||||
else:
|
||||
model_architectures = original_architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
|
||||
|
||||
model_architectures += "RL"
|
||||
context = paddle.LazyGuard()
|
||||
|
||||
with context:
|
||||
model_cls = ModelRegistry.get_class(model_architectures)
|
||||
model = model_cls(fd_config)
|
||||
|
||||
model.eval()
|
||||
|
||||
if fd_config.load_config.load_strategy == "normal":
|
||||
# normal strategy need load weight and architectures need without "RL"
|
||||
self.load_weights(model, fd_config, original_architectures)
|
||||
# RL model not need set_state_dict
|
||||
return model
|
||||
|
||||
@@ -98,3 +98,30 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
||||
return model
|
||||
self.load_weights(model, fd_config, enable_cache)
|
||||
return model
|
||||
|
||||
def load_rl_mock_model(self, fd_config: FDConfig) -> nn.Layer:
|
||||
"""use for rl model load"""
|
||||
# (TODO:gaoziyuan) optimze
|
||||
original_architectures = fd_config.model_config.architectures[0]
|
||||
|
||||
import fastdeploy.rl # noqa
|
||||
|
||||
if fd_config.speculative_config.model_type != "mtp":
|
||||
model_architectures = original_architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
|
||||
else:
|
||||
model_architectures = original_architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
|
||||
|
||||
model_architectures += "RL"
|
||||
context = paddle.LazyGuard()
|
||||
|
||||
with context:
|
||||
model_cls = ModelRegistry.get_class(model_architectures)
|
||||
model = model_cls(fd_config)
|
||||
|
||||
model.eval()
|
||||
|
||||
if fd_config.load_config.load_strategy == "normal":
|
||||
# normal strategy need load weight and architectures need without "RL"
|
||||
self.load_weights(model, fd_config, original_architectures)
|
||||
# RL model not need set_state_dict
|
||||
return model
|
||||
|
||||
@@ -86,6 +86,7 @@ class DynamicWeightManager:
|
||||
strategy_handlers = {
|
||||
"ipc_snapshot": self._update_ipc_snapshot,
|
||||
"ipc": self._update_ipc,
|
||||
"normal": self._normal_load_weight,
|
||||
}
|
||||
|
||||
if handler := strategy_handlers.get(self.load_config.load_strategy):
|
||||
@@ -100,6 +101,14 @@ class DynamicWeightManager:
|
||||
# step5: recapture cuda_graph
|
||||
# step6: update weight status signal
|
||||
|
||||
def _normal_load_weight(self):
|
||||
"""use for RL mock."""
|
||||
from fastdeploy.model_executor.model_loader import get_model_loader
|
||||
|
||||
model_loader = get_model_loader(load_config=self.fd_config.load_config)
|
||||
state_dict = model_loader.load_rl_mock_model(fd_config=self.fd_config).state_dict()
|
||||
self._update_model_from_state(state_dict, "raw")
|
||||
|
||||
def _update_ipc_snapshot(self):
|
||||
"""Update using IPC snapshot strategy for elastic recovery."""
|
||||
model_path = os.path.join(
|
||||
|
||||
Reference in New Issue
Block a user