mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[RL]Resolve shape mismatch problems in RL-related modules (#5032)
* RL fix * update
This commit is contained in:
@@ -56,8 +56,8 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
||||
load_weights_from_cache(model, weights_iterator)
|
||||
else:
|
||||
model.load_weights(weights_iterator)
|
||||
if fd_config.speculative_config.model_type != "mtp":
|
||||
process_final_after_loading(model, fd_config)
|
||||
if fd_config.speculative_config.model_type != "mtp":
|
||||
process_final_after_loading(model, fd_config)
|
||||
|
||||
self.clean_memory_fragments()
|
||||
|
||||
@@ -76,6 +76,7 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
||||
architectures = architectures + "RL"
|
||||
|
||||
enable_cache, _, weight_cache_context = is_weight_cache_enabled(fd_config)
|
||||
fd_config.model_config.enable_cache = enable_cache
|
||||
with weight_cache_context:
|
||||
with context:
|
||||
model_cls = ModelRegistry.get_class(architectures)
|
||||
@@ -88,6 +89,8 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
||||
assert_never(convert_type)
|
||||
|
||||
model = model_cls(fd_config)
|
||||
if fd_config.load_config.dynamic_load_weight or fd_config.model_config.enable_cache:
|
||||
process_final_after_loading(model, fd_config)
|
||||
|
||||
model.eval()
|
||||
# RL model not need set_state_dict
|
||||
|
||||
Reference in New Issue
Block a user