mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -31,10 +31,12 @@ import paddle
|
||||
from paddle.common_ops_import import convert_dtype
|
||||
from paddleformers.transformers.model_utils import _add_variant
|
||||
from paddleformers.transformers.utils import paddleformers_load
|
||||
from paddleformers.utils.env import (PADDLE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_MASTER_WEIGHTS_INDEX_NAME,
|
||||
SAFE_PEFT_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME)
|
||||
from paddleformers.utils.env import (
|
||||
PADDLE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_MASTER_WEIGHTS_INDEX_NAME,
|
||||
SAFE_PEFT_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
)
|
||||
from paddleformers.utils.log import logger
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -44,6 +46,7 @@ MAX_DRAFT_TOKENS = 6
|
||||
|
||||
class LayerIdPlaceholder(str, enum.Enum):
|
||||
"""LayerIdPlaceholder"""
|
||||
|
||||
LAYER_ID = "layer_id"
|
||||
FFN_LAYER_ID = "ffn_layer_id"
|
||||
MOE_LAYER_ID = "moe_layer_id"
|
||||
@@ -51,6 +54,7 @@ class LayerIdPlaceholder(str, enum.Enum):
|
||||
TEXT_EXPERT_ID = "text_export_id"
|
||||
IMG_EXPERT_ID = "img_export_id"
|
||||
|
||||
|
||||
class WeightMeta(NamedTuple):
|
||||
"""
|
||||
#tensor split parameters
|
||||
@@ -59,6 +63,7 @@ class WeightMeta(NamedTuple):
|
||||
# is_column: whether to split by columns
|
||||
# extra: optional flags like "is_naive_2fuse", "is_gqa", "is_naive_3fuse"
|
||||
"""
|
||||
|
||||
weight_name: str
|
||||
is_column: bool
|
||||
extra: Optional[str] = None
|
||||
@@ -81,8 +86,7 @@ class UniqueIDGenerator:
|
||||
first_key = sorted_keys[0]
|
||||
first_parameter = state_dict[first_key].cast("float32")
|
||||
# 假设模型参数是唯一的,通过第一个key来获取md5sum
|
||||
model_md5 = hashlib.md5(str(
|
||||
first_parameter.sum()).encode("utf-8")).hexdigest()
|
||||
model_md5 = hashlib.md5(str(first_parameter.sum()).encode("utf-8")).hexdigest()
|
||||
unique_id = f"{model_md5}-{random.randint(10000, 99999)}"
|
||||
return unique_id
|
||||
|
||||
@@ -99,20 +103,16 @@ def load_sharded_checkpoint(folder, variant=None, return_numpy=False):
|
||||
|
||||
"""
|
||||
# Load the index
|
||||
pdparams_file = os.path.join(folder,
|
||||
_add_variant("model_state.pdparams", variant))
|
||||
lora_pdparams_file = os.path.join(
|
||||
folder, _add_variant("lora_model_state.pdparams", variant))
|
||||
safetensors_file = os.path.join(folder,
|
||||
_add_variant("model.safetensors", variant))
|
||||
pdparams_file = os.path.join(folder, _add_variant("model_state.pdparams", variant))
|
||||
lora_pdparams_file = os.path.join(folder, _add_variant("lora_model_state.pdparams", variant))
|
||||
safetensors_file = os.path.join(folder, _add_variant("model.safetensors", variant))
|
||||
if os.path.isfile(pdparams_file):
|
||||
return paddle.load(pdparams_file, return_numpy=return_numpy)
|
||||
if os.path.isfile(lora_pdparams_file):
|
||||
return paddle.load(lora_pdparams_file, return_numpy=return_numpy)
|
||||
if os.path.isfile(safetensors_file):
|
||||
try:
|
||||
from paddleformers.utils.safetensors import \
|
||||
fast_load_file as safe_load_file
|
||||
from paddleformers.utils.safetensors import fast_load_file as safe_load_file
|
||||
except ImportError:
|
||||
from safetensors.numpy import load_file as safe_load_file
|
||||
|
||||
@@ -120,18 +120,13 @@ def load_sharded_checkpoint(folder, variant=None, return_numpy=False):
|
||||
if not return_numpy:
|
||||
for key in list(state_dict.keys()):
|
||||
if isinstance(state_dict[key], np.ndarray):
|
||||
state_dict[key] = paddle.Tensor(state_dict.pop(key),
|
||||
zero_copy=True)
|
||||
state_dict[key] = paddle.Tensor(state_dict.pop(key), zero_copy=True)
|
||||
return state_dict
|
||||
|
||||
index_file = os.path.join(folder,
|
||||
_add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant))
|
||||
safe_index_file = os.path.join(
|
||||
folder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
|
||||
safe_master_file = os.path.join(
|
||||
folder, _add_variant(SAFE_MASTER_WEIGHTS_INDEX_NAME, variant))
|
||||
safe_peft_file = os.path.join(
|
||||
folder, _add_variant(SAFE_PEFT_WEIGHTS_INDEX_NAME, variant))
|
||||
index_file = os.path.join(folder, _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant))
|
||||
safe_index_file = os.path.join(folder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
|
||||
safe_master_file = os.path.join(folder, _add_variant(SAFE_MASTER_WEIGHTS_INDEX_NAME, variant))
|
||||
safe_peft_file = os.path.join(folder, _add_variant(SAFE_PEFT_WEIGHTS_INDEX_NAME, variant))
|
||||
|
||||
index_present = os.path.isfile(index_file)
|
||||
safe_index_present = os.path.isfile(safe_index_file)
|
||||
@@ -152,14 +147,11 @@ def load_sharded_checkpoint(folder, variant=None, return_numpy=False):
|
||||
load_safe = True
|
||||
load_index = safe_peft_file
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Could not find {index_file} or {safe_index_file} or {safe_peft_file}"
|
||||
)
|
||||
raise ValueError(f"Could not find {index_file} or {safe_index_file} or {safe_peft_file}")
|
||||
|
||||
if load_safe:
|
||||
try:
|
||||
from paddleformers.utils.safetensors import \
|
||||
fast_load_file as safe_load_file
|
||||
from paddleformers.utils.safetensors import fast_load_file as safe_load_file
|
||||
except ImportError:
|
||||
from safetensors.numpy import load_file as safe_load_file
|
||||
|
||||
@@ -167,8 +159,7 @@ def load_sharded_checkpoint(folder, variant=None, return_numpy=False):
|
||||
index = json.load(f)
|
||||
|
||||
shard_files = list(set(index["weight_map"].values()))
|
||||
loader = (safe_load_file if load_safe else partial(
|
||||
paddleformers_load, map_location="np" if return_numpy else "cpu"))
|
||||
loader = safe_load_file if load_safe else partial(paddleformers_load, map_location="np" if return_numpy else "cpu")
|
||||
|
||||
ret = {}
|
||||
for shard_file in tqdm(shard_files):
|
||||
@@ -183,8 +174,7 @@ def load_sharded_checkpoint(folder, variant=None, return_numpy=False):
|
||||
return ret
|
||||
|
||||
|
||||
def convert_ndarray_dtype(np_array: np.ndarray,
|
||||
target_dtype: str) -> np.ndarray:
|
||||
def convert_ndarray_dtype(np_array: np.ndarray, target_dtype: str) -> np.ndarray:
|
||||
"""convert ndarray
|
||||
|
||||
Args:
|
||||
@@ -195,8 +185,11 @@ def convert_ndarray_dtype(np_array: np.ndarray,
|
||||
np.ndarray: converted numpy ndarray instance
|
||||
"""
|
||||
source_dtype = convert_dtype(np_array.dtype)
|
||||
if source_dtype == "uint16" and target_dtype == "bfloat16" and paddle.is_compiled_with_custom_device(
|
||||
"iluvatar_gpu"):
|
||||
if (
|
||||
source_dtype == "uint16"
|
||||
and target_dtype == "bfloat16"
|
||||
and paddle.is_compiled_with_custom_device("iluvatar_gpu")
|
||||
):
|
||||
return np_array.view(dtype=target_dtype)
|
||||
if source_dtype == "uint16" or target_dtype == "bfloat16":
|
||||
if paddle.is_compiled_with_xpu():
|
||||
@@ -235,11 +228,9 @@ def pad_batch_data(insts, pad_id=0, return_seq_len=False, pad_style="right"):
|
||||
# pad to max input len
|
||||
# max_len = args.max_len
|
||||
if pad_style == "left":
|
||||
inst_data = np.array([[pad_id] * (max_len - len(inst)) + list(inst)
|
||||
for inst in insts])
|
||||
inst_data = np.array([[pad_id] * (max_len - len(inst)) + list(inst) for inst in insts])
|
||||
else:
|
||||
inst_data = np.array(
|
||||
[list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts])
|
||||
inst_data = np.array([list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts])
|
||||
if return_seq_len:
|
||||
seq_len = np.array([len(inst) for inst in insts])
|
||||
return inst_data.astype("int64").reshape([-1, max_len]), seq_len
|
||||
@@ -258,8 +249,7 @@ def load_prefix_weights(
|
||||
Args:
|
||||
prefix_path (str): the path of prefix weight
|
||||
"""
|
||||
past_key_values = paddle.to_tensor(
|
||||
np.load(f"{prefix_path}/pre_caches.npy")).unsqueeze(2)
|
||||
past_key_values = paddle.to_tensor(np.load(f"{prefix_path}/pre_caches.npy")).unsqueeze(2)
|
||||
|
||||
if batch_size > 1:
|
||||
past_key_values = paddle.concat([past_key_values] * batch_size, axis=2)
|
||||
@@ -305,8 +295,7 @@ def w4a8_weight_convert(state_dict):
|
||||
name,
|
||||
w4a8_weight_bites_name_map,
|
||||
)
|
||||
state_dict[name] = weight_q.numpy(
|
||||
) if weight_q is not None else value
|
||||
state_dict[name] = weight_q.numpy() if weight_q is not None else value
|
||||
del weight_q
|
||||
w4a8_weight_bites_layers_map = {}
|
||||
w4a8_weight_bites_layers_map["qkv_gemm_bits_map"] = []
|
||||
@@ -319,13 +308,10 @@ def w4a8_weight_convert(state_dict):
|
||||
elif "out_proj" in name_keys:
|
||||
w4a8_weight_bites_layers_map["out_gemm_bits_map"].append(gemm_bits)
|
||||
elif "linear1" in name_keys:
|
||||
w4a8_weight_bites_layers_map["up_gate_proj_gemm_bits_map"].append(
|
||||
gemm_bits)
|
||||
w4a8_weight_bites_layers_map["up_gate_proj_gemm_bits_map"].append(gemm_bits)
|
||||
elif "linear2" in name_keys:
|
||||
w4a8_weight_bites_layers_map["down_proj_gemm_bits_map"].append(
|
||||
gemm_bits)
|
||||
logger.debug(
|
||||
f"w4a8_weight_bites_layers_map:{w4a8_weight_bites_layers_map}")
|
||||
w4a8_weight_bites_layers_map["down_proj_gemm_bits_map"].append(gemm_bits)
|
||||
logger.debug(f"w4a8_weight_bites_layers_map:{w4a8_weight_bites_layers_map}")
|
||||
return state_dict, w4a8_weight_bites_layers_map
|
||||
|
||||
|
||||
@@ -415,10 +401,13 @@ def calculate_effective_tokens(training_args, train_dataset, max_seq_len):
|
||||
else:
|
||||
sharding_parallel_degree = 1
|
||||
|
||||
total_batch = (training_args.max_steps *
|
||||
training_args.per_device_train_batch_size *
|
||||
training_args.gradient_accumulation_steps *
|
||||
sharding_parallel_degree * data_parallel_degree)
|
||||
total_batch = (
|
||||
training_args.max_steps
|
||||
* training_args.per_device_train_batch_size
|
||||
* training_args.gradient_accumulation_steps
|
||||
* sharding_parallel_degree
|
||||
* data_parallel_degree
|
||||
)
|
||||
for i, data in enumerate(train_dataset):
|
||||
if i == total_batch:
|
||||
break
|
||||
@@ -464,7 +453,7 @@ def parser_quant_type(quant_type):
|
||||
"fp8": "float8_e4m3fn",
|
||||
"fp16": "float16",
|
||||
"bf16": "bfloat16",
|
||||
"fp32": "float32"
|
||||
"fp32": "float32",
|
||||
}
|
||||
cache_type = default_type
|
||||
if "c8" in quant_type:
|
||||
@@ -483,8 +472,7 @@ def parser_quant_type(quant_type):
|
||||
pattern = f"({'|'.join(map(re.escape, ['w', 'a', 'c']))})"
|
||||
splited_type = re.split(pattern, quant_type)
|
||||
splited_type = [tmp_type for tmp_type in splited_type if tmp_type]
|
||||
assert (len(splited_type) % 2 == 0 and len(splited_type)
|
||||
<= 6), f"Quant type[{quant_type}] format error."
|
||||
assert len(splited_type) % 2 == 0 and len(splited_type) <= 6, f"Quant type[{quant_type}] format error."
|
||||
|
||||
quant_type_list = []
|
||||
if "w" in splited_type:
|
||||
|
Reference in New Issue
Block a user