mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
cache feature (#3857)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
This commit is contained in:
@@ -97,6 +97,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"FD_DISABLE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_DISABLE_CHUNKED_PREFILL", "0"))),
|
"FD_DISABLE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_DISABLE_CHUNKED_PREFILL", "0"))),
|
||||||
# Whether to use new get_output and save_output method (0 or 1)
|
# Whether to use new get_output and save_output method (0 or 1)
|
||||||
"FD_USE_GET_SAVE_OUTPUT_V1": lambda: bool(int(os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0"))),
|
"FD_USE_GET_SAVE_OUTPUT_V1": lambda: bool(int(os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0"))),
|
||||||
|
# Whether to enable model cache feature
|
||||||
|
"FD_ENABLE_MODEL_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_CACHE", "0"))),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1053,7 +1053,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
|||||||
self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2]
|
self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2]
|
||||||
self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size]
|
self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size]
|
||||||
|
|
||||||
if layer.fd_config.load_config.load_choices == "default_v1":
|
if self.quant_config.is_checkpoint_bf16:
|
||||||
layer.up_gate_proj_weight = layer.create_parameter(
|
layer.up_gate_proj_weight = layer.create_parameter(
|
||||||
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
|
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
|
||||||
dtype=layer.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
@@ -1138,7 +1138,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
|||||||
|
|
||||||
def process_weights_after_loading(self, layer):
|
def process_weights_after_loading(self, layer):
|
||||||
""" """
|
""" """
|
||||||
if not layer.fd_config.load_config.load_choices == "default_v1":
|
if not self.quant_config.is_checkpoint_bf16:
|
||||||
return
|
return
|
||||||
weight_id_map = {"gate_up": 0, "down": 1}
|
weight_id_map = {"gate_up": 0, "down": 1}
|
||||||
if (
|
if (
|
||||||
|
@@ -82,7 +82,7 @@ class MixQuantConfig(QuantConfigBase):
|
|||||||
.from_config(
|
.from_config(
|
||||||
{
|
{
|
||||||
"is_permuted": self.is_permuted,
|
"is_permuted": self.is_permuted,
|
||||||
"self.is_checkpoint_bf16": self.is_checkpoint_bf16,
|
"is_checkpoint_bf16": self.is_checkpoint_bf16,
|
||||||
"hadamard_block_size": self.hadamard_block_size,
|
"hadamard_block_size": self.hadamard_block_size,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -94,7 +94,7 @@ class MixQuantConfig(QuantConfigBase):
|
|||||||
.from_config(
|
.from_config(
|
||||||
{
|
{
|
||||||
"is_permuted": self.is_permuted,
|
"is_permuted": self.is_permuted,
|
||||||
"self.is_checkpoint_bf16": self.is_checkpoint_bf16,
|
"is_checkpoint_bf16": self.is_checkpoint_bf16,
|
||||||
"hadamard_block_size": self.hadamard_block_size,
|
"hadamard_block_size": self.hadamard_block_size,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -112,6 +112,6 @@ class MixQuantConfig(QuantConfigBase):
|
|||||||
else:
|
else:
|
||||||
return (
|
return (
|
||||||
get_quantization_config(self.dense_quant_type)
|
get_quantization_config(self.dense_quant_type)
|
||||||
.from_config({"self.is_checkpoint_bf16": self.is_checkpoint_bf16})
|
.from_config({"is_checkpoint_bf16": self.is_checkpoint_bf16})
|
||||||
.get_quant_method(layer)
|
.get_quant_method(layer)
|
||||||
)
|
)
|
||||||
|
@@ -45,6 +45,7 @@ class WeightOnlyConfig(QuantConfigBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
algo: str,
|
algo: str,
|
||||||
|
is_checkpoint_bf16: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.algo = algo
|
self.algo = algo
|
||||||
@@ -56,6 +57,7 @@ class WeightOnlyConfig(QuantConfigBase):
|
|||||||
self.quant_max_bound = 0
|
self.quant_max_bound = 0
|
||||||
self.quant_min_bound = 0
|
self.quant_min_bound = 0
|
||||||
self.quant_round_type = 0
|
self.quant_round_type = 0
|
||||||
|
self.is_checkpoint_bf16 = is_checkpoint_bf16
|
||||||
|
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "weight_only"
|
return "weight_only"
|
||||||
@@ -63,7 +65,8 @@ class WeightOnlyConfig(QuantConfigBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: dict) -> "WeightOnlyConfig":
|
def from_config(cls, config: dict) -> "WeightOnlyConfig":
|
||||||
algo = config["algo"]
|
algo = config["algo"]
|
||||||
return cls(algo)
|
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
||||||
|
return cls(algo, is_checkpoint_bf16)
|
||||||
|
|
||||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||||
if current_platform.is_xpu():
|
if current_platform.is_xpu():
|
||||||
@@ -154,12 +157,13 @@ class WINT8Config(WeightOnlyConfig):
|
|||||||
weight only int8 config
|
weight only int8 config
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, is_checkpoint_bf16: bool = False) -> None:
|
||||||
super().__init__("weight_only_int8")
|
super().__init__("weight_only_int8", is_checkpoint_bf16)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: dict) -> "WINT8Config":
|
def from_config(cls, config: dict) -> "WINT8Config":
|
||||||
return cls()
|
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
||||||
|
return cls(is_checkpoint_bf16)
|
||||||
|
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "wint8"
|
return "wint8"
|
||||||
@@ -172,12 +176,14 @@ class WINT4Config(WeightOnlyConfig):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
is_checkpoint_bf16: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__("weight_only_int4")
|
super().__init__("weight_only_int4", is_checkpoint_bf16)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: dict) -> "WINT4Config":
|
def from_config(cls, config: dict) -> "WINT4Config":
|
||||||
return cls()
|
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
||||||
|
return cls(is_checkpoint_bf16)
|
||||||
|
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "wint4"
|
return "wint4"
|
||||||
@@ -196,7 +202,7 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(self, layer, **extra_weight_attrs):
|
def create_weights(self, layer, **extra_weight_attrs):
|
||||||
if layer.fd_config.load_config.load_choices == "default_v1":
|
if self.quant_config.is_checkpoint_bf16:
|
||||||
layer.weight = layer.create_parameter(
|
layer.weight = layer.create_parameter(
|
||||||
shape=layer.weight_shape,
|
shape=layer.weight_shape,
|
||||||
dtype=layer.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
@@ -259,7 +265,7 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer) -> None:
|
def process_weights_after_loading(self, layer) -> None:
|
||||||
if not layer.fd_config.load_config.load_choices == "default_v1":
|
if not self.quant_config.is_checkpoint_bf16:
|
||||||
return
|
return
|
||||||
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
|
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
|
||||||
layer.weight,
|
layer.weight,
|
||||||
|
@@ -14,9 +14,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import hashlib
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
import time
|
import time
|
||||||
|
from functools import wraps
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
import paddle.distributed as dist
|
import paddle.distributed as dist
|
||||||
@@ -28,23 +34,140 @@ from paddleformers.utils.safetensors import fast_safe_open
|
|||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from fastdeploy import envs
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
|
from fastdeploy.model_executor.layers.linear import KVBatchLinear
|
||||||
from fastdeploy.model_executor.models.tp_utils import (
|
from fastdeploy.model_executor.models.tp_utils import (
|
||||||
check_tensor_parallel_prerequisites,
|
check_tensor_parallel_prerequisites,
|
||||||
)
|
)
|
||||||
|
from fastdeploy.model_executor.utils import switch_config_context
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
def measure_time(func):
|
def pdparams_weight_iterator(paddle_file_list: list[str]):
|
||||||
|
for pdparams_file in tqdm(
|
||||||
|
paddle_file_list,
|
||||||
|
desc="Loading pdparams checkpoint shards",
|
||||||
|
):
|
||||||
|
state_dict = paddle.load(pdparams_file)
|
||||||
|
yield from state_dict.items()
|
||||||
|
del state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_weights_form_cache(model, weights_iterator):
|
||||||
|
params_dict = dict(model.named_parameters())
|
||||||
|
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||||
|
param = params_dict[loaded_weight_name]
|
||||||
|
param.copy_(loaded_weight, False)
|
||||||
|
if "embeddings" in loaded_weight_name and getattr(model, "tie_word_embeddings", False):
|
||||||
|
model.lm_head.load_state_dict({model.lm_head.weight_key: loaded_weight})
|
||||||
|
for _, model_sublayer in model.named_sublayers():
|
||||||
|
if isinstance(model_sublayer, KVBatchLinear):
|
||||||
|
model_sublayer.process_weights_after_loading()
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight_iterator(model_path: str):
|
||||||
|
_, files_list, use_safetensors = get_all_weights_file(model_path)
|
||||||
|
if use_safetensors:
|
||||||
|
weights_iterator = fast_weights_iterator(files_list)
|
||||||
|
else:
|
||||||
|
weights_iterator = pdparams_weight_iterator(files_list)
|
||||||
|
return weights_iterator
|
||||||
|
|
||||||
|
|
||||||
|
def is_weight_cache_enabled(fd_config, weight_cache_path=".cache"):
|
||||||
|
weight_cache_context = contextlib.nullcontext()
|
||||||
|
weight_cache_dir = None
|
||||||
|
enable_cache = False
|
||||||
|
if envs.FD_ENABLE_MODEL_CACHE:
|
||||||
|
model_weight_cache_path = os.path.join(fd_config.model_config.model, weight_cache_path)
|
||||||
|
# model_type + quantization + tp_size + ep_size
|
||||||
|
weight_cache_key = "_".join(
|
||||||
|
[
|
||||||
|
fd_config.model_config.model_type,
|
||||||
|
fd_config.quant_config.name(),
|
||||||
|
str(fd_config.parallel_config.tensor_parallel_size),
|
||||||
|
str(fd_config.parallel_config.expert_parallel_size),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# only support tp now
|
||||||
|
hash_key = hashlib.md5(pickle.dumps(weight_cache_key)).hexdigest()
|
||||||
|
weight_cache_dir = os.path.join(model_weight_cache_path, hash_key)
|
||||||
|
if os.path.exists(weight_cache_dir):
|
||||||
|
logger.info(
|
||||||
|
f"Loading will prioritize cached models. Users are responsible for ensuring the saved model is correct. If any error occurs, deleting the cache at {weight_cache_dir} may resolve it."
|
||||||
|
)
|
||||||
|
enable_cache = True
|
||||||
|
weight_cache_context = switch_config_context(fd_config.quant_config, "is_checkpoint_bf16", False)
|
||||||
|
|
||||||
|
return enable_cache, weight_cache_dir, weight_cache_context
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(model_arg_name="model", config_arg_name="fd_config"):
|
||||||
|
@measure_time("Model saving")
|
||||||
|
def _save_model(model_dict, weight_cache_dir):
|
||||||
|
# Note: ProcessGroupNCCL do not support deepcopy protocol, we made modifications here.
|
||||||
|
paddle.distributed.communication.group.Group.__deepcopy__ = lambda self, _: self
|
||||||
|
paddle.distributed.communication.group.Group.to_json = lambda self: repr(self)
|
||||||
|
paddle.save(model_dict, weight_cache_dir)
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
time_before_load = time.time()
|
sig = inspect.signature(func)
|
||||||
|
bound_args = sig.bind(*args, **kwargs)
|
||||||
|
bound_args.apply_defaults()
|
||||||
|
fd_config = bound_args.arguments.get(config_arg_name, None)
|
||||||
|
model = bound_args.arguments.get(model_arg_name, None)
|
||||||
|
enable_cache, weight_cache_dir, _ = is_weight_cache_enabled(fd_config)
|
||||||
|
assert fd_config is not None, "fd_config cannot be None"
|
||||||
|
assert model is not None, "model cannot be None"
|
||||||
|
if enable_cache:
|
||||||
|
tp_weight_cache_dir = os.path.join(
|
||||||
|
weight_cache_dir, f"rank{str(fd_config.parallel_config.tensor_parallel_rank)}"
|
||||||
|
)
|
||||||
|
context = switch_config_context(fd_config.model_config, "model", tp_weight_cache_dir)
|
||||||
|
else:
|
||||||
|
context = contextlib.nullcontext()
|
||||||
|
|
||||||
|
with context:
|
||||||
result = func(*args, **kwargs)
|
result = func(*args, **kwargs)
|
||||||
time_after_load = time.time()
|
if envs.FD_ENABLE_MODEL_CACHE and weight_cache_dir is not None and not os.path.exists(weight_cache_dir):
|
||||||
logger.info(f"Model loading took {time_after_load - time_before_load} seconds")
|
assert fd_config.quant_config is not None and getattr(
|
||||||
|
fd_config.quant_config, "is_checkpoint_bf16", False
|
||||||
|
), "Save cache only for dynamic quantization"
|
||||||
|
tp_weight_cache_dir = os.path.join(
|
||||||
|
weight_cache_dir, f"rank{str(fd_config.parallel_config.tensor_parallel_rank)}"
|
||||||
|
)
|
||||||
|
logger.info(f"Saving model to {tp_weight_cache_dir}")
|
||||||
|
os.makedirs(
|
||||||
|
tp_weight_cache_dir,
|
||||||
|
exist_ok=True,
|
||||||
|
)
|
||||||
|
_save_model(model.state_dict(), os.path.join(tp_weight_cache_dir, "cache.pdparams"))
|
||||||
|
else:
|
||||||
|
logger.info("Weights are already cached, skip saving")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def measure_time(prefix: str = "Model loading"):
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
time_before = time.time()
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
time_after = time.time()
|
||||||
|
logger.info(f"{prefix} took {time_after - time_before:.3f} seconds")
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def load_reordered_experts(model_path: str, key_name: str):
|
def load_reordered_experts(model_path: str, key_name: str):
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
@@ -232,33 +355,38 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafete
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
_, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}"))
|
_, safetensor_files, _ = get_all_weights_file(os.path.join(model_path, f"rank{local_rank}"))
|
||||||
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
||||||
for name, weight in weights_iterator:
|
for name, weight in weights_iterator:
|
||||||
state_dict[name] = weight
|
state_dict[name] = weight
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def get_all_safetensors(model_path: str):
|
def get_all_weights_file(model_path: str):
|
||||||
"""
|
"""
|
||||||
get_all_safetensors
|
get_all_safetensors
|
||||||
"""
|
"""
|
||||||
safe_model_path = os.path.join(model_path, "model.safetensors")
|
model_path = Path(model_path)
|
||||||
if os.path.exists(safe_model_path):
|
use_safetensors = True
|
||||||
safetensor_list = [safe_model_path]
|
if any(model_path.glob("*.pdparams")):
|
||||||
|
key_name_list = []
|
||||||
|
files_list = [str(file) for file in model_path.glob("*.pdparams")]
|
||||||
|
use_safetensors = False
|
||||||
|
else:
|
||||||
|
safe_model_path = model_path / "model.safetensors"
|
||||||
|
if safe_model_path.exists():
|
||||||
|
files_list = [str(safe_model_path)]
|
||||||
with safe_open(safe_model_path, framework="np", device="cpu") as f:
|
with safe_open(safe_model_path, framework="np", device="cpu") as f:
|
||||||
key_name_list = f.keys()
|
key_name_list = f.keys()
|
||||||
return key_name_list, safetensor_list
|
return key_name_list, files_list, use_safetensors
|
||||||
else:
|
else:
|
||||||
with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
|
index_file = model_path / "model.safetensors.index.json"
|
||||||
|
with index_file.open("r") as f:
|
||||||
weight_map = json.load(f)["weight_map"]
|
weight_map = json.load(f)["weight_map"]
|
||||||
weight_files_in_index = set()
|
weight_files_in_index = {str(model_path / weight_map[name]) for name in weight_map}
|
||||||
for weight_name in weight_map:
|
key_name_list = list(weight_map.keys())
|
||||||
weight_files_in_index.add(os.path.join(model_path, weight_map[weight_name]))
|
files_list = sorted(weight_files_in_index)
|
||||||
key_name_list = list(set(weight_map.keys()))
|
return key_name_list, files_list, use_safetensors
|
||||||
safetensor_list = list(weight_files_in_index)
|
|
||||||
safetensor_list.sort()
|
|
||||||
return key_name_list, safetensor_list
|
|
||||||
|
|
||||||
|
|
||||||
def load_tp_checkpoint_v1(
|
def load_tp_checkpoint_v1(
|
||||||
@@ -271,7 +399,7 @@ def load_tp_checkpoint_v1(
|
|||||||
load_tp_checkpoint_v1
|
load_tp_checkpoint_v1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
safetensor_keys, safetensor_files = get_all_safetensors(model_path)
|
safetensor_keys, safetensor_files, _ = get_all_weights_file(model_path)
|
||||||
|
|
||||||
if use_fastsafetensor:
|
if use_fastsafetensor:
|
||||||
weights_iterator = fastsafetensors_weights_iterator(safetensor_files)
|
weights_iterator = fastsafetensors_weights_iterator(safetensor_files)
|
||||||
|
@@ -51,7 +51,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
paddle.device.cuda.empty_cache()
|
paddle.device.cuda.empty_cache()
|
||||||
paddle.device.synchronize()
|
paddle.device.synchronize()
|
||||||
|
|
||||||
@measure_time
|
@measure_time()
|
||||||
def load_weights(self, model, fd_config: FDConfig, architectures: str) -> None:
|
def load_weights(self, model, fd_config: FDConfig, architectures: str) -> None:
|
||||||
model_class = ModelRegistry.get_pretrain_cls(architectures)
|
model_class = ModelRegistry.get_pretrain_cls(architectures)
|
||||||
|
|
||||||
|
@@ -20,9 +20,11 @@ from paddleformers.utils.log import logger
|
|||||||
|
|
||||||
from fastdeploy.config import FDConfig, LoadConfig, ModelConfig
|
from fastdeploy.config import FDConfig, LoadConfig, ModelConfig
|
||||||
from fastdeploy.model_executor.load_weight_utils import (
|
from fastdeploy.model_executor.load_weight_utils import (
|
||||||
fast_weights_iterator,
|
get_weight_iterator,
|
||||||
get_all_safetensors,
|
is_weight_cache_enabled,
|
||||||
|
load_weights_form_cache,
|
||||||
measure_time,
|
measure_time,
|
||||||
|
save_model,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
|
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
|
||||||
from fastdeploy.model_executor.models.model_base import ModelRegistry
|
from fastdeploy.model_executor.models.model_base import ModelRegistry
|
||||||
@@ -44,10 +46,13 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
|||||||
paddle.device.cuda.empty_cache()
|
paddle.device.cuda.empty_cache()
|
||||||
paddle.device.synchronize()
|
paddle.device.synchronize()
|
||||||
|
|
||||||
@measure_time
|
@save_model()
|
||||||
def load_weights(self, model, fd_config: FDConfig) -> None:
|
@measure_time()
|
||||||
_, safetensor_files = get_all_safetensors(fd_config.model_config.model)
|
def load_weights(self, model, fd_config: FDConfig, enable_cache: bool = False) -> None:
|
||||||
weights_iterator = fast_weights_iterator(safetensor_files)
|
weights_iterator = get_weight_iterator(fd_config.model_config.model)
|
||||||
|
if enable_cache:
|
||||||
|
load_weights_form_cache(model, weights_iterator)
|
||||||
|
else:
|
||||||
model.load_weights(weights_iterator)
|
model.load_weights(weights_iterator)
|
||||||
self.clean_memory_fragments()
|
self.clean_memory_fragments()
|
||||||
|
|
||||||
@@ -61,14 +66,15 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
|||||||
|
|
||||||
architectures = architectures + "RL"
|
architectures = architectures + "RL"
|
||||||
|
|
||||||
|
enable_cache, _, weight_cache_context = is_weight_cache_enabled(fd_config)
|
||||||
|
with weight_cache_context:
|
||||||
with context:
|
with context:
|
||||||
model_cls = ModelRegistry.get_class(architectures)
|
model_cls = ModelRegistry.get_class(architectures)
|
||||||
model = model_cls(fd_config)
|
model = model_cls(fd_config)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# RL model not need set_state_dict
|
# RL model not need set_state_dict
|
||||||
if fd_config.load_config.dynamic_load_weight:
|
if fd_config.load_config.dynamic_load_weight:
|
||||||
return model
|
return model
|
||||||
self.load_weights(model, fd_config)
|
self.load_weights(model, fd_config, enable_cache)
|
||||||
return model
|
return model
|
||||||
|
@@ -199,3 +199,14 @@ def temporary_dtype(dtype: str):
|
|||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
paddle.set_default_dtype(orig_dtype)
|
paddle.set_default_dtype(orig_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def switch_config_context(config_obj, config_attr_name, value):
|
||||||
|
"""switch_config_context"""
|
||||||
|
origin_value = getattr(config_obj, config_attr_name)
|
||||||
|
setattr(config_obj, config_attr_name, value)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
setattr(config_obj, config_attr_name, origin_value)
|
||||||
|
@@ -15,7 +15,7 @@ from safetensors.numpy import save_file as safe_save_file
|
|||||||
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
|
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
|
||||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||||
from fastdeploy.model_executor.load_weight_utils import (
|
from fastdeploy.model_executor.load_weight_utils import (
|
||||||
get_all_safetensors,
|
get_all_weights_file,
|
||||||
safetensors_weights_iterator,
|
safetensors_weights_iterator,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
|
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
|
||||||
@@ -158,7 +158,7 @@ def main():
|
|||||||
Ernie4_5Tokenizer.resource_files_names["vocab_file"] = vocab_file_names[i]
|
Ernie4_5Tokenizer.resource_files_names["vocab_file"] = vocab_file_names[i]
|
||||||
break
|
break
|
||||||
tokenizer = Ernie4_5Tokenizer.from_pretrained(args.model_name_or_path)
|
tokenizer = Ernie4_5Tokenizer.from_pretrained(args.model_name_or_path)
|
||||||
_, safetensor_files = get_all_safetensors(args.model_name_or_path)
|
_, safetensor_files, _ = get_all_weights_file(args.model_name_or_path)
|
||||||
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
save_state_dict = {}
|
save_state_dict = {}
|
||||||
|
@@ -56,7 +56,11 @@ model_param_map = {
|
|||||||
"backend": "triton",
|
"backend": "triton",
|
||||||
"env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"},
|
"env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"},
|
||||||
},
|
},
|
||||||
{"quant_type": "block_wise_fp8", "backend": "deepgemm", "env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"}},
|
{
|
||||||
|
"quant_type": "block_wise_fp8",
|
||||||
|
"backend": "deepgemm",
|
||||||
|
"env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17", "FD_USE_DEEP_GEMM": "1"},
|
||||||
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
"DeepSeek-V3-0324": {
|
"DeepSeek-V3-0324": {
|
||||||
|
128
tests/model_loader/test_model_cache.py
Normal file
128
tests/model_loader/test_model_cache.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.abspath(os.path.join(current_dir, ".."))
|
||||||
|
if project_root not in sys.path:
|
||||||
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
from tests.model_loader.utils import (
|
||||||
|
check_tokens_id_and_text_close,
|
||||||
|
form_model_get_output_topp0,
|
||||||
|
get_paddle_model_path,
|
||||||
|
run_with_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313))
|
||||||
|
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
|
||||||
|
|
||||||
|
prompts = ["解释下“温故而知新", "Hello, how are you?"]
|
||||||
|
|
||||||
|
|
||||||
|
model_param_map = {
|
||||||
|
"ernie-4_5-21b-a3b-bf16-paddle": {
|
||||||
|
"tensor_parallel_size": 2,
|
||||||
|
"quantizations": [
|
||||||
|
{
|
||||||
|
"quant_type": "wint4",
|
||||||
|
"env": {"FD_ENABLE_MODEL_CACHE": "1"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
params = []
|
||||||
|
for model, cfg in model_param_map.items():
|
||||||
|
for q in cfg["quantizations"]:
|
||||||
|
if isinstance(q, dict):
|
||||||
|
quant, backend, env = q["quant_type"], q.get("backend", "default"), q.get("env", {})
|
||||||
|
else:
|
||||||
|
quant, backend, env = q, "default", {}
|
||||||
|
params.append(
|
||||||
|
pytest.param(
|
||||||
|
model,
|
||||||
|
cfg.get("tensor_parallel_size", 1),
|
||||||
|
cfg.get("max_model_len", 1024),
|
||||||
|
quant,
|
||||||
|
cfg.get("max_tokens", 32),
|
||||||
|
env,
|
||||||
|
marks=[pytest.mark.core_model],
|
||||||
|
id=f"{model}.{quant}.{backend}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name_or_path,tensor_parallel_size,max_model_len,quantization,max_tokens,env",
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
def test_model_cache(
|
||||||
|
fd_runner,
|
||||||
|
model_name_or_path: str,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
max_model_len: int,
|
||||||
|
max_tokens: int,
|
||||||
|
quantization: str,
|
||||||
|
env,
|
||||||
|
monkeypatch,
|
||||||
|
) -> None:
|
||||||
|
model_path = get_paddle_model_path(model_name_or_path)
|
||||||
|
|
||||||
|
fd_outputs_v1 = run_with_timeout(
|
||||||
|
target=form_model_get_output_topp0,
|
||||||
|
args=(
|
||||||
|
fd_runner,
|
||||||
|
model_path,
|
||||||
|
tensor_parallel_size,
|
||||||
|
max_model_len,
|
||||||
|
max_tokens,
|
||||||
|
quantization,
|
||||||
|
"default_v1",
|
||||||
|
FD_ENGINE_QUEUE_PORT,
|
||||||
|
prompts,
|
||||||
|
FD_CACHE_QUEUE_PORT,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if env:
|
||||||
|
for k, v in env.items():
|
||||||
|
monkeypatch.setenv(k, v)
|
||||||
|
|
||||||
|
fd_outputs_v1_with_cache = run_with_timeout(
|
||||||
|
target=form_model_get_output_topp0,
|
||||||
|
args=(
|
||||||
|
fd_runner,
|
||||||
|
model_path,
|
||||||
|
tensor_parallel_size,
|
||||||
|
max_model_len,
|
||||||
|
max_tokens,
|
||||||
|
quantization,
|
||||||
|
"default_v1",
|
||||||
|
FD_ENGINE_QUEUE_PORT,
|
||||||
|
prompts,
|
||||||
|
FD_CACHE_QUEUE_PORT,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
check_tokens_id_and_text_close(
|
||||||
|
outputs_0_lst=fd_outputs_v1,
|
||||||
|
outputs_1_lst=fd_outputs_v1_with_cache,
|
||||||
|
name_0="default_v1 laoder",
|
||||||
|
name_1="default_v1 loader using cache",
|
||||||
|
)
|
@@ -63,9 +63,14 @@ def run_with_timeout(target, args, timeout=60 * 5):
|
|||||||
print_logs()
|
print_logs()
|
||||||
raise RuntimeError("Worker process hung and was terminated")
|
raise RuntimeError("Worker process hung and was terminated")
|
||||||
try:
|
try:
|
||||||
return result_queue.get(timeout=60)
|
result = result_queue.get(timeout=60)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to get result from worker: {e}")
|
raise RuntimeError(f"Failed to get result from worker: {e}")
|
||||||
|
finally:
|
||||||
|
result_queue.close()
|
||||||
|
result_queue.join_thread()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def form_model_get_output_topp0(
|
def form_model_get_output_topp0(
|
||||||
|
Reference in New Issue
Block a user