cache feature (#3857)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

This commit is contained in:
bukejiyu
2025-09-07 18:52:46 +08:00
committed by GitHub
parent 30a1c1783f
commit e52ce1c4b1
12 changed files with 346 additions and 56 deletions

View File

@@ -97,6 +97,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_DISABLE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_DISABLE_CHUNKED_PREFILL", "0"))),
# Whether to use new get_output and save_output method (0 or 1)
"FD_USE_GET_SAVE_OUTPUT_V1": lambda: bool(int(os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0"))),
# Whether to enable model cache feature
"FD_ENABLE_MODEL_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_CACHE", "0"))),
}

View File

@@ -1053,7 +1053,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2]
self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size]
if layer.fd_config.load_config.load_choices == "default_v1":
if self.quant_config.is_checkpoint_bf16:
layer.up_gate_proj_weight = layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
dtype=layer.weight_dtype,
@@ -1138,7 +1138,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
def process_weights_after_loading(self, layer):
""" """
if not layer.fd_config.load_config.load_choices == "default_v1":
if not self.quant_config.is_checkpoint_bf16:
return
weight_id_map = {"gate_up": 0, "down": 1}
if (

View File

@@ -82,7 +82,7 @@ class MixQuantConfig(QuantConfigBase):
.from_config(
{
"is_permuted": self.is_permuted,
"self.is_checkpoint_bf16": self.is_checkpoint_bf16,
"is_checkpoint_bf16": self.is_checkpoint_bf16,
"hadamard_block_size": self.hadamard_block_size,
}
)
@@ -94,7 +94,7 @@ class MixQuantConfig(QuantConfigBase):
.from_config(
{
"is_permuted": self.is_permuted,
"self.is_checkpoint_bf16": self.is_checkpoint_bf16,
"is_checkpoint_bf16": self.is_checkpoint_bf16,
"hadamard_block_size": self.hadamard_block_size,
}
)
@@ -112,6 +112,6 @@ class MixQuantConfig(QuantConfigBase):
else:
return (
get_quantization_config(self.dense_quant_type)
.from_config({"self.is_checkpoint_bf16": self.is_checkpoint_bf16})
.from_config({"is_checkpoint_bf16": self.is_checkpoint_bf16})
.get_quant_method(layer)
)

View File

@@ -45,6 +45,7 @@ class WeightOnlyConfig(QuantConfigBase):
def __init__(
self,
algo: str,
is_checkpoint_bf16: bool = False,
) -> None:
super().__init__()
self.algo = algo
@@ -56,6 +57,7 @@ class WeightOnlyConfig(QuantConfigBase):
self.quant_max_bound = 0
self.quant_min_bound = 0
self.quant_round_type = 0
self.is_checkpoint_bf16 = is_checkpoint_bf16
def name(self) -> str:
return "weight_only"
@@ -63,7 +65,8 @@ class WeightOnlyConfig(QuantConfigBase):
@classmethod
def from_config(cls, config: dict) -> "WeightOnlyConfig":
algo = config["algo"]
return cls(algo)
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
return cls(algo, is_checkpoint_bf16)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
if current_platform.is_xpu():
@@ -154,12 +157,13 @@ class WINT8Config(WeightOnlyConfig):
weight only int8 config
"""
def __init__(self) -> None:
super().__init__("weight_only_int8")
def __init__(self, is_checkpoint_bf16: bool = False) -> None:
super().__init__("weight_only_int8", is_checkpoint_bf16)
@classmethod
def from_config(cls, config: dict) -> "WINT8Config":
return cls()
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
return cls(is_checkpoint_bf16)
def name(self) -> str:
return "wint8"
@@ -172,12 +176,14 @@ class WINT4Config(WeightOnlyConfig):
def __init__(
self,
is_checkpoint_bf16: bool = False,
) -> None:
super().__init__("weight_only_int4")
super().__init__("weight_only_int4", is_checkpoint_bf16)
@classmethod
def from_config(cls, config: dict) -> "WINT4Config":
return cls()
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
return cls(is_checkpoint_bf16)
def name(self) -> str:
return "wint4"
@@ -196,7 +202,7 @@ class WeightOnlyLinearMethod(QuantMethodBase):
self.quant_config = quant_config
def create_weights(self, layer, **extra_weight_attrs):
if layer.fd_config.load_config.load_choices == "default_v1":
if self.quant_config.is_checkpoint_bf16:
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
@@ -259,7 +265,7 @@ class WeightOnlyLinearMethod(QuantMethodBase):
)
def process_weights_after_loading(self, layer) -> None:
if not layer.fd_config.load_config.load_choices == "default_v1":
if not self.quant_config.is_checkpoint_bf16:
return
quanted_weight_tensor, weight_scale_tensor = weight_quantize(
layer.weight,

View File

@@ -14,9 +14,15 @@
# limitations under the License.
"""
import contextlib
import hashlib
import inspect
import json
import os
import pickle
import time
from functools import wraps
from pathlib import Path
import paddle
import paddle.distributed as dist
@@ -28,22 +34,139 @@ from paddleformers.utils.safetensors import fast_safe_open
from safetensors import safe_open
from tqdm import tqdm
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.linear import KVBatchLinear
from fastdeploy.model_executor.models.tp_utils import (
check_tensor_parallel_prerequisites,
)
from fastdeploy.model_executor.utils import switch_config_context
from fastdeploy.platforms import current_platform
def measure_time(func):
def wrapper(*args, **kwargs):
time_before_load = time.time()
result = func(*args, **kwargs)
time_after_load = time.time()
logger.info(f"Model loading took {time_after_load - time_before_load} seconds")
return result
def pdparams_weight_iterator(paddle_file_list: list[str]):
for pdparams_file in tqdm(
paddle_file_list,
desc="Loading pdparams checkpoint shards",
):
state_dict = paddle.load(pdparams_file)
yield from state_dict.items()
del state_dict
return wrapper
def load_weights_form_cache(model, weights_iterator):
params_dict = dict(model.named_parameters())
for loaded_weight_name, loaded_weight in weights_iterator:
param = params_dict[loaded_weight_name]
param.copy_(loaded_weight, False)
if "embeddings" in loaded_weight_name and getattr(model, "tie_word_embeddings", False):
model.lm_head.load_state_dict({model.lm_head.weight_key: loaded_weight})
for _, model_sublayer in model.named_sublayers():
if isinstance(model_sublayer, KVBatchLinear):
model_sublayer.process_weights_after_loading()
def get_weight_iterator(model_path: str):
_, files_list, use_safetensors = get_all_weights_file(model_path)
if use_safetensors:
weights_iterator = fast_weights_iterator(files_list)
else:
weights_iterator = pdparams_weight_iterator(files_list)
return weights_iterator
def is_weight_cache_enabled(fd_config, weight_cache_path=".cache"):
weight_cache_context = contextlib.nullcontext()
weight_cache_dir = None
enable_cache = False
if envs.FD_ENABLE_MODEL_CACHE:
model_weight_cache_path = os.path.join(fd_config.model_config.model, weight_cache_path)
# model_type + quantization + tp_size + ep_size
weight_cache_key = "_".join(
[
fd_config.model_config.model_type,
fd_config.quant_config.name(),
str(fd_config.parallel_config.tensor_parallel_size),
str(fd_config.parallel_config.expert_parallel_size),
]
)
# only support tp now
hash_key = hashlib.md5(pickle.dumps(weight_cache_key)).hexdigest()
weight_cache_dir = os.path.join(model_weight_cache_path, hash_key)
if os.path.exists(weight_cache_dir):
logger.info(
f"Loading will prioritize cached models. Users are responsible for ensuring the saved model is correct. If any error occurs, deleting the cache at {weight_cache_dir} may resolve it."
)
enable_cache = True
weight_cache_context = switch_config_context(fd_config.quant_config, "is_checkpoint_bf16", False)
return enable_cache, weight_cache_dir, weight_cache_context
def save_model(model_arg_name="model", config_arg_name="fd_config"):
@measure_time("Model saving")
def _save_model(model_dict, weight_cache_dir):
# Note: ProcessGroupNCCL do not support deepcopy protocol, we made modifications here.
paddle.distributed.communication.group.Group.__deepcopy__ = lambda self, _: self
paddle.distributed.communication.group.Group.to_json = lambda self: repr(self)
paddle.save(model_dict, weight_cache_dir)
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
fd_config = bound_args.arguments.get(config_arg_name, None)
model = bound_args.arguments.get(model_arg_name, None)
enable_cache, weight_cache_dir, _ = is_weight_cache_enabled(fd_config)
assert fd_config is not None, "fd_config cannot be None"
assert model is not None, "model cannot be None"
if enable_cache:
tp_weight_cache_dir = os.path.join(
weight_cache_dir, f"rank{str(fd_config.parallel_config.tensor_parallel_rank)}"
)
context = switch_config_context(fd_config.model_config, "model", tp_weight_cache_dir)
else:
context = contextlib.nullcontext()
with context:
result = func(*args, **kwargs)
if envs.FD_ENABLE_MODEL_CACHE and weight_cache_dir is not None and not os.path.exists(weight_cache_dir):
assert fd_config.quant_config is not None and getattr(
fd_config.quant_config, "is_checkpoint_bf16", False
), "Save cache only for dynamic quantization"
tp_weight_cache_dir = os.path.join(
weight_cache_dir, f"rank{str(fd_config.parallel_config.tensor_parallel_rank)}"
)
logger.info(f"Saving model to {tp_weight_cache_dir}")
os.makedirs(
tp_weight_cache_dir,
exist_ok=True,
)
_save_model(model.state_dict(), os.path.join(tp_weight_cache_dir, "cache.pdparams"))
else:
logger.info("Weights are already cached, skip saving")
return result
return wrapper
return decorator
def measure_time(prefix: str = "Model loading"):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
time_before = time.time()
result = func(*args, **kwargs)
time_after = time.time()
logger.info(f"{prefix} took {time_after - time_before:.3f} seconds")
return result
return wrapper
return decorator
def load_reordered_experts(model_path: str, key_name: str):
@@ -232,33 +355,38 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafete
"""
state_dict = {}
_, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}"))
_, safetensor_files, _ = get_all_weights_file(os.path.join(model_path, f"rank{local_rank}"))
weights_iterator = safetensors_weights_iterator(safetensor_files)
for name, weight in weights_iterator:
state_dict[name] = weight
return state_dict
def get_all_safetensors(model_path: str):
def get_all_weights_file(model_path: str):
"""
get_all_safetensors
"""
safe_model_path = os.path.join(model_path, "model.safetensors")
if os.path.exists(safe_model_path):
safetensor_list = [safe_model_path]
with safe_open(safe_model_path, framework="np", device="cpu") as f:
key_name_list = f.keys()
return key_name_list, safetensor_list
model_path = Path(model_path)
use_safetensors = True
if any(model_path.glob("*.pdparams")):
key_name_list = []
files_list = [str(file) for file in model_path.glob("*.pdparams")]
use_safetensors = False
else:
with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
weight_map = json.load(f)["weight_map"]
weight_files_in_index = set()
for weight_name in weight_map:
weight_files_in_index.add(os.path.join(model_path, weight_map[weight_name]))
key_name_list = list(set(weight_map.keys()))
safetensor_list = list(weight_files_in_index)
safetensor_list.sort()
return key_name_list, safetensor_list
safe_model_path = model_path / "model.safetensors"
if safe_model_path.exists():
files_list = [str(safe_model_path)]
with safe_open(safe_model_path, framework="np", device="cpu") as f:
key_name_list = f.keys()
return key_name_list, files_list, use_safetensors
else:
index_file = model_path / "model.safetensors.index.json"
with index_file.open("r") as f:
weight_map = json.load(f)["weight_map"]
weight_files_in_index = {str(model_path / weight_map[name]) for name in weight_map}
key_name_list = list(weight_map.keys())
files_list = sorted(weight_files_in_index)
return key_name_list, files_list, use_safetensors
def load_tp_checkpoint_v1(
@@ -271,7 +399,7 @@ def load_tp_checkpoint_v1(
load_tp_checkpoint_v1
"""
safetensor_keys, safetensor_files = get_all_safetensors(model_path)
safetensor_keys, safetensor_files, _ = get_all_weights_file(model_path)
if use_fastsafetensor:
weights_iterator = fastsafetensors_weights_iterator(safetensor_files)

View File

@@ -51,7 +51,7 @@ class DefaultModelLoader(BaseModelLoader):
paddle.device.cuda.empty_cache()
paddle.device.synchronize()
@measure_time
@measure_time()
def load_weights(self, model, fd_config: FDConfig, architectures: str) -> None:
model_class = ModelRegistry.get_pretrain_cls(architectures)

View File

@@ -20,9 +20,11 @@ from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig, LoadConfig, ModelConfig
from fastdeploy.model_executor.load_weight_utils import (
fast_weights_iterator,
get_all_safetensors,
get_weight_iterator,
is_weight_cache_enabled,
load_weights_form_cache,
measure_time,
save_model,
)
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
from fastdeploy.model_executor.models.model_base import ModelRegistry
@@ -44,11 +46,14 @@ class DefaultModelLoaderV1(BaseModelLoader):
paddle.device.cuda.empty_cache()
paddle.device.synchronize()
@measure_time
def load_weights(self, model, fd_config: FDConfig) -> None:
_, safetensor_files = get_all_safetensors(fd_config.model_config.model)
weights_iterator = fast_weights_iterator(safetensor_files)
model.load_weights(weights_iterator)
@save_model()
@measure_time()
def load_weights(self, model, fd_config: FDConfig, enable_cache: bool = False) -> None:
weights_iterator = get_weight_iterator(fd_config.model_config.model)
if enable_cache:
load_weights_form_cache(model, weights_iterator)
else:
model.load_weights(weights_iterator)
self.clean_memory_fragments()
def load_model(self, fd_config: FDConfig) -> nn.Layer:
@@ -61,14 +66,15 @@ class DefaultModelLoaderV1(BaseModelLoader):
architectures = architectures + "RL"
with context:
model_cls = ModelRegistry.get_class(architectures)
model = model_cls(fd_config)
enable_cache, _, weight_cache_context = is_weight_cache_enabled(fd_config)
with weight_cache_context:
with context:
model_cls = ModelRegistry.get_class(architectures)
model = model_cls(fd_config)
model.eval()
# RL model not need set_state_dict
if fd_config.load_config.dynamic_load_weight:
return model
self.load_weights(model, fd_config)
self.load_weights(model, fd_config, enable_cache)
return model

View File

@@ -199,3 +199,14 @@ def temporary_dtype(dtype: str):
yield
finally:
paddle.set_default_dtype(orig_dtype)
@contextmanager
def switch_config_context(config_obj, config_attr_name, value):
"""switch_config_context"""
origin_value = getattr(config_obj, config_attr_name)
setattr(config_obj, config_attr_name, value)
try:
yield
finally:
setattr(config_obj, config_attr_name, origin_value)

View File

@@ -15,7 +15,7 @@ from safetensors.numpy import save_file as safe_save_file
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.load_weight_utils import (
get_all_safetensors,
get_all_weights_file,
safetensors_weights_iterator,
)
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
@@ -158,7 +158,7 @@ def main():
Ernie4_5Tokenizer.resource_files_names["vocab_file"] = vocab_file_names[i]
break
tokenizer = Ernie4_5Tokenizer.from_pretrained(args.model_name_or_path)
_, safetensor_files = get_all_safetensors(args.model_name_or_path)
_, safetensor_files, _ = get_all_weights_file(args.model_name_or_path)
weights_iterator = safetensors_weights_iterator(safetensor_files)
state_dict = {}
save_state_dict = {}

View File

@@ -56,7 +56,11 @@ model_param_map = {
"backend": "triton",
"env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"},
},
{"quant_type": "block_wise_fp8", "backend": "deepgemm", "env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"}},
{
"quant_type": "block_wise_fp8",
"backend": "deepgemm",
"env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17", "FD_USE_DEEP_GEMM": "1"},
},
],
},
"DeepSeek-V3-0324": {

View File

@@ -0,0 +1,128 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import pytest
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from tests.model_loader.utils import (
check_tokens_id_and_text_close,
form_model_get_output_topp0,
get_paddle_model_path,
run_with_timeout,
)
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313))
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
prompts = ["解释下“温故而知新", "Hello, how are you?"]
model_param_map = {
"ernie-4_5-21b-a3b-bf16-paddle": {
"tensor_parallel_size": 2,
"quantizations": [
{
"quant_type": "wint4",
"env": {"FD_ENABLE_MODEL_CACHE": "1"},
}
],
}
}
params = []
for model, cfg in model_param_map.items():
for q in cfg["quantizations"]:
if isinstance(q, dict):
quant, backend, env = q["quant_type"], q.get("backend", "default"), q.get("env", {})
else:
quant, backend, env = q, "default", {}
params.append(
pytest.param(
model,
cfg.get("tensor_parallel_size", 1),
cfg.get("max_model_len", 1024),
quant,
cfg.get("max_tokens", 32),
env,
marks=[pytest.mark.core_model],
id=f"{model}.{quant}.{backend}",
)
)
@pytest.mark.parametrize(
"model_name_or_path,tensor_parallel_size,max_model_len,quantization,max_tokens,env",
params,
)
def test_model_cache(
fd_runner,
model_name_or_path: str,
tensor_parallel_size: int,
max_model_len: int,
max_tokens: int,
quantization: str,
env,
monkeypatch,
) -> None:
model_path = get_paddle_model_path(model_name_or_path)
fd_outputs_v1 = run_with_timeout(
target=form_model_get_output_topp0,
args=(
fd_runner,
model_path,
tensor_parallel_size,
max_model_len,
max_tokens,
quantization,
"default_v1",
FD_ENGINE_QUEUE_PORT,
prompts,
FD_CACHE_QUEUE_PORT,
),
)
if env:
for k, v in env.items():
monkeypatch.setenv(k, v)
fd_outputs_v1_with_cache = run_with_timeout(
target=form_model_get_output_topp0,
args=(
fd_runner,
model_path,
tensor_parallel_size,
max_model_len,
max_tokens,
quantization,
"default_v1",
FD_ENGINE_QUEUE_PORT,
prompts,
FD_CACHE_QUEUE_PORT,
),
)
check_tokens_id_and_text_close(
outputs_0_lst=fd_outputs_v1,
outputs_1_lst=fd_outputs_v1_with_cache,
name_0="default_v1 laoder",
name_1="default_v1 loader using cache",
)

View File

@@ -63,9 +63,14 @@ def run_with_timeout(target, args, timeout=60 * 5):
print_logs()
raise RuntimeError("Worker process hung and was terminated")
try:
return result_queue.get(timeout=60)
result = result_queue.get(timeout=60)
except Exception as e:
raise RuntimeError(f"Failed to get result from worker: {e}")
finally:
result_queue.close()
result_queue.join_thread()
return result
def form_model_get_output_topp0(