Files
FastDeploy/fastdeploy/model_executor/model_loader/default_loader.py
lizexu123 c86945ef49 [Feature] support pool (#3827)
* support pool

* update pooling

* add pooler_config and check

* update

* support AutoWeightsLoader load weight

* fix

* update

* delete print

* update pre-commit

* fix

* fix xpu

* fix ModelRegistry->model_registry

* fix Copilot review

* fix pooler.py

* delete StepPooler

* fix abstract

* fix default_loader_v1

* fix Pre Commit

* support torch qwen3 dense

* add test and fix torch-qwen

* fix

* fix

* adapter ci:

* fix review

* fix pooling_params.py

* fix

* fix tasks.py 2025

* fix print and logger

* Modefy ModelRegistry and delete AutoWeightsLoader

* fix logger

* fix test_embedding

* fix ci bug

* ernie4_5 model_registry

* fix test

* support Qwen3-Embedding-0.6B tp=1 load

* fix extra code

* fix

* delete fix vocab_size

* delete prepare_params_dict

* fix:
2025-09-22 14:09:09 +08:00

93 lines
3.1 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import contextlib
import paddle
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig, LoadConfig, ModelConfig
from fastdeploy.model_executor.load_weight_utils import (
load_composite_checkpoint,
measure_time,
)
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
from fastdeploy.model_executor.models.model_base import ModelRegistry
from fastdeploy.platforms import current_platform
class DefaultModelLoader(BaseModelLoader):
"""ModelLoader that can load registered models"""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
logger.info("Load the model and weights using DefaultModelLoader")
def download_model(self, model_config: ModelConfig) -> None:
"""download_model"""
pass
def clean_memory_fragments(self, state_dict: dict) -> None:
"""clean_memory_fragments"""
if current_platform.is_cuda():
if state_dict:
for k, v in state_dict.items():
if isinstance(v, paddle.Tensor):
v.value().get_tensor()._clear()
paddle.device.cuda.empty_cache()
paddle.device.synchronize()
@measure_time()
def load_weights(self, model, fd_config: FDConfig, architectures: str) -> None:
model_class = ModelRegistry.get_pretrain_cls(architectures)
state_dict = load_composite_checkpoint(
fd_config.model_config.model,
model_class,
fd_config,
return_numpy=True,
)
model.set_state_dict(state_dict)
self.clean_memory_fragments(state_dict)
def load_model(self, fd_config: FDConfig) -> nn.Layer:
architectures = fd_config.model_config.architectures[0]
logger.info(f"Starting to load model {architectures}")
if fd_config.load_config.dynamic_load_weight:
# register rl model
import fastdeploy.rl # noqa
architectures = architectures + "RL"
context = paddle.LazyGuard()
else:
context = contextlib.nullcontext()
with context:
model_cls = ModelRegistry.get_class(architectures)
model = model_cls(fd_config)
model.eval()
# RL model not need set_state_dict
if fd_config.load_config.dynamic_load_weight:
return model
# TODO(gongshaotian): Now, only support safetensor
self.load_weights(model, fd_config, architectures)
return model