mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 00:06:38 +08:00
[Feature] multi-source download (#2986)
* multi-source download * multi-source download * huggingface download revision * requirement * style * add revision arg * test * pre-commit
This commit is contained in:
@@ -46,6 +46,10 @@ class EngineArgs:
|
|||||||
"""
|
"""
|
||||||
The name or path of the model to be used.
|
The name or path of the model to be used.
|
||||||
"""
|
"""
|
||||||
|
revision: Optional[str] = "master"
|
||||||
|
"""
|
||||||
|
The revision for downloading models.
|
||||||
|
"""
|
||||||
model_config_name: Optional[str] = "config.json"
|
model_config_name: Optional[str] = "config.json"
|
||||||
"""
|
"""
|
||||||
The name of the model configuration file.
|
The name of the model configuration file.
|
||||||
@@ -340,6 +344,12 @@ class EngineArgs:
|
|||||||
default=EngineArgs.model,
|
default=EngineArgs.model,
|
||||||
help="Model name or path to be used.",
|
help="Model name or path to be used.",
|
||||||
)
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
"--revision",
|
||||||
|
type=nullable_str,
|
||||||
|
default=EngineArgs.revision,
|
||||||
|
help="Revision for downloading models",
|
||||||
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--model-config-name",
|
"--model-config-name",
|
||||||
type=nullable_str,
|
type=nullable_str,
|
||||||
|
@@ -66,10 +66,11 @@ class LLM:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
revision: Optional[str] = "master",
|
||||||
tokenizer: Optional[str] = None,
|
tokenizer: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
model = retrive_model_from_server(model)
|
model = retrive_model_from_server(model, revision)
|
||||||
engine_args = EngineArgs(
|
engine_args = EngineArgs(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@@ -62,7 +62,7 @@ parser.add_argument("--metrics-port", default=8001, type=int, help="port for met
|
|||||||
parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server")
|
parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server")
|
||||||
parser = EngineArgs.add_cli_args(parser)
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.model = retrive_model_from_server(args.model)
|
args.model = retrive_model_from_server(args.model, args.revision)
|
||||||
|
|
||||||
llm_engine = None
|
llm_engine = None
|
||||||
|
|
||||||
|
@@ -30,6 +30,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"FD_DEBUG": lambda: os.getenv("FD_DEBUG", "0"),
|
"FD_DEBUG": lambda: os.getenv("FD_DEBUG", "0"),
|
||||||
# Number of days to keep fastdeploy logs.
|
# Number of days to keep fastdeploy logs.
|
||||||
"FD_LOG_BACKUP_COUNT": lambda: os.getenv("FD_LOG_BACKUP_COUNT", "7"),
|
"FD_LOG_BACKUP_COUNT": lambda: os.getenv("FD_LOG_BACKUP_COUNT", "7"),
|
||||||
|
# Model download source, can set "AISTUDIO", "MODELSCOPE" or "HUGGINGFACE".
|
||||||
|
"FD_MODEL_SOURCE": lambda: os.getenv("FD_MODEL_SOURCE", "MODELSCOPE"),
|
||||||
# Model download cache directory.
|
# Model download cache directory.
|
||||||
"FD_MODEL_CACHE": lambda: os.getenv("FD_MODEL_CACHE", None),
|
"FD_MODEL_CACHE": lambda: os.getenv("FD_MODEL_CACHE", None),
|
||||||
# Maximum number of stop sequences.
|
# Maximum number of stop sequences.
|
||||||
|
@@ -31,7 +31,9 @@ from typing import Literal, TypeVar, Union
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
import yaml
|
import yaml
|
||||||
from aistudio_sdk.snapshot_download import snapshot_download
|
from aistudio_sdk.snapshot_download import snapshot_download as aistudio_download
|
||||||
|
from huggingface_hub._snapshot_download import snapshot_download as huggingface_download
|
||||||
|
from modelscope.hub.snapshot_download import snapshot_download as modelscope_download
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing_extensions import TypeIs, assert_never
|
from typing_extensions import TypeIs, assert_never
|
||||||
|
|
||||||
@@ -494,21 +496,53 @@ def none_or_str(value):
|
|||||||
|
|
||||||
def retrive_model_from_server(model_name_or_path, revision="master"):
|
def retrive_model_from_server(model_name_or_path, revision="master"):
|
||||||
"""
|
"""
|
||||||
Download pretrained model from AIStudio automatically
|
Download pretrained model from MODELSCOPE, AIStudio or HUGGINGFACE automatically
|
||||||
"""
|
"""
|
||||||
if os.path.exists(model_name_or_path):
|
if os.path.exists(model_name_or_path):
|
||||||
return model_name_or_path
|
return model_name_or_path
|
||||||
try:
|
model_source = envs.FD_MODEL_SOURCE
|
||||||
repo_id = model_name_or_path
|
local_path = envs.FD_MODEL_CACHE
|
||||||
if repo_id.lower().strip().startswith("baidu"):
|
repo_id = model_name_or_path
|
||||||
repo_id = "PaddlePaddle" + repo_id.strip()[5:]
|
if model_source == "MODELSCOPE":
|
||||||
local_path = envs.FD_MODEL_CACHE
|
try:
|
||||||
if local_path is None:
|
if repo_id.lower().strip().startswith("baidu"):
|
||||||
local_path = f'{os.getenv("HOME")}/{repo_id}'
|
repo_id = "PaddlePaddle" + repo_id.strip()[5:]
|
||||||
snapshot_download(repo_id=repo_id, revision=revision, local_dir=local_path)
|
if local_path is None:
|
||||||
model_name_or_path = local_path
|
local_path = f'{os.getenv("HOME")}'
|
||||||
except Exception:
|
local_path = f"{local_path}/{repo_id}/{revision}"
|
||||||
raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.")
|
modelscope_download(repo_id=repo_id, revision=revision, local_dir=local_path)
|
||||||
|
model_name_or_path = local_path
|
||||||
|
except Exception:
|
||||||
|
raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.")
|
||||||
|
elif model_source == "AISTUDIO":
|
||||||
|
try:
|
||||||
|
if repo_id.lower().strip().startswith("baidu"):
|
||||||
|
repo_id = "PaddlePaddle" + repo_id.strip()[5:]
|
||||||
|
if local_path is None:
|
||||||
|
local_path = f'{os.getenv("HOME")}'
|
||||||
|
local_path = f"{local_path}/{repo_id}/{revision}"
|
||||||
|
aistudio_download(repo_id=repo_id, revision=revision, local_dir=local_path)
|
||||||
|
model_name_or_path = local_path
|
||||||
|
except Exception:
|
||||||
|
raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.")
|
||||||
|
elif model_source == "HUGGINGFACE":
|
||||||
|
try:
|
||||||
|
if revision == "master":
|
||||||
|
revision = "main"
|
||||||
|
repo_id = model_name_or_path
|
||||||
|
if repo_id.lower().strip().startswith("PaddlePaddle"):
|
||||||
|
repo_id = "baidu" + repo_id.strip()[12:]
|
||||||
|
if local_path is None:
|
||||||
|
local_path = f'{os.getenv("HOME")}'
|
||||||
|
local_path = f"{local_path}/{repo_id}/{revision}"
|
||||||
|
huggingface_download(repo_id=repo_id, revision=revision, local_dir=local_path)
|
||||||
|
model_name_or_path = local_path
|
||||||
|
except Exception:
|
||||||
|
raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported model source: {model_source}, please choose one of ['MODELSCOPE', 'AISTUDIO', 'HUGGINGFACE']"
|
||||||
|
)
|
||||||
return model_name_or_path
|
return model_name_or_path
|
||||||
|
|
||||||
|
|
||||||
|
@@ -30,6 +30,7 @@ use-triton-in-paddle
|
|||||||
crcmod
|
crcmod
|
||||||
fastsafetensors==0.1.14
|
fastsafetensors==0.1.14
|
||||||
msgpack
|
msgpack
|
||||||
|
modelscope
|
||||||
opentelemetry-api>=1.24.0
|
opentelemetry-api>=1.24.0
|
||||||
opentelemetry-sdk>=1.24.0
|
opentelemetry-sdk>=1.24.0
|
||||||
opentelemetry-instrumentation-redis
|
opentelemetry-instrumentation-redis
|
||||||
|
43
test/utils/test_download.py
Normal file
43
test/utils/test_download.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from fastdeploy.utils import retrive_model_from_server
|
||||||
|
|
||||||
|
|
||||||
|
class TestAistudioDownload(unittest.TestCase):
|
||||||
|
def test_retrive_model_from_server_MODELSCOPE(self):
|
||||||
|
os.environ["FD_MODEL_SOURCE"] = "MODELSCOPE"
|
||||||
|
os.environ["FD_MODEL_CACHE"] = "./models"
|
||||||
|
|
||||||
|
model_name_or_path = "baidu/ERNIE-4.5-0.3B-PT"
|
||||||
|
revision = "master"
|
||||||
|
expected_path = f"./models/PaddlePaddle/ERNIE-4.5-0.3B-PT/{revision}"
|
||||||
|
result = retrive_model_from_server(model_name_or_path, revision)
|
||||||
|
self.assertEqual(expected_path, result)
|
||||||
|
|
||||||
|
os.environ.clear()
|
||||||
|
|
||||||
|
def test_retrive_model_from_server_unsupported_source(self):
|
||||||
|
os.environ["FD_MODEL_SOURCE"] = "UNSUPPORTED_SOURCE"
|
||||||
|
os.environ["FD_MODEL_CACHE"] = "./models"
|
||||||
|
|
||||||
|
model_name_or_path = "baidu/ERNIE-4.5-0.3B-PT"
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
retrive_model_from_server(model_name_or_path)
|
||||||
|
|
||||||
|
os.environ.clear()
|
||||||
|
|
||||||
|
def test_retrive_model_from_server_model_not_exist(self):
|
||||||
|
os.environ["FD_MODEL_SOURCE"] = "MODELSCOPE"
|
||||||
|
os.environ["FD_MODEL_CACHE"] = "./models"
|
||||||
|
|
||||||
|
model_name_or_path = "non_existing_model"
|
||||||
|
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
retrive_model_from_server(model_name_or_path)
|
||||||
|
|
||||||
|
os.environ.clear()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Reference in New Issue
Block a user