diff --git a/python/fastdeploy/__init__.py b/python/fastdeploy/__init__.py index 42db5c281..31735c685 100644 --- a/python/fastdeploy/__init__.py +++ b/python/fastdeploy/__init__.py @@ -36,5 +36,5 @@ from . import c_lib_wrap as C from . import vision from . import pipeline from . import text -from .download import download, download_and_decompress, download_model +from .download import download, download_and_decompress, download_model, get_model_list from . import serving diff --git a/python/fastdeploy/download.py b/python/fastdeploy/download.py index 0b14ccf8e..7af6042a8 100644 --- a/python/fastdeploy/download.py +++ b/python/fastdeploy/download.py @@ -213,6 +213,30 @@ def download_and_decompress(url, path='.', rename=None): return +def get_model_list(category: str=None): + ''' + Get all pre-trained models information supported by fd.download_model. + Args: + category(str): model category, if None, list all models in all categories. + Returns: + results(dict): a dictionary, key is category, value is a list which contains models information. + ''' + result = model_server.get_model_list() + if result['status'] != 0: + raise ValueError( + 'Failed to get pretrained models information from hub model server.' + ) + result = result['data'] + if category is None: + return result + elif category in result: + return {category: result[category]} + else: + raise ValueError( + 'No pretrained model in category {} can be downloaded now.'.format( + category)) + + def download_model(name: str, path: str=None, format: str=None, @@ -237,11 +261,13 @@ def download_model(name: str, if format == 'paddle': if url.count(".tgz") > 0 or url.count(".tar") > 0 or url.count( "zip") > 0: + archive_path = fullpath fullpath = decompress(fullpath) try: os.rename(fullpath, os.path.join(os.path.dirname(fullpath), name)) fullpath = os.path.join(os.path.dirname(fullpath), name) + os.remove(archive_path) except FileExistsError: pass print('Successfully download model at path: {}'.format(fullpath)) diff --git a/python/fastdeploy/utils/hub_model_server.py b/python/fastdeploy/utils/hub_model_server.py index 849763b9f..3eb891e64 100644 --- a/python/fastdeploy/utils/hub_model_server.py +++ b/python/fastdeploy/utils/hub_model_server.py @@ -98,6 +98,20 @@ class ModelServer(object): except requests.exceptions.ConnectionError as e: raise ServerConnectionError(self._url) + def get_model_list(self): + ''' + Get all pre-trained models information in dataset. + Return: + result(dict): key is category name, value is a list which contains models \ + information such as name, format and version. + ''' + api = '{}/{}'.format(self._url, 'fastdeploy_listmodels') + try: + result = requests.get(api, timeout=self._timeout) + return result.json() + except requests.exceptions.ConnectionError as e: + raise ServerConnectionError(self._url) + def is_connected(self): return self.check(self._url)