mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Other] Add a interface to get all pretrained models available from hub model server (#1022)
add get model list
This commit is contained in:
@@ -36,5 +36,5 @@ from . import c_lib_wrap as C
|
||||
from . import vision
|
||||
from . import pipeline
|
||||
from . import text
|
||||
from .download import download, download_and_decompress, download_model
|
||||
from .download import download, download_and_decompress, download_model, get_model_list
|
||||
from . import serving
|
||||
|
@@ -213,6 +213,30 @@ def download_and_decompress(url, path='.', rename=None):
|
||||
return
|
||||
|
||||
|
||||
def get_model_list(category: str=None):
|
||||
'''
|
||||
Get all pre-trained models information supported by fd.download_model.
|
||||
Args:
|
||||
category(str): model category, if None, list all models in all categories.
|
||||
Returns:
|
||||
results(dict): a dictionary, key is category, value is a list which contains models information.
|
||||
'''
|
||||
result = model_server.get_model_list()
|
||||
if result['status'] != 0:
|
||||
raise ValueError(
|
||||
'Failed to get pretrained models information from hub model server.'
|
||||
)
|
||||
result = result['data']
|
||||
if category is None:
|
||||
return result
|
||||
elif category in result:
|
||||
return {category: result[category]}
|
||||
else:
|
||||
raise ValueError(
|
||||
'No pretrained model in category {} can be downloaded now.'.format(
|
||||
category))
|
||||
|
||||
|
||||
def download_model(name: str,
|
||||
path: str=None,
|
||||
format: str=None,
|
||||
@@ -237,11 +261,13 @@ def download_model(name: str,
|
||||
if format == 'paddle':
|
||||
if url.count(".tgz") > 0 or url.count(".tar") > 0 or url.count(
|
||||
"zip") > 0:
|
||||
archive_path = fullpath
|
||||
fullpath = decompress(fullpath)
|
||||
try:
|
||||
os.rename(fullpath,
|
||||
os.path.join(os.path.dirname(fullpath), name))
|
||||
fullpath = os.path.join(os.path.dirname(fullpath), name)
|
||||
os.remove(archive_path)
|
||||
except FileExistsError:
|
||||
pass
|
||||
print('Successfully download model at path: {}'.format(fullpath))
|
||||
|
@@ -98,6 +98,20 @@ class ModelServer(object):
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise ServerConnectionError(self._url)
|
||||
|
||||
def get_model_list(self):
|
||||
'''
|
||||
Get all pre-trained models information in dataset.
|
||||
Return:
|
||||
result(dict): key is category name, value is a list which contains models \
|
||||
information such as name, format and version.
|
||||
'''
|
||||
api = '{}/{}'.format(self._url, 'fastdeploy_listmodels')
|
||||
try:
|
||||
result = requests.get(api, timeout=self._timeout)
|
||||
return result.json()
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise ServerConnectionError(self._url)
|
||||
|
||||
def is_connected(self):
|
||||
return self.check(self._url)
|
||||
|
||||
|
Reference in New Issue
Block a user