mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-01 23:02:36 +08:00
[Feature] multi-source download (#2986)
* multi-source download * multi-source download * huggingface download revision * requirement * style * add revision arg * test * pre-commit
This commit is contained in:
@@ -46,6 +46,10 @@ class EngineArgs:
|
||||
"""
|
||||
The name or path of the model to be used.
|
||||
"""
|
||||
revision: Optional[str] = "master"
|
||||
"""
|
||||
The revision for downloading models.
|
||||
"""
|
||||
model_config_name: Optional[str] = "config.json"
|
||||
"""
|
||||
The name of the model configuration file.
|
||||
@@ -340,6 +344,12 @@ class EngineArgs:
|
||||
default=EngineArgs.model,
|
||||
help="Model name or path to be used.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--revision",
|
||||
type=nullable_str,
|
||||
default=EngineArgs.revision,
|
||||
help="Revision for downloading models",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--model-config-name",
|
||||
type=nullable_str,
|
||||
|
@@ -66,10 +66,11 @@ class LLM:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
revision: Optional[str] = "master",
|
||||
tokenizer: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
model = retrive_model_from_server(model)
|
||||
model = retrive_model_from_server(model, revision)
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
|
@@ -62,7 +62,7 @@ parser.add_argument("--metrics-port", default=8001, type=int, help="port for met
|
||||
parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server")
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
args.model = retrive_model_from_server(args.model)
|
||||
args.model = retrive_model_from_server(args.model, args.revision)
|
||||
|
||||
llm_engine = None
|
||||
|
||||
|
@@ -30,6 +30,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_DEBUG": lambda: os.getenv("FD_DEBUG", "0"),
|
||||
# Number of days to keep fastdeploy logs.
|
||||
"FD_LOG_BACKUP_COUNT": lambda: os.getenv("FD_LOG_BACKUP_COUNT", "7"),
|
||||
# Model download source, can set "AISTUDIO", "MODELSCOPE" or "HUGGINGFACE".
|
||||
"FD_MODEL_SOURCE": lambda: os.getenv("FD_MODEL_SOURCE", "MODELSCOPE"),
|
||||
# Model download cache directory.
|
||||
"FD_MODEL_CACHE": lambda: os.getenv("FD_MODEL_CACHE", None),
|
||||
# Maximum number of stop sequences.
|
||||
|
@@ -31,7 +31,9 @@ from typing import Literal, TypeVar, Union
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
from aistudio_sdk.snapshot_download import snapshot_download
|
||||
from aistudio_sdk.snapshot_download import snapshot_download as aistudio_download
|
||||
from huggingface_hub._snapshot_download import snapshot_download as huggingface_download
|
||||
from modelscope.hub.snapshot_download import snapshot_download as modelscope_download
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import TypeIs, assert_never
|
||||
|
||||
@@ -494,21 +496,53 @@ def none_or_str(value):
|
||||
|
||||
def retrive_model_from_server(model_name_or_path, revision="master"):
|
||||
"""
|
||||
Download pretrained model from AIStudio automatically
|
||||
Download pretrained model from MODELSCOPE, AIStudio or HUGGINGFACE automatically
|
||||
"""
|
||||
if os.path.exists(model_name_or_path):
|
||||
return model_name_or_path
|
||||
try:
|
||||
repo_id = model_name_or_path
|
||||
if repo_id.lower().strip().startswith("baidu"):
|
||||
repo_id = "PaddlePaddle" + repo_id.strip()[5:]
|
||||
local_path = envs.FD_MODEL_CACHE
|
||||
if local_path is None:
|
||||
local_path = f'{os.getenv("HOME")}/{repo_id}'
|
||||
snapshot_download(repo_id=repo_id, revision=revision, local_dir=local_path)
|
||||
model_name_or_path = local_path
|
||||
except Exception:
|
||||
raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.")
|
||||
model_source = envs.FD_MODEL_SOURCE
|
||||
local_path = envs.FD_MODEL_CACHE
|
||||
repo_id = model_name_or_path
|
||||
if model_source == "MODELSCOPE":
|
||||
try:
|
||||
if repo_id.lower().strip().startswith("baidu"):
|
||||
repo_id = "PaddlePaddle" + repo_id.strip()[5:]
|
||||
if local_path is None:
|
||||
local_path = f'{os.getenv("HOME")}'
|
||||
local_path = f"{local_path}/{repo_id}/{revision}"
|
||||
modelscope_download(repo_id=repo_id, revision=revision, local_dir=local_path)
|
||||
model_name_or_path = local_path
|
||||
except Exception:
|
||||
raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.")
|
||||
elif model_source == "AISTUDIO":
|
||||
try:
|
||||
if repo_id.lower().strip().startswith("baidu"):
|
||||
repo_id = "PaddlePaddle" + repo_id.strip()[5:]
|
||||
if local_path is None:
|
||||
local_path = f'{os.getenv("HOME")}'
|
||||
local_path = f"{local_path}/{repo_id}/{revision}"
|
||||
aistudio_download(repo_id=repo_id, revision=revision, local_dir=local_path)
|
||||
model_name_or_path = local_path
|
||||
except Exception:
|
||||
raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.")
|
||||
elif model_source == "HUGGINGFACE":
|
||||
try:
|
||||
if revision == "master":
|
||||
revision = "main"
|
||||
repo_id = model_name_or_path
|
||||
if repo_id.lower().strip().startswith("PaddlePaddle"):
|
||||
repo_id = "baidu" + repo_id.strip()[12:]
|
||||
if local_path is None:
|
||||
local_path = f'{os.getenv("HOME")}'
|
||||
local_path = f"{local_path}/{repo_id}/{revision}"
|
||||
huggingface_download(repo_id=repo_id, revision=revision, local_dir=local_path)
|
||||
model_name_or_path = local_path
|
||||
except Exception:
|
||||
raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported model source: {model_source}, please choose one of ['MODELSCOPE', 'AISTUDIO', 'HUGGINGFACE']"
|
||||
)
|
||||
return model_name_or_path
|
||||
|
||||
|
||||
|
@@ -30,6 +30,7 @@ use-triton-in-paddle
|
||||
crcmod
|
||||
fastsafetensors==0.1.14
|
||||
msgpack
|
||||
modelscope
|
||||
opentelemetry-api>=1.24.0
|
||||
opentelemetry-sdk>=1.24.0
|
||||
opentelemetry-instrumentation-redis
|
||||
|
43
test/utils/test_download.py
Normal file
43
test/utils/test_download.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from fastdeploy.utils import retrive_model_from_server
|
||||
|
||||
|
||||
class TestAistudioDownload(unittest.TestCase):
|
||||
def test_retrive_model_from_server_MODELSCOPE(self):
|
||||
os.environ["FD_MODEL_SOURCE"] = "MODELSCOPE"
|
||||
os.environ["FD_MODEL_CACHE"] = "./models"
|
||||
|
||||
model_name_or_path = "baidu/ERNIE-4.5-0.3B-PT"
|
||||
revision = "master"
|
||||
expected_path = f"./models/PaddlePaddle/ERNIE-4.5-0.3B-PT/{revision}"
|
||||
result = retrive_model_from_server(model_name_or_path, revision)
|
||||
self.assertEqual(expected_path, result)
|
||||
|
||||
os.environ.clear()
|
||||
|
||||
def test_retrive_model_from_server_unsupported_source(self):
|
||||
os.environ["FD_MODEL_SOURCE"] = "UNSUPPORTED_SOURCE"
|
||||
os.environ["FD_MODEL_CACHE"] = "./models"
|
||||
|
||||
model_name_or_path = "baidu/ERNIE-4.5-0.3B-PT"
|
||||
with self.assertRaises(ValueError):
|
||||
retrive_model_from_server(model_name_or_path)
|
||||
|
||||
os.environ.clear()
|
||||
|
||||
def test_retrive_model_from_server_model_not_exist(self):
|
||||
os.environ["FD_MODEL_SOURCE"] = "MODELSCOPE"
|
||||
os.environ["FD_MODEL_CACHE"] = "./models"
|
||||
|
||||
model_name_or_path = "non_existing_model"
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
retrive_model_from_server(model_name_or_path)
|
||||
|
||||
os.environ.clear()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user