[RL]Resolve shape mismatch problems in RL-related modules (#5032)

* RL fix

* update
This commit is contained in:
bukejiyu
2025-11-19 11:12:48 +08:00
committed by GitHub
parent 4694ed2a43
commit a82f25ea7b
12 changed files with 61 additions and 87 deletions

View File

@@ -56,8 +56,8 @@ class DefaultModelLoaderV1(BaseModelLoader):
load_weights_from_cache(model, weights_iterator)
else:
model.load_weights(weights_iterator)
if fd_config.speculative_config.model_type != "mtp":
process_final_after_loading(model, fd_config)
if fd_config.speculative_config.model_type != "mtp":
process_final_after_loading(model, fd_config)
self.clean_memory_fragments()
@@ -76,6 +76,7 @@ class DefaultModelLoaderV1(BaseModelLoader):
architectures = architectures + "RL"
enable_cache, _, weight_cache_context = is_weight_cache_enabled(fd_config)
fd_config.model_config.enable_cache = enable_cache
with weight_cache_context:
with context:
model_cls = ModelRegistry.get_class(architectures)
@@ -88,6 +89,8 @@ class DefaultModelLoaderV1(BaseModelLoader):
assert_never(convert_type)
model = model_cls(fd_config)
if fd_config.load_config.dynamic_load_weight or fd_config.model_config.enable_cache:
process_final_after_loading(model, fd_config)
model.eval()
# RL model not need set_state_dict