mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-01 06:42:23 +08:00
[Feature] multi-source download (#2986)
* multi-source download * multi-source download * huggingface download revision * requirement * style * add revision arg * test * pre-commit
This commit is contained in:
@@ -31,7 +31,9 @@ from typing import Literal, TypeVar, Union
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
from aistudio_sdk.snapshot_download import snapshot_download
|
||||
from aistudio_sdk.snapshot_download import snapshot_download as aistudio_download
|
||||
from huggingface_hub._snapshot_download import snapshot_download as huggingface_download
|
||||
from modelscope.hub.snapshot_download import snapshot_download as modelscope_download
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import TypeIs, assert_never
|
||||
|
||||
@@ -494,21 +496,53 @@ def none_or_str(value):
|
||||
|
||||
def retrive_model_from_server(model_name_or_path, revision="master"):
|
||||
"""
|
||||
Download pretrained model from AIStudio automatically
|
||||
Download pretrained model from MODELSCOPE, AIStudio or HUGGINGFACE automatically
|
||||
"""
|
||||
if os.path.exists(model_name_or_path):
|
||||
return model_name_or_path
|
||||
try:
|
||||
repo_id = model_name_or_path
|
||||
if repo_id.lower().strip().startswith("baidu"):
|
||||
repo_id = "PaddlePaddle" + repo_id.strip()[5:]
|
||||
local_path = envs.FD_MODEL_CACHE
|
||||
if local_path is None:
|
||||
local_path = f'{os.getenv("HOME")}/{repo_id}'
|
||||
snapshot_download(repo_id=repo_id, revision=revision, local_dir=local_path)
|
||||
model_name_or_path = local_path
|
||||
except Exception:
|
||||
raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.")
|
||||
model_source = envs.FD_MODEL_SOURCE
|
||||
local_path = envs.FD_MODEL_CACHE
|
||||
repo_id = model_name_or_path
|
||||
if model_source == "MODELSCOPE":
|
||||
try:
|
||||
if repo_id.lower().strip().startswith("baidu"):
|
||||
repo_id = "PaddlePaddle" + repo_id.strip()[5:]
|
||||
if local_path is None:
|
||||
local_path = f'{os.getenv("HOME")}'
|
||||
local_path = f"{local_path}/{repo_id}/{revision}"
|
||||
modelscope_download(repo_id=repo_id, revision=revision, local_dir=local_path)
|
||||
model_name_or_path = local_path
|
||||
except Exception:
|
||||
raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.")
|
||||
elif model_source == "AISTUDIO":
|
||||
try:
|
||||
if repo_id.lower().strip().startswith("baidu"):
|
||||
repo_id = "PaddlePaddle" + repo_id.strip()[5:]
|
||||
if local_path is None:
|
||||
local_path = f'{os.getenv("HOME")}'
|
||||
local_path = f"{local_path}/{repo_id}/{revision}"
|
||||
aistudio_download(repo_id=repo_id, revision=revision, local_dir=local_path)
|
||||
model_name_or_path = local_path
|
||||
except Exception:
|
||||
raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.")
|
||||
elif model_source == "HUGGINGFACE":
|
||||
try:
|
||||
if revision == "master":
|
||||
revision = "main"
|
||||
repo_id = model_name_or_path
|
||||
if repo_id.lower().strip().startswith("PaddlePaddle"):
|
||||
repo_id = "baidu" + repo_id.strip()[12:]
|
||||
if local_path is None:
|
||||
local_path = f'{os.getenv("HOME")}'
|
||||
local_path = f"{local_path}/{repo_id}/{revision}"
|
||||
huggingface_download(repo_id=repo_id, revision=revision, local_dir=local_path)
|
||||
model_name_or_path = local_path
|
||||
except Exception:
|
||||
raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported model source: {model_source}, please choose one of ['MODELSCOPE', 'AISTUDIO', 'HUGGINGFACE']"
|
||||
)
|
||||
return model_name_or_path
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user