[Feature] multi-source download (#2986)

* multi-source download

* multi-source download

* huggingface download revision

* requirement

* style

* add revision arg

* test

* pre-commit
This commit is contained in:
Yzc216
2025-07-24 14:26:37 +08:00
committed by GitHub
parent 87a2f4191d
commit e14587a954
7 changed files with 106 additions and 15 deletions

View File

@@ -46,6 +46,10 @@ class EngineArgs:
"""
The name or path of the model to be used.
"""
revision: Optional[str] = "master"
"""
The revision for downloading models.
"""
model_config_name: Optional[str] = "config.json"
"""
The name of the model configuration file.
@@ -340,6 +344,12 @@ class EngineArgs:
default=EngineArgs.model,
help="Model name or path to be used.",
)
model_group.add_argument(
"--revision",
type=nullable_str,
default=EngineArgs.revision,
help="Revision for downloading models",
)
model_group.add_argument(
"--model-config-name",
type=nullable_str,

View File

@@ -66,10 +66,11 @@ class LLM:
def __init__(
self,
model: str,
revision: Optional[str] = "master",
tokenizer: Optional[str] = None,
**kwargs,
):
model = retrive_model_from_server(model)
model = retrive_model_from_server(model, revision)
engine_args = EngineArgs(
model=model,
tokenizer=tokenizer,

View File

@@ -62,7 +62,7 @@ parser.add_argument("--metrics-port", default=8001, type=int, help="port for met
parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server")
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
args.model = retrive_model_from_server(args.model)
args.model = retrive_model_from_server(args.model, args.revision)
llm_engine = None

View File

@@ -30,6 +30,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_DEBUG": lambda: os.getenv("FD_DEBUG", "0"),
# Number of days to keep fastdeploy logs.
"FD_LOG_BACKUP_COUNT": lambda: os.getenv("FD_LOG_BACKUP_COUNT", "7"),
# Model download source, can set "AISTUDIO", "MODELSCOPE" or "HUGGINGFACE".
"FD_MODEL_SOURCE": lambda: os.getenv("FD_MODEL_SOURCE", "MODELSCOPE"),
# Model download cache directory.
"FD_MODEL_CACHE": lambda: os.getenv("FD_MODEL_CACHE", None),
# Maximum number of stop sequences.

View File

@@ -31,7 +31,9 @@ from typing import Literal, TypeVar, Union
import requests
import yaml
from aistudio_sdk.snapshot_download import snapshot_download
from aistudio_sdk.snapshot_download import snapshot_download as aistudio_download
from huggingface_hub._snapshot_download import snapshot_download as huggingface_download
from modelscope.hub.snapshot_download import snapshot_download as modelscope_download
from tqdm import tqdm
from typing_extensions import TypeIs, assert_never
@@ -494,21 +496,53 @@ def none_or_str(value):
def retrive_model_from_server(model_name_or_path, revision="master"):
"""
Download pretrained model from AIStudio automatically
Download pretrained model from MODELSCOPE, AIStudio or HUGGINGFACE automatically
"""
if os.path.exists(model_name_or_path):
return model_name_or_path
try:
repo_id = model_name_or_path
if repo_id.lower().strip().startswith("baidu"):
repo_id = "PaddlePaddle" + repo_id.strip()[5:]
local_path = envs.FD_MODEL_CACHE
if local_path is None:
local_path = f'{os.getenv("HOME")}/{repo_id}'
snapshot_download(repo_id=repo_id, revision=revision, local_dir=local_path)
model_name_or_path = local_path
except Exception:
raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.")
model_source = envs.FD_MODEL_SOURCE
local_path = envs.FD_MODEL_CACHE
repo_id = model_name_or_path
if model_source == "MODELSCOPE":
try:
if repo_id.lower().strip().startswith("baidu"):
repo_id = "PaddlePaddle" + repo_id.strip()[5:]
if local_path is None:
local_path = f'{os.getenv("HOME")}'
local_path = f"{local_path}/{repo_id}/{revision}"
modelscope_download(repo_id=repo_id, revision=revision, local_dir=local_path)
model_name_or_path = local_path
except Exception:
raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.")
elif model_source == "AISTUDIO":
try:
if repo_id.lower().strip().startswith("baidu"):
repo_id = "PaddlePaddle" + repo_id.strip()[5:]
if local_path is None:
local_path = f'{os.getenv("HOME")}'
local_path = f"{local_path}/{repo_id}/{revision}"
aistudio_download(repo_id=repo_id, revision=revision, local_dir=local_path)
model_name_or_path = local_path
except Exception:
raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.")
elif model_source == "HUGGINGFACE":
try:
if revision == "master":
revision = "main"
repo_id = model_name_or_path
if repo_id.lower().strip().startswith("PaddlePaddle"):
repo_id = "baidu" + repo_id.strip()[12:]
if local_path is None:
local_path = f'{os.getenv("HOME")}'
local_path = f"{local_path}/{repo_id}/{revision}"
huggingface_download(repo_id=repo_id, revision=revision, local_dir=local_path)
model_name_or_path = local_path
except Exception:
raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.")
else:
raise ValueError(
f"Unsupported model source: {model_source}, please choose one of ['MODELSCOPE', 'AISTUDIO', 'HUGGINGFACE']"
)
return model_name_or_path

View File

@@ -30,6 +30,7 @@ use-triton-in-paddle
crcmod
fastsafetensors==0.1.14
msgpack
modelscope
opentelemetry-api>=1.24.0
opentelemetry-sdk>=1.24.0
opentelemetry-instrumentation-redis

View File

@@ -0,0 +1,43 @@
import os
import unittest
from fastdeploy.utils import retrive_model_from_server
class TestAistudioDownload(unittest.TestCase):
def test_retrive_model_from_server_MODELSCOPE(self):
os.environ["FD_MODEL_SOURCE"] = "MODELSCOPE"
os.environ["FD_MODEL_CACHE"] = "./models"
model_name_or_path = "baidu/ERNIE-4.5-0.3B-PT"
revision = "master"
expected_path = f"./models/PaddlePaddle/ERNIE-4.5-0.3B-PT/{revision}"
result = retrive_model_from_server(model_name_or_path, revision)
self.assertEqual(expected_path, result)
os.environ.clear()
def test_retrive_model_from_server_unsupported_source(self):
os.environ["FD_MODEL_SOURCE"] = "UNSUPPORTED_SOURCE"
os.environ["FD_MODEL_CACHE"] = "./models"
model_name_or_path = "baidu/ERNIE-4.5-0.3B-PT"
with self.assertRaises(ValueError):
retrive_model_from_server(model_name_or_path)
os.environ.clear()
def test_retrive_model_from_server_model_not_exist(self):
os.environ["FD_MODEL_SOURCE"] = "MODELSCOPE"
os.environ["FD_MODEL_CACHE"] = "./models"
model_name_or_path = "non_existing_model"
with self.assertRaises(Exception):
retrive_model_from_server(model_name_or_path)
os.environ.clear()
if __name__ == "__main__":
unittest.main()