mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
cache feature (#3857)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
This commit is contained in:
@@ -97,6 +97,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_DISABLE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_DISABLE_CHUNKED_PREFILL", "0"))),
|
||||
# Whether to use new get_output and save_output method (0 or 1)
|
||||
"FD_USE_GET_SAVE_OUTPUT_V1": lambda: bool(int(os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0"))),
|
||||
# Whether to enable model cache feature
|
||||
"FD_ENABLE_MODEL_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_CACHE", "0"))),
|
||||
}
|
||||
|
||||
|
||||
|
@@ -1053,7 +1053,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
||||
self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2]
|
||||
self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size]
|
||||
|
||||
if layer.fd_config.load_config.load_choices == "default_v1":
|
||||
if self.quant_config.is_checkpoint_bf16:
|
||||
layer.up_gate_proj_weight = layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
|
||||
dtype=layer.weight_dtype,
|
||||
@@ -1138,7 +1138,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
""" """
|
||||
if not layer.fd_config.load_config.load_choices == "default_v1":
|
||||
if not self.quant_config.is_checkpoint_bf16:
|
||||
return
|
||||
weight_id_map = {"gate_up": 0, "down": 1}
|
||||
if (
|
||||
|
@@ -82,7 +82,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
.from_config(
|
||||
{
|
||||
"is_permuted": self.is_permuted,
|
||||
"self.is_checkpoint_bf16": self.is_checkpoint_bf16,
|
||||
"is_checkpoint_bf16": self.is_checkpoint_bf16,
|
||||
"hadamard_block_size": self.hadamard_block_size,
|
||||
}
|
||||
)
|
||||
@@ -94,7 +94,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
.from_config(
|
||||
{
|
||||
"is_permuted": self.is_permuted,
|
||||
"self.is_checkpoint_bf16": self.is_checkpoint_bf16,
|
||||
"is_checkpoint_bf16": self.is_checkpoint_bf16,
|
||||
"hadamard_block_size": self.hadamard_block_size,
|
||||
}
|
||||
)
|
||||
@@ -112,6 +112,6 @@ class MixQuantConfig(QuantConfigBase):
|
||||
else:
|
||||
return (
|
||||
get_quantization_config(self.dense_quant_type)
|
||||
.from_config({"self.is_checkpoint_bf16": self.is_checkpoint_bf16})
|
||||
.from_config({"is_checkpoint_bf16": self.is_checkpoint_bf16})
|
||||
.get_quant_method(layer)
|
||||
)
|
||||
|
@@ -45,6 +45,7 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
def __init__(
|
||||
self,
|
||||
algo: str,
|
||||
is_checkpoint_bf16: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.algo = algo
|
||||
@@ -56,6 +57,7 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
self.quant_max_bound = 0
|
||||
self.quant_min_bound = 0
|
||||
self.quant_round_type = 0
|
||||
self.is_checkpoint_bf16 = is_checkpoint_bf16
|
||||
|
||||
def name(self) -> str:
|
||||
return "weight_only"
|
||||
@@ -63,7 +65,8 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WeightOnlyConfig":
|
||||
algo = config["algo"]
|
||||
return cls(algo)
|
||||
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
||||
return cls(algo, is_checkpoint_bf16)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
if current_platform.is_xpu():
|
||||
@@ -154,12 +157,13 @@ class WINT8Config(WeightOnlyConfig):
|
||||
weight only int8 config
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("weight_only_int8")
|
||||
def __init__(self, is_checkpoint_bf16: bool = False) -> None:
|
||||
super().__init__("weight_only_int8", is_checkpoint_bf16)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WINT8Config":
|
||||
return cls()
|
||||
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
||||
return cls(is_checkpoint_bf16)
|
||||
|
||||
def name(self) -> str:
|
||||
return "wint8"
|
||||
@@ -172,12 +176,14 @@ class WINT4Config(WeightOnlyConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_checkpoint_bf16: bool = False,
|
||||
) -> None:
|
||||
super().__init__("weight_only_int4")
|
||||
super().__init__("weight_only_int4", is_checkpoint_bf16)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WINT4Config":
|
||||
return cls()
|
||||
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
||||
return cls(is_checkpoint_bf16)
|
||||
|
||||
def name(self) -> str:
|
||||
return "wint4"
|
||||
@@ -196,7 +202,7 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer, **extra_weight_attrs):
|
||||
if layer.fd_config.load_config.load_choices == "default_v1":
|
||||
if self.quant_config.is_checkpoint_bf16:
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=layer.weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
@@ -259,7 +265,7 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
if not layer.fd_config.load_config.load_choices == "default_v1":
|
||||
if not self.quant_config.is_checkpoint_bf16:
|
||||
return
|
||||
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
|
||||
layer.weight,
|
||||
|
@@ -14,9 +14,15 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
@@ -28,22 +34,139 @@ from paddleformers.utils.safetensors import fast_safe_open
|
||||
from safetensors import safe_open
|
||||
from tqdm import tqdm
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.linear import KVBatchLinear
|
||||
from fastdeploy.model_executor.models.tp_utils import (
|
||||
check_tensor_parallel_prerequisites,
|
||||
)
|
||||
from fastdeploy.model_executor.utils import switch_config_context
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
def measure_time(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
time_before_load = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
time_after_load = time.time()
|
||||
logger.info(f"Model loading took {time_after_load - time_before_load} seconds")
|
||||
return result
|
||||
def pdparams_weight_iterator(paddle_file_list: list[str]):
|
||||
for pdparams_file in tqdm(
|
||||
paddle_file_list,
|
||||
desc="Loading pdparams checkpoint shards",
|
||||
):
|
||||
state_dict = paddle.load(pdparams_file)
|
||||
yield from state_dict.items()
|
||||
del state_dict
|
||||
|
||||
return wrapper
|
||||
|
||||
def load_weights_form_cache(model, weights_iterator):
|
||||
params_dict = dict(model.named_parameters())
|
||||
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||
param = params_dict[loaded_weight_name]
|
||||
param.copy_(loaded_weight, False)
|
||||
if "embeddings" in loaded_weight_name and getattr(model, "tie_word_embeddings", False):
|
||||
model.lm_head.load_state_dict({model.lm_head.weight_key: loaded_weight})
|
||||
for _, model_sublayer in model.named_sublayers():
|
||||
if isinstance(model_sublayer, KVBatchLinear):
|
||||
model_sublayer.process_weights_after_loading()
|
||||
|
||||
|
||||
def get_weight_iterator(model_path: str):
|
||||
_, files_list, use_safetensors = get_all_weights_file(model_path)
|
||||
if use_safetensors:
|
||||
weights_iterator = fast_weights_iterator(files_list)
|
||||
else:
|
||||
weights_iterator = pdparams_weight_iterator(files_list)
|
||||
return weights_iterator
|
||||
|
||||
|
||||
def is_weight_cache_enabled(fd_config, weight_cache_path=".cache"):
|
||||
weight_cache_context = contextlib.nullcontext()
|
||||
weight_cache_dir = None
|
||||
enable_cache = False
|
||||
if envs.FD_ENABLE_MODEL_CACHE:
|
||||
model_weight_cache_path = os.path.join(fd_config.model_config.model, weight_cache_path)
|
||||
# model_type + quantization + tp_size + ep_size
|
||||
weight_cache_key = "_".join(
|
||||
[
|
||||
fd_config.model_config.model_type,
|
||||
fd_config.quant_config.name(),
|
||||
str(fd_config.parallel_config.tensor_parallel_size),
|
||||
str(fd_config.parallel_config.expert_parallel_size),
|
||||
]
|
||||
)
|
||||
# only support tp now
|
||||
hash_key = hashlib.md5(pickle.dumps(weight_cache_key)).hexdigest()
|
||||
weight_cache_dir = os.path.join(model_weight_cache_path, hash_key)
|
||||
if os.path.exists(weight_cache_dir):
|
||||
logger.info(
|
||||
f"Loading will prioritize cached models. Users are responsible for ensuring the saved model is correct. If any error occurs, deleting the cache at {weight_cache_dir} may resolve it."
|
||||
)
|
||||
enable_cache = True
|
||||
weight_cache_context = switch_config_context(fd_config.quant_config, "is_checkpoint_bf16", False)
|
||||
|
||||
return enable_cache, weight_cache_dir, weight_cache_context
|
||||
|
||||
|
||||
def save_model(model_arg_name="model", config_arg_name="fd_config"):
|
||||
@measure_time("Model saving")
|
||||
def _save_model(model_dict, weight_cache_dir):
|
||||
# Note: ProcessGroupNCCL do not support deepcopy protocol, we made modifications here.
|
||||
paddle.distributed.communication.group.Group.__deepcopy__ = lambda self, _: self
|
||||
paddle.distributed.communication.group.Group.to_json = lambda self: repr(self)
|
||||
paddle.save(model_dict, weight_cache_dir)
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
sig = inspect.signature(func)
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
fd_config = bound_args.arguments.get(config_arg_name, None)
|
||||
model = bound_args.arguments.get(model_arg_name, None)
|
||||
enable_cache, weight_cache_dir, _ = is_weight_cache_enabled(fd_config)
|
||||
assert fd_config is not None, "fd_config cannot be None"
|
||||
assert model is not None, "model cannot be None"
|
||||
if enable_cache:
|
||||
tp_weight_cache_dir = os.path.join(
|
||||
weight_cache_dir, f"rank{str(fd_config.parallel_config.tensor_parallel_rank)}"
|
||||
)
|
||||
context = switch_config_context(fd_config.model_config, "model", tp_weight_cache_dir)
|
||||
else:
|
||||
context = contextlib.nullcontext()
|
||||
|
||||
with context:
|
||||
result = func(*args, **kwargs)
|
||||
if envs.FD_ENABLE_MODEL_CACHE and weight_cache_dir is not None and not os.path.exists(weight_cache_dir):
|
||||
assert fd_config.quant_config is not None and getattr(
|
||||
fd_config.quant_config, "is_checkpoint_bf16", False
|
||||
), "Save cache only for dynamic quantization"
|
||||
tp_weight_cache_dir = os.path.join(
|
||||
weight_cache_dir, f"rank{str(fd_config.parallel_config.tensor_parallel_rank)}"
|
||||
)
|
||||
logger.info(f"Saving model to {tp_weight_cache_dir}")
|
||||
os.makedirs(
|
||||
tp_weight_cache_dir,
|
||||
exist_ok=True,
|
||||
)
|
||||
_save_model(model.state_dict(), os.path.join(tp_weight_cache_dir, "cache.pdparams"))
|
||||
else:
|
||||
logger.info("Weights are already cached, skip saving")
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def measure_time(prefix: str = "Model loading"):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
time_before = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
time_after = time.time()
|
||||
logger.info(f"{prefix} took {time_after - time_before:.3f} seconds")
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def load_reordered_experts(model_path: str, key_name: str):
|
||||
@@ -232,33 +355,38 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafete
|
||||
"""
|
||||
|
||||
state_dict = {}
|
||||
_, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}"))
|
||||
_, safetensor_files, _ = get_all_weights_file(os.path.join(model_path, f"rank{local_rank}"))
|
||||
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
||||
for name, weight in weights_iterator:
|
||||
state_dict[name] = weight
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_all_safetensors(model_path: str):
|
||||
def get_all_weights_file(model_path: str):
|
||||
"""
|
||||
get_all_safetensors
|
||||
"""
|
||||
safe_model_path = os.path.join(model_path, "model.safetensors")
|
||||
if os.path.exists(safe_model_path):
|
||||
safetensor_list = [safe_model_path]
|
||||
with safe_open(safe_model_path, framework="np", device="cpu") as f:
|
||||
key_name_list = f.keys()
|
||||
return key_name_list, safetensor_list
|
||||
model_path = Path(model_path)
|
||||
use_safetensors = True
|
||||
if any(model_path.glob("*.pdparams")):
|
||||
key_name_list = []
|
||||
files_list = [str(file) for file in model_path.glob("*.pdparams")]
|
||||
use_safetensors = False
|
||||
else:
|
||||
with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
|
||||
weight_map = json.load(f)["weight_map"]
|
||||
weight_files_in_index = set()
|
||||
for weight_name in weight_map:
|
||||
weight_files_in_index.add(os.path.join(model_path, weight_map[weight_name]))
|
||||
key_name_list = list(set(weight_map.keys()))
|
||||
safetensor_list = list(weight_files_in_index)
|
||||
safetensor_list.sort()
|
||||
return key_name_list, safetensor_list
|
||||
safe_model_path = model_path / "model.safetensors"
|
||||
if safe_model_path.exists():
|
||||
files_list = [str(safe_model_path)]
|
||||
with safe_open(safe_model_path, framework="np", device="cpu") as f:
|
||||
key_name_list = f.keys()
|
||||
return key_name_list, files_list, use_safetensors
|
||||
else:
|
||||
index_file = model_path / "model.safetensors.index.json"
|
||||
with index_file.open("r") as f:
|
||||
weight_map = json.load(f)["weight_map"]
|
||||
weight_files_in_index = {str(model_path / weight_map[name]) for name in weight_map}
|
||||
key_name_list = list(weight_map.keys())
|
||||
files_list = sorted(weight_files_in_index)
|
||||
return key_name_list, files_list, use_safetensors
|
||||
|
||||
|
||||
def load_tp_checkpoint_v1(
|
||||
@@ -271,7 +399,7 @@ def load_tp_checkpoint_v1(
|
||||
load_tp_checkpoint_v1
|
||||
"""
|
||||
|
||||
safetensor_keys, safetensor_files = get_all_safetensors(model_path)
|
||||
safetensor_keys, safetensor_files, _ = get_all_weights_file(model_path)
|
||||
|
||||
if use_fastsafetensor:
|
||||
weights_iterator = fastsafetensors_weights_iterator(safetensor_files)
|
||||
|
@@ -51,7 +51,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
paddle.device.cuda.empty_cache()
|
||||
paddle.device.synchronize()
|
||||
|
||||
@measure_time
|
||||
@measure_time()
|
||||
def load_weights(self, model, fd_config: FDConfig, architectures: str) -> None:
|
||||
model_class = ModelRegistry.get_pretrain_cls(architectures)
|
||||
|
||||
|
@@ -20,9 +20,11 @@ from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig, LoadConfig, ModelConfig
|
||||
from fastdeploy.model_executor.load_weight_utils import (
|
||||
fast_weights_iterator,
|
||||
get_all_safetensors,
|
||||
get_weight_iterator,
|
||||
is_weight_cache_enabled,
|
||||
load_weights_form_cache,
|
||||
measure_time,
|
||||
save_model,
|
||||
)
|
||||
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from fastdeploy.model_executor.models.model_base import ModelRegistry
|
||||
@@ -44,11 +46,14 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
||||
paddle.device.cuda.empty_cache()
|
||||
paddle.device.synchronize()
|
||||
|
||||
@measure_time
|
||||
def load_weights(self, model, fd_config: FDConfig) -> None:
|
||||
_, safetensor_files = get_all_safetensors(fd_config.model_config.model)
|
||||
weights_iterator = fast_weights_iterator(safetensor_files)
|
||||
model.load_weights(weights_iterator)
|
||||
@save_model()
|
||||
@measure_time()
|
||||
def load_weights(self, model, fd_config: FDConfig, enable_cache: bool = False) -> None:
|
||||
weights_iterator = get_weight_iterator(fd_config.model_config.model)
|
||||
if enable_cache:
|
||||
load_weights_form_cache(model, weights_iterator)
|
||||
else:
|
||||
model.load_weights(weights_iterator)
|
||||
self.clean_memory_fragments()
|
||||
|
||||
def load_model(self, fd_config: FDConfig) -> nn.Layer:
|
||||
@@ -61,14 +66,15 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
||||
|
||||
architectures = architectures + "RL"
|
||||
|
||||
with context:
|
||||
model_cls = ModelRegistry.get_class(architectures)
|
||||
model = model_cls(fd_config)
|
||||
enable_cache, _, weight_cache_context = is_weight_cache_enabled(fd_config)
|
||||
with weight_cache_context:
|
||||
with context:
|
||||
model_cls = ModelRegistry.get_class(architectures)
|
||||
model = model_cls(fd_config)
|
||||
|
||||
model.eval()
|
||||
|
||||
# RL model not need set_state_dict
|
||||
if fd_config.load_config.dynamic_load_weight:
|
||||
return model
|
||||
self.load_weights(model, fd_config)
|
||||
self.load_weights(model, fd_config, enable_cache)
|
||||
return model
|
||||
|
@@ -199,3 +199,14 @@ def temporary_dtype(dtype: str):
|
||||
yield
|
||||
finally:
|
||||
paddle.set_default_dtype(orig_dtype)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def switch_config_context(config_obj, config_attr_name, value):
|
||||
"""switch_config_context"""
|
||||
origin_value = getattr(config_obj, config_attr_name)
|
||||
setattr(config_obj, config_attr_name, value)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
setattr(config_obj, config_attr_name, origin_value)
|
||||
|
@@ -15,7 +15,7 @@ from safetensors.numpy import save_file as safe_save_file
|
||||
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.load_weight_utils import (
|
||||
get_all_safetensors,
|
||||
get_all_weights_file,
|
||||
safetensors_weights_iterator,
|
||||
)
|
||||
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
|
||||
@@ -158,7 +158,7 @@ def main():
|
||||
Ernie4_5Tokenizer.resource_files_names["vocab_file"] = vocab_file_names[i]
|
||||
break
|
||||
tokenizer = Ernie4_5Tokenizer.from_pretrained(args.model_name_or_path)
|
||||
_, safetensor_files = get_all_safetensors(args.model_name_or_path)
|
||||
_, safetensor_files, _ = get_all_weights_file(args.model_name_or_path)
|
||||
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
||||
state_dict = {}
|
||||
save_state_dict = {}
|
||||
|
@@ -56,7 +56,11 @@ model_param_map = {
|
||||
"backend": "triton",
|
||||
"env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"},
|
||||
},
|
||||
{"quant_type": "block_wise_fp8", "backend": "deepgemm", "env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"}},
|
||||
{
|
||||
"quant_type": "block_wise_fp8",
|
||||
"backend": "deepgemm",
|
||||
"env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17", "FD_USE_DEEP_GEMM": "1"},
|
||||
},
|
||||
],
|
||||
},
|
||||
"DeepSeek-V3-0324": {
|
||||
|
128
tests/model_loader/test_model_cache.py
Normal file
128
tests/model_loader/test_model_cache.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.abspath(os.path.join(current_dir, ".."))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from tests.model_loader.utils import (
|
||||
check_tokens_id_and_text_close,
|
||||
form_model_get_output_topp0,
|
||||
get_paddle_model_path,
|
||||
run_with_timeout,
|
||||
)
|
||||
|
||||
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313))
|
||||
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
|
||||
|
||||
prompts = ["解释下“温故而知新", "Hello, how are you?"]
|
||||
|
||||
|
||||
model_param_map = {
|
||||
"ernie-4_5-21b-a3b-bf16-paddle": {
|
||||
"tensor_parallel_size": 2,
|
||||
"quantizations": [
|
||||
{
|
||||
"quant_type": "wint4",
|
||||
"env": {"FD_ENABLE_MODEL_CACHE": "1"},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
params = []
|
||||
for model, cfg in model_param_map.items():
|
||||
for q in cfg["quantizations"]:
|
||||
if isinstance(q, dict):
|
||||
quant, backend, env = q["quant_type"], q.get("backend", "default"), q.get("env", {})
|
||||
else:
|
||||
quant, backend, env = q, "default", {}
|
||||
params.append(
|
||||
pytest.param(
|
||||
model,
|
||||
cfg.get("tensor_parallel_size", 1),
|
||||
cfg.get("max_model_len", 1024),
|
||||
quant,
|
||||
cfg.get("max_tokens", 32),
|
||||
env,
|
||||
marks=[pytest.mark.core_model],
|
||||
id=f"{model}.{quant}.{backend}",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name_or_path,tensor_parallel_size,max_model_len,quantization,max_tokens,env",
|
||||
params,
|
||||
)
|
||||
def test_model_cache(
|
||||
fd_runner,
|
||||
model_name_or_path: str,
|
||||
tensor_parallel_size: int,
|
||||
max_model_len: int,
|
||||
max_tokens: int,
|
||||
quantization: str,
|
||||
env,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
model_path = get_paddle_model_path(model_name_or_path)
|
||||
|
||||
fd_outputs_v1 = run_with_timeout(
|
||||
target=form_model_get_output_topp0,
|
||||
args=(
|
||||
fd_runner,
|
||||
model_path,
|
||||
tensor_parallel_size,
|
||||
max_model_len,
|
||||
max_tokens,
|
||||
quantization,
|
||||
"default_v1",
|
||||
FD_ENGINE_QUEUE_PORT,
|
||||
prompts,
|
||||
FD_CACHE_QUEUE_PORT,
|
||||
),
|
||||
)
|
||||
|
||||
if env:
|
||||
for k, v in env.items():
|
||||
monkeypatch.setenv(k, v)
|
||||
|
||||
fd_outputs_v1_with_cache = run_with_timeout(
|
||||
target=form_model_get_output_topp0,
|
||||
args=(
|
||||
fd_runner,
|
||||
model_path,
|
||||
tensor_parallel_size,
|
||||
max_model_len,
|
||||
max_tokens,
|
||||
quantization,
|
||||
"default_v1",
|
||||
FD_ENGINE_QUEUE_PORT,
|
||||
prompts,
|
||||
FD_CACHE_QUEUE_PORT,
|
||||
),
|
||||
)
|
||||
check_tokens_id_and_text_close(
|
||||
outputs_0_lst=fd_outputs_v1,
|
||||
outputs_1_lst=fd_outputs_v1_with_cache,
|
||||
name_0="default_v1 laoder",
|
||||
name_1="default_v1 loader using cache",
|
||||
)
|
@@ -63,9 +63,14 @@ def run_with_timeout(target, args, timeout=60 * 5):
|
||||
print_logs()
|
||||
raise RuntimeError("Worker process hung and was terminated")
|
||||
try:
|
||||
return result_queue.get(timeout=60)
|
||||
result = result_queue.get(timeout=60)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to get result from worker: {e}")
|
||||
finally:
|
||||
result_queue.close()
|
||||
result_queue.join_thread()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def form_model_get_output_topp0(
|
||||
|
Reference in New Issue
Block a user