Files
FastDeploy/fastdeploy/transformer_utils/config.py
lizexu123 c86945ef49 [Feature] support pool (#3827)
* support pool

* update pooling

* add pooler_config and check

* update

* support AutoWeightsLoader load weight

* fix

* update

* delete print

* update pre-commit

* fix

* fix xpu

* fix ModelRegistry->model_registry

* fix Copilot review

* fix pooler.py

* delete StepPooler

* fix abstract

* fix default_loader_v1

* fix Pre Commit

* support torch qwen3 dense

* add test and fix torch-qwen

* fix

* fix

* adapter ci:

* fix review

* fix pooling_params.py

* fix

* fix tasks.py 2025

* fix print and logger

* Modefy ModelRegistry and delete AutoWeightsLoader

* fix logger

* fix test_embedding

* fix ci bug

* ernie4_5 model_registry

* fix test

* support Qwen3-Embedding-0.6B tp=1 load

* fix extra code

* fix

* delete fix vocab_size

* delete prepare_params_dict

* fix:
2025-09-22 14:09:09 +08:00

140 lines
4.5 KiB
Python

import json
from pathlib import Path
from typing import Optional, Union
import huggingface_hub
from huggingface_hub import hf_hub_download, try_to_load_from_cache
from huggingface_hub.utils import (
EntryNotFoundError,
HfHubHTTPError,
LocalEntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
)
from fastdeploy.utils import get_logger
logger = get_logger("transformer_config", "transformer_config.log")
def file_or_path_exists(model, config_name):
if (local_path := Path(model)).exists():
return (local_path / config_name).is_file()
return False
def get_pooling_config_name(pooling_name: str):
if "pooling_mode_" in pooling_name:
pooling_name = pooling_name.replace("pooling_mode_", "")
if "_" in pooling_name:
pooling_name = pooling_name.split("_")[0]
if "lasttoken" in pooling_name:
pooling_name = "last"
supported_pooling_types = ["LAST", "ALL", "CLS", "STEP", "MEAN"]
pooling_type_name = pooling_name.upper()
if pooling_type_name in supported_pooling_types:
return pooling_type_name
raise NotImplementedError(f"Pooling type {pooling_type_name} not supported")
def try_get_local_file(model: Union[str, Path], file_name: str, revision: Optional[str] = "main") -> Optional[Path]:
file_path = Path(model) / file_name
if file_path.is_file():
return file_path
else:
try:
cached_filepath = try_to_load_from_cache(repo_id=model, filename=file_name, revision=revision)
if isinstance(cached_filepath, str):
return Path(cached_filepath)
except ValueError:
...
return None
def get_hf_file_to_dict(file_name: str, model: Union[str, Path], revision: Optional[str] = "main"):
"""
Downloads a file from the Hugging Face Hub and returns
its contents as a dictionary.
Parameters:
- file_name (str): The name of the file to download.
- model (str): The name of the model on the Hugging Face Hub.
- revision (str): The specific version of the model.
Returns:
- config_dict (dict): A dictionary containing
the contents of the downloaded file.
"""
file_path = try_get_local_file(model=model, file_name=file_name, revision=revision)
if file_path is None:
try:
hf_hub_file = hf_hub_download(model, file_name, revision=revision)
except huggingface_hub.errors.OfflineModeIsEnabled:
return None
except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError, LocalEntryNotFoundError) as e:
logger.debug("File or repository not found in hf_hub_download", e)
return None
except HfHubHTTPError as e:
logger.warning(
"Cannot connect to Hugging Face Hub. Skipping file " "download for '%s':", file_name, exc_info=e
)
return None
file_path = Path(hf_hub_file)
if file_path is not None and file_path.is_file():
with open(file_path) as file:
return json.load(file)
return None
def get_pooling_config(model: str, revision: Optional[str] = "main"):
"""
This function gets the pooling and normalize
config from the model - only applies to
sentence-transformers models.
Args:
model (str): The name of the Hugging Face model.
revision (str, optional): The specific version
of the model to use. Defaults to 'main'.
Returns:
dict: A dictionary containing the pooling
type and whether normalization is used.
"""
modules_file_name = "modules.json"
modules_dict = None
if file_or_path_exists(model, config_name=modules_file_name):
modules_dict = get_hf_file_to_dict(modules_file_name, model)
if modules_dict is None:
return None
pooling = next((item for item in modules_dict if item["type"] == "sentence_transformers.models.Pooling"), None)
normalize = bool(
next((item for item in modules_dict if item["type"] == "sentence_transformers.models.Normalize"), False)
)
if pooling:
pooling_file_name = "{}/config.json".format(pooling["path"])
pooling_dict = get_hf_file_to_dict(pooling_file_name, model)
pooling_type_name = next((item for item, val in pooling_dict.items() if val is True), None)
if pooling_type_name is not None:
pooling_type_name = get_pooling_config_name(pooling_type_name)
return {"pooling_type": pooling_type_name, "normalize": normalize}
return None