mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
support dynamic load for normal (#5437)
This commit is contained in:
@@ -95,3 +95,31 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
# TODO(gongshaotian): Now, only support safetensor
|
||||
self.load_weights(model, fd_config, architectures)
|
||||
return model
|
||||
|
||||
def load_rl_mock_model(self, fd_config: FDConfig) -> nn.Layer:
|
||||
"""use for rl model load"""
|
||||
# (TODO:gaoziyuan) optimze
|
||||
original_architectures = fd_config.model_config.architectures[0]
|
||||
logger.info(f"Starting to load model {original_architectures}.")
|
||||
|
||||
import fastdeploy.rl # noqa
|
||||
|
||||
if fd_config.speculative_config.model_type != "mtp":
|
||||
model_architectures = original_architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
|
||||
else:
|
||||
model_architectures = original_architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
|
||||
|
||||
model_architectures += "RL"
|
||||
context = paddle.LazyGuard()
|
||||
|
||||
with context:
|
||||
model_cls = ModelRegistry.get_class(model_architectures)
|
||||
model = model_cls(fd_config)
|
||||
|
||||
model.eval()
|
||||
|
||||
if fd_config.load_config.load_strategy == "normal":
|
||||
# normal strategy need load weight and architectures need without "RL"
|
||||
self.load_weights(model, fd_config, original_architectures)
|
||||
# RL model not need set_state_dict
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user