cache feature (#3857)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

This commit is contained in:
bukejiyu
2025-09-07 18:52:46 +08:00
committed by GitHub
parent 30a1c1783f
commit e52ce1c4b1
12 changed files with 346 additions and 56 deletions

View File

@@ -14,9 +14,15 @@
# limitations under the License.
"""
import contextlib
import hashlib
import inspect
import json
import os
import pickle
import time
from functools import wraps
from pathlib import Path
import paddle
import paddle.distributed as dist
@@ -28,22 +34,139 @@ from paddleformers.utils.safetensors import fast_safe_open
from safetensors import safe_open
from tqdm import tqdm
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.linear import KVBatchLinear
from fastdeploy.model_executor.models.tp_utils import (
check_tensor_parallel_prerequisites,
)
from fastdeploy.model_executor.utils import switch_config_context
from fastdeploy.platforms import current_platform
def measure_time(func):
def wrapper(*args, **kwargs):
time_before_load = time.time()
result = func(*args, **kwargs)
time_after_load = time.time()
logger.info(f"Model loading took {time_after_load - time_before_load} seconds")
return result
def pdparams_weight_iterator(paddle_file_list: list[str]):
for pdparams_file in tqdm(
paddle_file_list,
desc="Loading pdparams checkpoint shards",
):
state_dict = paddle.load(pdparams_file)
yield from state_dict.items()
del state_dict
return wrapper
def load_weights_form_cache(model, weights_iterator):
params_dict = dict(model.named_parameters())
for loaded_weight_name, loaded_weight in weights_iterator:
param = params_dict[loaded_weight_name]
param.copy_(loaded_weight, False)
if "embeddings" in loaded_weight_name and getattr(model, "tie_word_embeddings", False):
model.lm_head.load_state_dict({model.lm_head.weight_key: loaded_weight})
for _, model_sublayer in model.named_sublayers():
if isinstance(model_sublayer, KVBatchLinear):
model_sublayer.process_weights_after_loading()
def get_weight_iterator(model_path: str):
_, files_list, use_safetensors = get_all_weights_file(model_path)
if use_safetensors:
weights_iterator = fast_weights_iterator(files_list)
else:
weights_iterator = pdparams_weight_iterator(files_list)
return weights_iterator
def is_weight_cache_enabled(fd_config, weight_cache_path=".cache"):
weight_cache_context = contextlib.nullcontext()
weight_cache_dir = None
enable_cache = False
if envs.FD_ENABLE_MODEL_CACHE:
model_weight_cache_path = os.path.join(fd_config.model_config.model, weight_cache_path)
# model_type + quantization + tp_size + ep_size
weight_cache_key = "_".join(
[
fd_config.model_config.model_type,
fd_config.quant_config.name(),
str(fd_config.parallel_config.tensor_parallel_size),
str(fd_config.parallel_config.expert_parallel_size),
]
)
# only support tp now
hash_key = hashlib.md5(pickle.dumps(weight_cache_key)).hexdigest()
weight_cache_dir = os.path.join(model_weight_cache_path, hash_key)
if os.path.exists(weight_cache_dir):
logger.info(
f"Loading will prioritize cached models. Users are responsible for ensuring the saved model is correct. If any error occurs, deleting the cache at {weight_cache_dir} may resolve it."
)
enable_cache = True
weight_cache_context = switch_config_context(fd_config.quant_config, "is_checkpoint_bf16", False)
return enable_cache, weight_cache_dir, weight_cache_context
def save_model(model_arg_name="model", config_arg_name="fd_config"):
@measure_time("Model saving")
def _save_model(model_dict, weight_cache_dir):
# Note: ProcessGroupNCCL do not support deepcopy protocol, we made modifications here.
paddle.distributed.communication.group.Group.__deepcopy__ = lambda self, _: self
paddle.distributed.communication.group.Group.to_json = lambda self: repr(self)
paddle.save(model_dict, weight_cache_dir)
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
fd_config = bound_args.arguments.get(config_arg_name, None)
model = bound_args.arguments.get(model_arg_name, None)
enable_cache, weight_cache_dir, _ = is_weight_cache_enabled(fd_config)
assert fd_config is not None, "fd_config cannot be None"
assert model is not None, "model cannot be None"
if enable_cache:
tp_weight_cache_dir = os.path.join(
weight_cache_dir, f"rank{str(fd_config.parallel_config.tensor_parallel_rank)}"
)
context = switch_config_context(fd_config.model_config, "model", tp_weight_cache_dir)
else:
context = contextlib.nullcontext()
with context:
result = func(*args, **kwargs)
if envs.FD_ENABLE_MODEL_CACHE and weight_cache_dir is not None and not os.path.exists(weight_cache_dir):
assert fd_config.quant_config is not None and getattr(
fd_config.quant_config, "is_checkpoint_bf16", False
), "Save cache only for dynamic quantization"
tp_weight_cache_dir = os.path.join(
weight_cache_dir, f"rank{str(fd_config.parallel_config.tensor_parallel_rank)}"
)
logger.info(f"Saving model to {tp_weight_cache_dir}")
os.makedirs(
tp_weight_cache_dir,
exist_ok=True,
)
_save_model(model.state_dict(), os.path.join(tp_weight_cache_dir, "cache.pdparams"))
else:
logger.info("Weights are already cached, skip saving")
return result
return wrapper
return decorator
def measure_time(prefix: str = "Model loading"):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
time_before = time.time()
result = func(*args, **kwargs)
time_after = time.time()
logger.info(f"{prefix} took {time_after - time_before:.3f} seconds")
return result
return wrapper
return decorator
def load_reordered_experts(model_path: str, key_name: str):
@@ -232,33 +355,38 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafete
"""
state_dict = {}
_, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}"))
_, safetensor_files, _ = get_all_weights_file(os.path.join(model_path, f"rank{local_rank}"))
weights_iterator = safetensors_weights_iterator(safetensor_files)
for name, weight in weights_iterator:
state_dict[name] = weight
return state_dict
def get_all_safetensors(model_path: str):
def get_all_weights_file(model_path: str):
"""
get_all_safetensors
"""
safe_model_path = os.path.join(model_path, "model.safetensors")
if os.path.exists(safe_model_path):
safetensor_list = [safe_model_path]
with safe_open(safe_model_path, framework="np", device="cpu") as f:
key_name_list = f.keys()
return key_name_list, safetensor_list
model_path = Path(model_path)
use_safetensors = True
if any(model_path.glob("*.pdparams")):
key_name_list = []
files_list = [str(file) for file in model_path.glob("*.pdparams")]
use_safetensors = False
else:
with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
weight_map = json.load(f)["weight_map"]
weight_files_in_index = set()
for weight_name in weight_map:
weight_files_in_index.add(os.path.join(model_path, weight_map[weight_name]))
key_name_list = list(set(weight_map.keys()))
safetensor_list = list(weight_files_in_index)
safetensor_list.sort()
return key_name_list, safetensor_list
safe_model_path = model_path / "model.safetensors"
if safe_model_path.exists():
files_list = [str(safe_model_path)]
with safe_open(safe_model_path, framework="np", device="cpu") as f:
key_name_list = f.keys()
return key_name_list, files_list, use_safetensors
else:
index_file = model_path / "model.safetensors.index.json"
with index_file.open("r") as f:
weight_map = json.load(f)["weight_map"]
weight_files_in_index = {str(model_path / weight_map[name]) for name in weight_map}
key_name_list = list(weight_map.keys())
files_list = sorted(weight_files_in_index)
return key_name_list, files_list, use_safetensors
def load_tp_checkpoint_v1(
@@ -271,7 +399,7 @@ def load_tp_checkpoint_v1(
load_tp_checkpoint_v1
"""
safetensor_keys, safetensor_files = get_all_safetensors(model_path)
safetensor_keys, safetensor_files, _ = get_all_weights_file(model_path)
if use_fastsafetensor:
weights_iterator = fastsafetensors_weights_iterator(safetensor_files)