polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -31,10 +31,12 @@ import paddle
from paddle.common_ops_import import convert_dtype
from paddleformers.transformers.model_utils import _add_variant
from paddleformers.transformers.utils import paddleformers_load
from paddleformers.utils.env import (PADDLE_WEIGHTS_INDEX_NAME,
SAFE_MASTER_WEIGHTS_INDEX_NAME,
SAFE_PEFT_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_INDEX_NAME)
from paddleformers.utils.env import (
PADDLE_WEIGHTS_INDEX_NAME,
SAFE_MASTER_WEIGHTS_INDEX_NAME,
SAFE_PEFT_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_INDEX_NAME,
)
from paddleformers.utils.log import logger
from tqdm import tqdm
@@ -44,6 +46,7 @@ MAX_DRAFT_TOKENS = 6
class LayerIdPlaceholder(str, enum.Enum):
"""LayerIdPlaceholder"""
LAYER_ID = "layer_id"
FFN_LAYER_ID = "ffn_layer_id"
MOE_LAYER_ID = "moe_layer_id"
@@ -51,6 +54,7 @@ class LayerIdPlaceholder(str, enum.Enum):
TEXT_EXPERT_ID = "text_export_id"
IMG_EXPERT_ID = "img_export_id"
class WeightMeta(NamedTuple):
"""
#tensor split parameters
@@ -59,6 +63,7 @@ class WeightMeta(NamedTuple):
# is_column: whether to split by columns
# extra: optional flags like "is_naive_2fuse", "is_gqa", "is_naive_3fuse"
"""
weight_name: str
is_column: bool
extra: Optional[str] = None
@@ -81,8 +86,7 @@ class UniqueIDGenerator:
first_key = sorted_keys[0]
first_parameter = state_dict[first_key].cast("float32")
# 假设模型参数是唯一的通过第一个key来获取md5sum
model_md5 = hashlib.md5(str(
first_parameter.sum()).encode("utf-8")).hexdigest()
model_md5 = hashlib.md5(str(first_parameter.sum()).encode("utf-8")).hexdigest()
unique_id = f"{model_md5}-{random.randint(10000, 99999)}"
return unique_id
@@ -99,20 +103,16 @@ def load_sharded_checkpoint(folder, variant=None, return_numpy=False):
"""
# Load the index
pdparams_file = os.path.join(folder,
_add_variant("model_state.pdparams", variant))
lora_pdparams_file = os.path.join(
folder, _add_variant("lora_model_state.pdparams", variant))
safetensors_file = os.path.join(folder,
_add_variant("model.safetensors", variant))
pdparams_file = os.path.join(folder, _add_variant("model_state.pdparams", variant))
lora_pdparams_file = os.path.join(folder, _add_variant("lora_model_state.pdparams", variant))
safetensors_file = os.path.join(folder, _add_variant("model.safetensors", variant))
if os.path.isfile(pdparams_file):
return paddle.load(pdparams_file, return_numpy=return_numpy)
if os.path.isfile(lora_pdparams_file):
return paddle.load(lora_pdparams_file, return_numpy=return_numpy)
if os.path.isfile(safetensors_file):
try:
from paddleformers.utils.safetensors import \
fast_load_file as safe_load_file
from paddleformers.utils.safetensors import fast_load_file as safe_load_file
except ImportError:
from safetensors.numpy import load_file as safe_load_file
@@ -120,18 +120,13 @@ def load_sharded_checkpoint(folder, variant=None, return_numpy=False):
if not return_numpy:
for key in list(state_dict.keys()):
if isinstance(state_dict[key], np.ndarray):
state_dict[key] = paddle.Tensor(state_dict.pop(key),
zero_copy=True)
state_dict[key] = paddle.Tensor(state_dict.pop(key), zero_copy=True)
return state_dict
index_file = os.path.join(folder,
_add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant))
safe_index_file = os.path.join(
folder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
safe_master_file = os.path.join(
folder, _add_variant(SAFE_MASTER_WEIGHTS_INDEX_NAME, variant))
safe_peft_file = os.path.join(
folder, _add_variant(SAFE_PEFT_WEIGHTS_INDEX_NAME, variant))
index_file = os.path.join(folder, _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant))
safe_index_file = os.path.join(folder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
safe_master_file = os.path.join(folder, _add_variant(SAFE_MASTER_WEIGHTS_INDEX_NAME, variant))
safe_peft_file = os.path.join(folder, _add_variant(SAFE_PEFT_WEIGHTS_INDEX_NAME, variant))
index_present = os.path.isfile(index_file)
safe_index_present = os.path.isfile(safe_index_file)
@@ -152,14 +147,11 @@ def load_sharded_checkpoint(folder, variant=None, return_numpy=False):
load_safe = True
load_index = safe_peft_file
else:
raise ValueError(
f"Could not find {index_file} or {safe_index_file} or {safe_peft_file}"
)
raise ValueError(f"Could not find {index_file} or {safe_index_file} or {safe_peft_file}")
if load_safe:
try:
from paddleformers.utils.safetensors import \
fast_load_file as safe_load_file
from paddleformers.utils.safetensors import fast_load_file as safe_load_file
except ImportError:
from safetensors.numpy import load_file as safe_load_file
@@ -167,8 +159,7 @@ def load_sharded_checkpoint(folder, variant=None, return_numpy=False):
index = json.load(f)
shard_files = list(set(index["weight_map"].values()))
loader = (safe_load_file if load_safe else partial(
paddleformers_load, map_location="np" if return_numpy else "cpu"))
loader = safe_load_file if load_safe else partial(paddleformers_load, map_location="np" if return_numpy else "cpu")
ret = {}
for shard_file in tqdm(shard_files):
@@ -183,8 +174,7 @@ def load_sharded_checkpoint(folder, variant=None, return_numpy=False):
return ret
def convert_ndarray_dtype(np_array: np.ndarray,
target_dtype: str) -> np.ndarray:
def convert_ndarray_dtype(np_array: np.ndarray, target_dtype: str) -> np.ndarray:
"""convert ndarray
Args:
@@ -195,8 +185,11 @@ def convert_ndarray_dtype(np_array: np.ndarray,
np.ndarray: converted numpy ndarray instance
"""
source_dtype = convert_dtype(np_array.dtype)
if source_dtype == "uint16" and target_dtype == "bfloat16" and paddle.is_compiled_with_custom_device(
"iluvatar_gpu"):
if (
source_dtype == "uint16"
and target_dtype == "bfloat16"
and paddle.is_compiled_with_custom_device("iluvatar_gpu")
):
return np_array.view(dtype=target_dtype)
if source_dtype == "uint16" or target_dtype == "bfloat16":
if paddle.is_compiled_with_xpu():
@@ -235,11 +228,9 @@ def pad_batch_data(insts, pad_id=0, return_seq_len=False, pad_style="right"):
# pad to max input len
# max_len = args.max_len
if pad_style == "left":
inst_data = np.array([[pad_id] * (max_len - len(inst)) + list(inst)
for inst in insts])
inst_data = np.array([[pad_id] * (max_len - len(inst)) + list(inst) for inst in insts])
else:
inst_data = np.array(
[list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts])
inst_data = np.array([list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts])
if return_seq_len:
seq_len = np.array([len(inst) for inst in insts])
return inst_data.astype("int64").reshape([-1, max_len]), seq_len
@@ -258,8 +249,7 @@ def load_prefix_weights(
Args:
prefix_path (str): the path of prefix weight
"""
past_key_values = paddle.to_tensor(
np.load(f"{prefix_path}/pre_caches.npy")).unsqueeze(2)
past_key_values = paddle.to_tensor(np.load(f"{prefix_path}/pre_caches.npy")).unsqueeze(2)
if batch_size > 1:
past_key_values = paddle.concat([past_key_values] * batch_size, axis=2)
@@ -305,8 +295,7 @@ def w4a8_weight_convert(state_dict):
name,
w4a8_weight_bites_name_map,
)
state_dict[name] = weight_q.numpy(
) if weight_q is not None else value
state_dict[name] = weight_q.numpy() if weight_q is not None else value
del weight_q
w4a8_weight_bites_layers_map = {}
w4a8_weight_bites_layers_map["qkv_gemm_bits_map"] = []
@@ -319,13 +308,10 @@ def w4a8_weight_convert(state_dict):
elif "out_proj" in name_keys:
w4a8_weight_bites_layers_map["out_gemm_bits_map"].append(gemm_bits)
elif "linear1" in name_keys:
w4a8_weight_bites_layers_map["up_gate_proj_gemm_bits_map"].append(
gemm_bits)
w4a8_weight_bites_layers_map["up_gate_proj_gemm_bits_map"].append(gemm_bits)
elif "linear2" in name_keys:
w4a8_weight_bites_layers_map["down_proj_gemm_bits_map"].append(
gemm_bits)
logger.debug(
f"w4a8_weight_bites_layers_map:{w4a8_weight_bites_layers_map}")
w4a8_weight_bites_layers_map["down_proj_gemm_bits_map"].append(gemm_bits)
logger.debug(f"w4a8_weight_bites_layers_map:{w4a8_weight_bites_layers_map}")
return state_dict, w4a8_weight_bites_layers_map
@@ -415,10 +401,13 @@ def calculate_effective_tokens(training_args, train_dataset, max_seq_len):
else:
sharding_parallel_degree = 1
total_batch = (training_args.max_steps *
training_args.per_device_train_batch_size *
training_args.gradient_accumulation_steps *
sharding_parallel_degree * data_parallel_degree)
total_batch = (
training_args.max_steps
* training_args.per_device_train_batch_size
* training_args.gradient_accumulation_steps
* sharding_parallel_degree
* data_parallel_degree
)
for i, data in enumerate(train_dataset):
if i == total_batch:
break
@@ -464,7 +453,7 @@ def parser_quant_type(quant_type):
"fp8": "float8_e4m3fn",
"fp16": "float16",
"bf16": "bfloat16",
"fp32": "float32"
"fp32": "float32",
}
cache_type = default_type
if "c8" in quant_type:
@@ -483,8 +472,7 @@ def parser_quant_type(quant_type):
pattern = f"({'|'.join(map(re.escape, ['w', 'a', 'c']))})"
splited_type = re.split(pattern, quant_type)
splited_type = [tmp_type for tmp_type in splited_type if tmp_type]
assert (len(splited_type) % 2 == 0 and len(splited_type)
<= 6), f"Quant type[{quant_type}] format error."
assert len(splited_type) % 2 == 0 and len(splited_type) <= 6, f"Quant type[{quant_type}] format error."
quant_type_list = []
if "w" in splited_type: