[Other] Add a interface to get all pretrained models available from hub model server (#1022)

add get model list
This commit is contained in:
chenjian
2023-01-03 09:45:42 +08:00
committed by GitHub
parent 971cc051f4
commit 42f2e8d22b
3 changed files with 41 additions and 1 deletions

View File

@@ -36,5 +36,5 @@ from . import c_lib_wrap as C
from . import vision from . import vision
from . import pipeline from . import pipeline
from . import text from . import text
from .download import download, download_and_decompress, download_model from .download import download, download_and_decompress, download_model, get_model_list
from . import serving from . import serving

View File

@@ -213,6 +213,30 @@ def download_and_decompress(url, path='.', rename=None):
return return
def get_model_list(category: str=None):
'''
Get all pre-trained models information supported by fd.download_model.
Args:
category(str): model category, if None, list all models in all categories.
Returns:
results(dict): a dictionary, key is category, value is a list which contains models information.
'''
result = model_server.get_model_list()
if result['status'] != 0:
raise ValueError(
'Failed to get pretrained models information from hub model server.'
)
result = result['data']
if category is None:
return result
elif category in result:
return {category: result[category]}
else:
raise ValueError(
'No pretrained model in category {} can be downloaded now.'.format(
category))
def download_model(name: str, def download_model(name: str,
path: str=None, path: str=None,
format: str=None, format: str=None,
@@ -237,11 +261,13 @@ def download_model(name: str,
if format == 'paddle': if format == 'paddle':
if url.count(".tgz") > 0 or url.count(".tar") > 0 or url.count( if url.count(".tgz") > 0 or url.count(".tar") > 0 or url.count(
"zip") > 0: "zip") > 0:
archive_path = fullpath
fullpath = decompress(fullpath) fullpath = decompress(fullpath)
try: try:
os.rename(fullpath, os.rename(fullpath,
os.path.join(os.path.dirname(fullpath), name)) os.path.join(os.path.dirname(fullpath), name))
fullpath = os.path.join(os.path.dirname(fullpath), name) fullpath = os.path.join(os.path.dirname(fullpath), name)
os.remove(archive_path)
except FileExistsError: except FileExistsError:
pass pass
print('Successfully download model at path: {}'.format(fullpath)) print('Successfully download model at path: {}'.format(fullpath))

View File

@@ -98,6 +98,20 @@ class ModelServer(object):
except requests.exceptions.ConnectionError as e: except requests.exceptions.ConnectionError as e:
raise ServerConnectionError(self._url) raise ServerConnectionError(self._url)
def get_model_list(self):
'''
Get all pre-trained models information in dataset.
Return:
result(dict): key is category name, value is a list which contains models \
information such as name, format and version.
'''
api = '{}/{}'.format(self._url, 'fastdeploy_listmodels')
try:
result = requests.get(api, timeout=self._timeout)
return result.json()
except requests.exceptions.ConnectionError as e:
raise ServerConnectionError(self._url)
def is_connected(self): def is_connected(self):
return self.check(self._url) return self.check(self._url)