From e14587a95438c58687b1f27af39c0eaf902d76c7 Mon Sep 17 00:00:00 2001 From: Yzc216 <101054010+Yzc216@users.noreply.github.com> Date: Thu, 24 Jul 2025 14:26:37 +0800 Subject: [PATCH] [Feature] multi-source download (#2986) * multi-source download * multi-source download * huggingface download revision * requirement * style * add revision arg * test * pre-commit --- fastdeploy/engine/args_utils.py | 10 ++++ fastdeploy/entrypoints/llm.py | 3 +- fastdeploy/entrypoints/openai/api_server.py | 2 +- fastdeploy/envs.py | 2 + fastdeploy/utils.py | 60 ++++++++++++++++----- requirements.txt | 1 + test/utils/test_download.py | 43 +++++++++++++++ 7 files changed, 106 insertions(+), 15 deletions(-) create mode 100644 test/utils/test_download.py diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index b1e464a6a..756ced2c3 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -46,6 +46,10 @@ class EngineArgs: """ The name or path of the model to be used. """ + revision: Optional[str] = "master" + """ + The revision for downloading models. + """ model_config_name: Optional[str] = "config.json" """ The name of the model configuration file. @@ -340,6 +344,12 @@ class EngineArgs: default=EngineArgs.model, help="Model name or path to be used.", ) + model_group.add_argument( + "--revision", + type=nullable_str, + default=EngineArgs.revision, + help="Revision for downloading models", + ) model_group.add_argument( "--model-config-name", type=nullable_str, diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py index e6356981f..1204a67f9 100644 --- a/fastdeploy/entrypoints/llm.py +++ b/fastdeploy/entrypoints/llm.py @@ -66,10 +66,11 @@ class LLM: def __init__( self, model: str, + revision: Optional[str] = "master", tokenizer: Optional[str] = None, **kwargs, ): - model = retrive_model_from_server(model) + model = retrive_model_from_server(model, revision) engine_args = EngineArgs( model=model, tokenizer=tokenizer, diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 3e05e7367..8e37f0d01 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -62,7 +62,7 @@ parser.add_argument("--metrics-port", default=8001, type=int, help="port for met parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server") parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() -args.model = retrive_model_from_server(args.model) +args.model = retrive_model_from_server(args.model, args.revision) llm_engine = None diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 40203b485..964e2f078 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -30,6 +30,8 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_DEBUG": lambda: os.getenv("FD_DEBUG", "0"), # Number of days to keep fastdeploy logs. "FD_LOG_BACKUP_COUNT": lambda: os.getenv("FD_LOG_BACKUP_COUNT", "7"), + # Model download source, can set "AISTUDIO", "MODELSCOPE" or "HUGGINGFACE". + "FD_MODEL_SOURCE": lambda: os.getenv("FD_MODEL_SOURCE", "MODELSCOPE"), # Model download cache directory. "FD_MODEL_CACHE": lambda: os.getenv("FD_MODEL_CACHE", None), # Maximum number of stop sequences. diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index a5cf5b3e0..f9cfe9d40 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -31,7 +31,9 @@ from typing import Literal, TypeVar, Union import requests import yaml -from aistudio_sdk.snapshot_download import snapshot_download +from aistudio_sdk.snapshot_download import snapshot_download as aistudio_download +from huggingface_hub._snapshot_download import snapshot_download as huggingface_download +from modelscope.hub.snapshot_download import snapshot_download as modelscope_download from tqdm import tqdm from typing_extensions import TypeIs, assert_never @@ -494,21 +496,53 @@ def none_or_str(value): def retrive_model_from_server(model_name_or_path, revision="master"): """ - Download pretrained model from AIStudio automatically + Download pretrained model from MODELSCOPE, AIStudio or HUGGINGFACE automatically """ if os.path.exists(model_name_or_path): return model_name_or_path - try: - repo_id = model_name_or_path - if repo_id.lower().strip().startswith("baidu"): - repo_id = "PaddlePaddle" + repo_id.strip()[5:] - local_path = envs.FD_MODEL_CACHE - if local_path is None: - local_path = f'{os.getenv("HOME")}/{repo_id}' - snapshot_download(repo_id=repo_id, revision=revision, local_dir=local_path) - model_name_or_path = local_path - except Exception: - raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.") + model_source = envs.FD_MODEL_SOURCE + local_path = envs.FD_MODEL_CACHE + repo_id = model_name_or_path + if model_source == "MODELSCOPE": + try: + if repo_id.lower().strip().startswith("baidu"): + repo_id = "PaddlePaddle" + repo_id.strip()[5:] + if local_path is None: + local_path = f'{os.getenv("HOME")}' + local_path = f"{local_path}/{repo_id}/{revision}" + modelscope_download(repo_id=repo_id, revision=revision, local_dir=local_path) + model_name_or_path = local_path + except Exception: + raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.") + elif model_source == "AISTUDIO": + try: + if repo_id.lower().strip().startswith("baidu"): + repo_id = "PaddlePaddle" + repo_id.strip()[5:] + if local_path is None: + local_path = f'{os.getenv("HOME")}' + local_path = f"{local_path}/{repo_id}/{revision}" + aistudio_download(repo_id=repo_id, revision=revision, local_dir=local_path) + model_name_or_path = local_path + except Exception: + raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.") + elif model_source == "HUGGINGFACE": + try: + if revision == "master": + revision = "main" + repo_id = model_name_or_path + if repo_id.lower().strip().startswith("PaddlePaddle"): + repo_id = "baidu" + repo_id.strip()[12:] + if local_path is None: + local_path = f'{os.getenv("HOME")}' + local_path = f"{local_path}/{repo_id}/{revision}" + huggingface_download(repo_id=repo_id, revision=revision, local_dir=local_path) + model_name_or_path = local_path + except Exception: + raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.") + else: + raise ValueError( + f"Unsupported model source: {model_source}, please choose one of ['MODELSCOPE', 'AISTUDIO', 'HUGGINGFACE']" + ) return model_name_or_path diff --git a/requirements.txt b/requirements.txt index f9166c8c2..4717f532a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,6 +30,7 @@ use-triton-in-paddle crcmod fastsafetensors==0.1.14 msgpack +modelscope opentelemetry-api>=1.24.0 opentelemetry-sdk>=1.24.0 opentelemetry-instrumentation-redis diff --git a/test/utils/test_download.py b/test/utils/test_download.py new file mode 100644 index 000000000..f479c693f --- /dev/null +++ b/test/utils/test_download.py @@ -0,0 +1,43 @@ +import os +import unittest + +from fastdeploy.utils import retrive_model_from_server + + +class TestAistudioDownload(unittest.TestCase): + def test_retrive_model_from_server_MODELSCOPE(self): + os.environ["FD_MODEL_SOURCE"] = "MODELSCOPE" + os.environ["FD_MODEL_CACHE"] = "./models" + + model_name_or_path = "baidu/ERNIE-4.5-0.3B-PT" + revision = "master" + expected_path = f"./models/PaddlePaddle/ERNIE-4.5-0.3B-PT/{revision}" + result = retrive_model_from_server(model_name_or_path, revision) + self.assertEqual(expected_path, result) + + os.environ.clear() + + def test_retrive_model_from_server_unsupported_source(self): + os.environ["FD_MODEL_SOURCE"] = "UNSUPPORTED_SOURCE" + os.environ["FD_MODEL_CACHE"] = "./models" + + model_name_or_path = "baidu/ERNIE-4.5-0.3B-PT" + with self.assertRaises(ValueError): + retrive_model_from_server(model_name_or_path) + + os.environ.clear() + + def test_retrive_model_from_server_model_not_exist(self): + os.environ["FD_MODEL_SOURCE"] = "MODELSCOPE" + os.environ["FD_MODEL_CACHE"] = "./models" + + model_name_or_path = "non_existing_model" + + with self.assertRaises(Exception): + retrive_model_from_server(model_name_or_path) + + os.environ.clear() + + +if __name__ == "__main__": + unittest.main()