[Other] Add a interface to get all pretrained models available from hub model server (#1022)

add get model list
This commit is contained in:
chenjian
2023-01-03 09:45:42 +08:00
committed by GitHub
parent 971cc051f4
commit 42f2e8d22b
3 changed files with 41 additions and 1 deletions

View File

@@ -213,6 +213,30 @@ def download_and_decompress(url, path='.', rename=None):
return
def get_model_list(category: str=None):
'''
Get all pre-trained models information supported by fd.download_model.
Args:
category(str): model category, if None, list all models in all categories.
Returns:
results(dict): a dictionary, key is category, value is a list which contains models information.
'''
result = model_server.get_model_list()
if result['status'] != 0:
raise ValueError(
'Failed to get pretrained models information from hub model server.'
)
result = result['data']
if category is None:
return result
elif category in result:
return {category: result[category]}
else:
raise ValueError(
'No pretrained model in category {} can be downloaded now.'.format(
category))
def download_model(name: str,
path: str=None,
format: str=None,
@@ -237,11 +261,13 @@ def download_model(name: str,
if format == 'paddle':
if url.count(".tgz") > 0 or url.count(".tar") > 0 or url.count(
"zip") > 0:
archive_path = fullpath
fullpath = decompress(fullpath)
try:
os.rename(fullpath,
os.path.join(os.path.dirname(fullpath), name))
fullpath = os.path.join(os.path.dirname(fullpath), name)
os.remove(archive_path)
except FileExistsError:
pass
print('Successfully download model at path: {}'.format(fullpath))