mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Other] Add a interface to get all pretrained models available from hub model server (#1022)
add get model list
This commit is contained in:
@@ -36,5 +36,5 @@ from . import c_lib_wrap as C
|
|||||||
from . import vision
|
from . import vision
|
||||||
from . import pipeline
|
from . import pipeline
|
||||||
from . import text
|
from . import text
|
||||||
from .download import download, download_and_decompress, download_model
|
from .download import download, download_and_decompress, download_model, get_model_list
|
||||||
from . import serving
|
from . import serving
|
||||||
|
@@ -213,6 +213,30 @@ def download_and_decompress(url, path='.', rename=None):
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_list(category: str=None):
|
||||||
|
'''
|
||||||
|
Get all pre-trained models information supported by fd.download_model.
|
||||||
|
Args:
|
||||||
|
category(str): model category, if None, list all models in all categories.
|
||||||
|
Returns:
|
||||||
|
results(dict): a dictionary, key is category, value is a list which contains models information.
|
||||||
|
'''
|
||||||
|
result = model_server.get_model_list()
|
||||||
|
if result['status'] != 0:
|
||||||
|
raise ValueError(
|
||||||
|
'Failed to get pretrained models information from hub model server.'
|
||||||
|
)
|
||||||
|
result = result['data']
|
||||||
|
if category is None:
|
||||||
|
return result
|
||||||
|
elif category in result:
|
||||||
|
return {category: result[category]}
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'No pretrained model in category {} can be downloaded now.'.format(
|
||||||
|
category))
|
||||||
|
|
||||||
|
|
||||||
def download_model(name: str,
|
def download_model(name: str,
|
||||||
path: str=None,
|
path: str=None,
|
||||||
format: str=None,
|
format: str=None,
|
||||||
@@ -237,11 +261,13 @@ def download_model(name: str,
|
|||||||
if format == 'paddle':
|
if format == 'paddle':
|
||||||
if url.count(".tgz") > 0 or url.count(".tar") > 0 or url.count(
|
if url.count(".tgz") > 0 or url.count(".tar") > 0 or url.count(
|
||||||
"zip") > 0:
|
"zip") > 0:
|
||||||
|
archive_path = fullpath
|
||||||
fullpath = decompress(fullpath)
|
fullpath = decompress(fullpath)
|
||||||
try:
|
try:
|
||||||
os.rename(fullpath,
|
os.rename(fullpath,
|
||||||
os.path.join(os.path.dirname(fullpath), name))
|
os.path.join(os.path.dirname(fullpath), name))
|
||||||
fullpath = os.path.join(os.path.dirname(fullpath), name)
|
fullpath = os.path.join(os.path.dirname(fullpath), name)
|
||||||
|
os.remove(archive_path)
|
||||||
except FileExistsError:
|
except FileExistsError:
|
||||||
pass
|
pass
|
||||||
print('Successfully download model at path: {}'.format(fullpath))
|
print('Successfully download model at path: {}'.format(fullpath))
|
||||||
|
@@ -98,6 +98,20 @@ class ModelServer(object):
|
|||||||
except requests.exceptions.ConnectionError as e:
|
except requests.exceptions.ConnectionError as e:
|
||||||
raise ServerConnectionError(self._url)
|
raise ServerConnectionError(self._url)
|
||||||
|
|
||||||
|
def get_model_list(self):
|
||||||
|
'''
|
||||||
|
Get all pre-trained models information in dataset.
|
||||||
|
Return:
|
||||||
|
result(dict): key is category name, value is a list which contains models \
|
||||||
|
information such as name, format and version.
|
||||||
|
'''
|
||||||
|
api = '{}/{}'.format(self._url, 'fastdeploy_listmodels')
|
||||||
|
try:
|
||||||
|
result = requests.get(api, timeout=self._timeout)
|
||||||
|
return result.json()
|
||||||
|
except requests.exceptions.ConnectionError as e:
|
||||||
|
raise ServerConnectionError(self._url)
|
||||||
|
|
||||||
def is_connected(self):
|
def is_connected(self):
|
||||||
return self.check(self._url)
|
return self.check(self._url)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user