add qwen-2.5-7B-PRM/ernie-rm (#4319)

This commit is contained in:
bukejiyu
2025-10-20 15:31:03 +08:00
committed by GitHub
parent 47595a2480
commit de2eaf4f81
10 changed files with 352 additions and 24 deletions

View File

@@ -16,8 +16,10 @@
import os
import re
from collections.abc import Mapping
from contextlib import contextmanager
from typing import Any, Optional, Union
from dataclasses import dataclass, field
from typing import Any, List, Optional, Union
import paddle
from paddleformers.utils.log import logger
@@ -150,6 +152,36 @@ def process_weights_after_loading(sublayers_dict: dict):
return fn
@dataclass
class WeightsMapper:
orig_to_new_prefix: Mapping[str, Optional[str]] = field(default_factory=dict)
def _map_name(self, key: str) -> Optional[str]:
for prefix, new_key in self.orig_to_new_prefix.items():
if key.startswith(prefix):
key = key.replace(prefix, new_key, 1)
return key
def apply(self, weight_name):
return self._map_name(weight_name)
def process_weights_before_loading(
*, skip_prefixes: Optional[List[str]] = None, mapper: Optional[WeightsMapper] = None
):
def _can_skip(weight_name):
return any(weight_name.startswith(p) for p in (skip_prefixes or []))
def fn(weight_name):
if mapper is not None:
weight_name = mapper.apply(weight_name)
if _can_skip(weight_name):
weight_name = None
return weight_name
return fn
def free_tensor(tensor):
if hasattr(tensor, "tensor_track"):
tensor.tensor_track = None