mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
cache feature (#3857)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
This commit is contained in:
@@ -14,9 +14,15 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
@@ -28,22 +34,139 @@ from paddleformers.utils.safetensors import fast_safe_open
|
||||
from safetensors import safe_open
|
||||
from tqdm import tqdm
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.linear import KVBatchLinear
|
||||
from fastdeploy.model_executor.models.tp_utils import (
|
||||
check_tensor_parallel_prerequisites,
|
||||
)
|
||||
from fastdeploy.model_executor.utils import switch_config_context
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
def measure_time(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
time_before_load = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
time_after_load = time.time()
|
||||
logger.info(f"Model loading took {time_after_load - time_before_load} seconds")
|
||||
return result
|
||||
def pdparams_weight_iterator(paddle_file_list: list[str]):
|
||||
for pdparams_file in tqdm(
|
||||
paddle_file_list,
|
||||
desc="Loading pdparams checkpoint shards",
|
||||
):
|
||||
state_dict = paddle.load(pdparams_file)
|
||||
yield from state_dict.items()
|
||||
del state_dict
|
||||
|
||||
return wrapper
|
||||
|
||||
def load_weights_form_cache(model, weights_iterator):
|
||||
params_dict = dict(model.named_parameters())
|
||||
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||
param = params_dict[loaded_weight_name]
|
||||
param.copy_(loaded_weight, False)
|
||||
if "embeddings" in loaded_weight_name and getattr(model, "tie_word_embeddings", False):
|
||||
model.lm_head.load_state_dict({model.lm_head.weight_key: loaded_weight})
|
||||
for _, model_sublayer in model.named_sublayers():
|
||||
if isinstance(model_sublayer, KVBatchLinear):
|
||||
model_sublayer.process_weights_after_loading()
|
||||
|
||||
|
||||
def get_weight_iterator(model_path: str):
|
||||
_, files_list, use_safetensors = get_all_weights_file(model_path)
|
||||
if use_safetensors:
|
||||
weights_iterator = fast_weights_iterator(files_list)
|
||||
else:
|
||||
weights_iterator = pdparams_weight_iterator(files_list)
|
||||
return weights_iterator
|
||||
|
||||
|
||||
def is_weight_cache_enabled(fd_config, weight_cache_path=".cache"):
|
||||
weight_cache_context = contextlib.nullcontext()
|
||||
weight_cache_dir = None
|
||||
enable_cache = False
|
||||
if envs.FD_ENABLE_MODEL_CACHE:
|
||||
model_weight_cache_path = os.path.join(fd_config.model_config.model, weight_cache_path)
|
||||
# model_type + quantization + tp_size + ep_size
|
||||
weight_cache_key = "_".join(
|
||||
[
|
||||
fd_config.model_config.model_type,
|
||||
fd_config.quant_config.name(),
|
||||
str(fd_config.parallel_config.tensor_parallel_size),
|
||||
str(fd_config.parallel_config.expert_parallel_size),
|
||||
]
|
||||
)
|
||||
# only support tp now
|
||||
hash_key = hashlib.md5(pickle.dumps(weight_cache_key)).hexdigest()
|
||||
weight_cache_dir = os.path.join(model_weight_cache_path, hash_key)
|
||||
if os.path.exists(weight_cache_dir):
|
||||
logger.info(
|
||||
f"Loading will prioritize cached models. Users are responsible for ensuring the saved model is correct. If any error occurs, deleting the cache at {weight_cache_dir} may resolve it."
|
||||
)
|
||||
enable_cache = True
|
||||
weight_cache_context = switch_config_context(fd_config.quant_config, "is_checkpoint_bf16", False)
|
||||
|
||||
return enable_cache, weight_cache_dir, weight_cache_context
|
||||
|
||||
|
||||
def save_model(model_arg_name="model", config_arg_name="fd_config"):
|
||||
@measure_time("Model saving")
|
||||
def _save_model(model_dict, weight_cache_dir):
|
||||
# Note: ProcessGroupNCCL do not support deepcopy protocol, we made modifications here.
|
||||
paddle.distributed.communication.group.Group.__deepcopy__ = lambda self, _: self
|
||||
paddle.distributed.communication.group.Group.to_json = lambda self: repr(self)
|
||||
paddle.save(model_dict, weight_cache_dir)
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
sig = inspect.signature(func)
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
fd_config = bound_args.arguments.get(config_arg_name, None)
|
||||
model = bound_args.arguments.get(model_arg_name, None)
|
||||
enable_cache, weight_cache_dir, _ = is_weight_cache_enabled(fd_config)
|
||||
assert fd_config is not None, "fd_config cannot be None"
|
||||
assert model is not None, "model cannot be None"
|
||||
if enable_cache:
|
||||
tp_weight_cache_dir = os.path.join(
|
||||
weight_cache_dir, f"rank{str(fd_config.parallel_config.tensor_parallel_rank)}"
|
||||
)
|
||||
context = switch_config_context(fd_config.model_config, "model", tp_weight_cache_dir)
|
||||
else:
|
||||
context = contextlib.nullcontext()
|
||||
|
||||
with context:
|
||||
result = func(*args, **kwargs)
|
||||
if envs.FD_ENABLE_MODEL_CACHE and weight_cache_dir is not None and not os.path.exists(weight_cache_dir):
|
||||
assert fd_config.quant_config is not None and getattr(
|
||||
fd_config.quant_config, "is_checkpoint_bf16", False
|
||||
), "Save cache only for dynamic quantization"
|
||||
tp_weight_cache_dir = os.path.join(
|
||||
weight_cache_dir, f"rank{str(fd_config.parallel_config.tensor_parallel_rank)}"
|
||||
)
|
||||
logger.info(f"Saving model to {tp_weight_cache_dir}")
|
||||
os.makedirs(
|
||||
tp_weight_cache_dir,
|
||||
exist_ok=True,
|
||||
)
|
||||
_save_model(model.state_dict(), os.path.join(tp_weight_cache_dir, "cache.pdparams"))
|
||||
else:
|
||||
logger.info("Weights are already cached, skip saving")
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def measure_time(prefix: str = "Model loading"):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
time_before = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
time_after = time.time()
|
||||
logger.info(f"{prefix} took {time_after - time_before:.3f} seconds")
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def load_reordered_experts(model_path: str, key_name: str):
|
||||
@@ -232,33 +355,38 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafete
|
||||
"""
|
||||
|
||||
state_dict = {}
|
||||
_, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}"))
|
||||
_, safetensor_files, _ = get_all_weights_file(os.path.join(model_path, f"rank{local_rank}"))
|
||||
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
||||
for name, weight in weights_iterator:
|
||||
state_dict[name] = weight
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_all_safetensors(model_path: str):
|
||||
def get_all_weights_file(model_path: str):
|
||||
"""
|
||||
get_all_safetensors
|
||||
"""
|
||||
safe_model_path = os.path.join(model_path, "model.safetensors")
|
||||
if os.path.exists(safe_model_path):
|
||||
safetensor_list = [safe_model_path]
|
||||
with safe_open(safe_model_path, framework="np", device="cpu") as f:
|
||||
key_name_list = f.keys()
|
||||
return key_name_list, safetensor_list
|
||||
model_path = Path(model_path)
|
||||
use_safetensors = True
|
||||
if any(model_path.glob("*.pdparams")):
|
||||
key_name_list = []
|
||||
files_list = [str(file) for file in model_path.glob("*.pdparams")]
|
||||
use_safetensors = False
|
||||
else:
|
||||
with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
|
||||
weight_map = json.load(f)["weight_map"]
|
||||
weight_files_in_index = set()
|
||||
for weight_name in weight_map:
|
||||
weight_files_in_index.add(os.path.join(model_path, weight_map[weight_name]))
|
||||
key_name_list = list(set(weight_map.keys()))
|
||||
safetensor_list = list(weight_files_in_index)
|
||||
safetensor_list.sort()
|
||||
return key_name_list, safetensor_list
|
||||
safe_model_path = model_path / "model.safetensors"
|
||||
if safe_model_path.exists():
|
||||
files_list = [str(safe_model_path)]
|
||||
with safe_open(safe_model_path, framework="np", device="cpu") as f:
|
||||
key_name_list = f.keys()
|
||||
return key_name_list, files_list, use_safetensors
|
||||
else:
|
||||
index_file = model_path / "model.safetensors.index.json"
|
||||
with index_file.open("r") as f:
|
||||
weight_map = json.load(f)["weight_map"]
|
||||
weight_files_in_index = {str(model_path / weight_map[name]) for name in weight_map}
|
||||
key_name_list = list(weight_map.keys())
|
||||
files_list = sorted(weight_files_in_index)
|
||||
return key_name_list, files_list, use_safetensors
|
||||
|
||||
|
||||
def load_tp_checkpoint_v1(
|
||||
@@ -271,7 +399,7 @@ def load_tp_checkpoint_v1(
|
||||
load_tp_checkpoint_v1
|
||||
"""
|
||||
|
||||
safetensor_keys, safetensor_files = get_all_safetensors(model_path)
|
||||
safetensor_keys, safetensor_files, _ = get_all_weights_file(model_path)
|
||||
|
||||
if use_fastsafetensor:
|
||||
weights_iterator = fastsafetensors_weights_iterator(safetensor_files)
|
||||
|
Reference in New Issue
Block a user