mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Other] Add a interface to get all pretrained models available from hub model server (#1022)
add get model list
This commit is contained in:
@@ -213,6 +213,30 @@ def download_and_decompress(url, path='.', rename=None):
|
||||
return
|
||||
|
||||
|
||||
def get_model_list(category: str=None):
|
||||
'''
|
||||
Get all pre-trained models information supported by fd.download_model.
|
||||
Args:
|
||||
category(str): model category, if None, list all models in all categories.
|
||||
Returns:
|
||||
results(dict): a dictionary, key is category, value is a list which contains models information.
|
||||
'''
|
||||
result = model_server.get_model_list()
|
||||
if result['status'] != 0:
|
||||
raise ValueError(
|
||||
'Failed to get pretrained models information from hub model server.'
|
||||
)
|
||||
result = result['data']
|
||||
if category is None:
|
||||
return result
|
||||
elif category in result:
|
||||
return {category: result[category]}
|
||||
else:
|
||||
raise ValueError(
|
||||
'No pretrained model in category {} can be downloaded now.'.format(
|
||||
category))
|
||||
|
||||
|
||||
def download_model(name: str,
|
||||
path: str=None,
|
||||
format: str=None,
|
||||
@@ -237,11 +261,13 @@ def download_model(name: str,
|
||||
if format == 'paddle':
|
||||
if url.count(".tgz") > 0 or url.count(".tar") > 0 or url.count(
|
||||
"zip") > 0:
|
||||
archive_path = fullpath
|
||||
fullpath = decompress(fullpath)
|
||||
try:
|
||||
os.rename(fullpath,
|
||||
os.path.join(os.path.dirname(fullpath), name))
|
||||
fullpath = os.path.join(os.path.dirname(fullpath), name)
|
||||
os.remove(archive_path)
|
||||
except FileExistsError:
|
||||
pass
|
||||
print('Successfully download model at path: {}'.format(fullpath))
|
||||
|
Reference in New Issue
Block a user