mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
add qwen-2.5-7B-PRM/ernie-rm (#4319)
This commit is contained in:
@@ -16,8 +16,10 @@
|
||||
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Optional, Union
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import paddle
|
||||
from paddleformers.utils.log import logger
|
||||
@@ -150,6 +152,36 @@ def process_weights_after_loading(sublayers_dict: dict):
|
||||
return fn
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightsMapper:
|
||||
orig_to_new_prefix: Mapping[str, Optional[str]] = field(default_factory=dict)
|
||||
|
||||
def _map_name(self, key: str) -> Optional[str]:
|
||||
for prefix, new_key in self.orig_to_new_prefix.items():
|
||||
if key.startswith(prefix):
|
||||
key = key.replace(prefix, new_key, 1)
|
||||
return key
|
||||
|
||||
def apply(self, weight_name):
|
||||
return self._map_name(weight_name)
|
||||
|
||||
|
||||
def process_weights_before_loading(
|
||||
*, skip_prefixes: Optional[List[str]] = None, mapper: Optional[WeightsMapper] = None
|
||||
):
|
||||
def _can_skip(weight_name):
|
||||
return any(weight_name.startswith(p) for p in (skip_prefixes or []))
|
||||
|
||||
def fn(weight_name):
|
||||
if mapper is not None:
|
||||
weight_name = mapper.apply(weight_name)
|
||||
if _can_skip(weight_name):
|
||||
weight_name = None
|
||||
return weight_name
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def free_tensor(tensor):
|
||||
if hasattr(tensor, "tensor_track"):
|
||||
tensor.tensor_track = None
|
||||
|
||||
Reference in New Issue
Block a user