mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-03 07:46:50 +08:00
[Feature] multi source download (#3005)
* multi-source download * multi-source download * huggingface download revision * requirement * style * add revision arg * test * pre-commit * Change default download * change requirements.txt * modify English Documentation * documentation
This commit is contained in:
@@ -32,8 +32,6 @@ from typing import Literal, TypeVar, Union
|
||||
import requests
|
||||
import yaml
|
||||
from aistudio_sdk.snapshot_download import snapshot_download as aistudio_download
|
||||
from huggingface_hub._snapshot_download import snapshot_download as huggingface_download
|
||||
from modelscope.hub.snapshot_download import snapshot_download as modelscope_download
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import TypeIs, assert_never
|
||||
|
||||
@@ -496,25 +494,14 @@ def none_or_str(value):
|
||||
|
||||
def retrive_model_from_server(model_name_or_path, revision="master"):
|
||||
"""
|
||||
Download pretrained model from MODELSCOPE, AIStudio or HUGGINGFACE automatically
|
||||
Download pretrained model from AIStudio, MODELSCOPE or HUGGINGFACE automatically
|
||||
"""
|
||||
if os.path.exists(model_name_or_path):
|
||||
return model_name_or_path
|
||||
model_source = envs.FD_MODEL_SOURCE
|
||||
local_path = envs.FD_MODEL_CACHE
|
||||
repo_id = model_name_or_path
|
||||
if model_source == "MODELSCOPE":
|
||||
try:
|
||||
if repo_id.lower().strip().startswith("baidu"):
|
||||
repo_id = "PaddlePaddle" + repo_id.strip()[5:]
|
||||
if local_path is None:
|
||||
local_path = f'{os.getenv("HOME")}'
|
||||
local_path = f"{local_path}/{repo_id}/{revision}"
|
||||
modelscope_download(repo_id=repo_id, revision=revision, local_dir=local_path)
|
||||
model_name_or_path = local_path
|
||||
except Exception:
|
||||
raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.")
|
||||
elif model_source == "AISTUDIO":
|
||||
if model_source == "AISTUDIO":
|
||||
try:
|
||||
if repo_id.lower().strip().startswith("baidu"):
|
||||
repo_id = "PaddlePaddle" + repo_id.strip()[5:]
|
||||
@@ -525,8 +512,27 @@ def retrive_model_from_server(model_name_or_path, revision="master"):
|
||||
model_name_or_path = local_path
|
||||
except Exception:
|
||||
raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.")
|
||||
elif model_source == "MODELSCOPE":
|
||||
try:
|
||||
from modelscope.hub.snapshot_download import (
|
||||
snapshot_download as modelscope_download,
|
||||
)
|
||||
|
||||
if repo_id.lower().strip().startswith("baidu"):
|
||||
repo_id = "PaddlePaddle" + repo_id.strip()[5:]
|
||||
if local_path is None:
|
||||
local_path = f'{os.getenv("HOME")}'
|
||||
local_path = f"{local_path}/{repo_id}/{revision}"
|
||||
modelscope_download(repo_id=repo_id, revision=revision, local_dir=local_path)
|
||||
model_name_or_path = local_path
|
||||
except Exception:
|
||||
raise Exception(f"The setting model_name_or_path:{model_name_or_path} is not exist.")
|
||||
elif model_source == "HUGGINGFACE":
|
||||
try:
|
||||
from huggingface_hub._snapshot_download import (
|
||||
snapshot_download as huggingface_download,
|
||||
)
|
||||
|
||||
if revision == "master":
|
||||
revision = "main"
|
||||
repo_id = model_name_or_path
|
||||
|
Reference in New Issue
Block a user