diff --git a/python/fastdeploy/__init__.py b/python/fastdeploy/__init__.py index c477fcec0..36e52c911 100644 --- a/python/fastdeploy/__init__.py +++ b/python/fastdeploy/__init__.py @@ -28,4 +28,4 @@ from . import c_lib_wrap as C from . import vision from . import pipeline from . import text -from .download import download, download_and_decompress +from .download import download, download_and_decompress, download_model diff --git a/python/fastdeploy/download.py b/python/fastdeploy/download.py index 72d969bda..01c382ec5 100644 --- a/python/fastdeploy/download.py +++ b/python/fastdeploy/download.py @@ -23,6 +23,9 @@ import hashlib import tqdm import logging +from fastdeploy.utils.hub_model_server import model_server +import fastdeploy.utils.hub_env as hubenv + DOWNLOAD_RETRY_LIMIT = 3 @@ -137,25 +140,29 @@ def decompress(fname): if fname.find('.tar') >= 0 or fname.find('.tgz') >= 0: with tarfile.open(fname) as tf: + def is_within_directory(directory, target): - + abs_directory = os.path.abspath(directory) abs_target = os.path.abspath(target) - + prefix = os.path.commonprefix([abs_directory, abs_target]) - + return prefix == abs_directory - - def safe_extract(tar, path=".", members=None, *, numeric_owner=False): - + + def safe_extract(tar, + path=".", + members=None, + *, + numeric_owner=False): + for member in tar.getmembers(): member_path = os.path.join(path, member.name) if not is_within_directory(path, member_path): raise Exception("Attempted Path Traversal in Tar File") - - tar.extractall(path, members, numeric_owner=numeric_owner) - - + + tar.extractall(path, members, numeric_owner=numeric_owner) + safe_extract(tf, path=fpath_tmp) elif fname.find('.zip') >= 0: with zipfile.ZipFile(fname) as zf: @@ -204,3 +211,39 @@ def download_and_decompress(url, path='.', rename=None): while os.path.exists(lock_path): time.sleep(1) return + + +def download_model(name: str, + path: str=None, + format: str=None, + version: str=None): + ''' + Download pre-trained model for FastDeploy inference engine. + Args: + name: model name + path(str): local path for saving model. If not set, default is hubenv.MODEL_HOME + format(str): FastDeploy model format + version(str) : FastDeploy model version + ''' + result = model_server.search_model(name, format, version) + if path is None: + path = hubenv.MODEL_HOME + if result: + url = result[0]['url'] + format = result[0]['format'] + version = result[0]['version'] + fullpath = download(url, path, show_progress=True) + model_server.stat_model(name, format, version) + if format == 'paddle': + if url.count(".tgz") > 0 or url.count(".tar") > 0 or url.count( + "zip") > 0: + fullpath = decompress(fullpath) + try: + os.rename(fullpath, + os.path.join(os.path.dirname(fullpath), name)) + fullpath = os.path.join(os.path.dirname(fullpath), name) + except FileExistsError: + pass + print('Successfully download model at path: {}'.format(fullpath)) + else: + print('ERROR: Could not find a model named {}'.format(name)) diff --git a/python/fastdeploy/utils/__init__.py b/python/fastdeploy/utils/__init__.py new file mode 100644 index 000000000..97043fd7b --- /dev/null +++ b/python/fastdeploy/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/fastdeploy/utils/hub_config.py b/python/fastdeploy/utils/hub_config.py new file mode 100644 index 000000000..4318d3713 --- /dev/null +++ b/python/fastdeploy/utils/hub_config.py @@ -0,0 +1,76 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import os +import time +import json +import uuid +import yaml + +import fastdeploy.utils.hub_env as hubenv + + +class HubConfig: + ''' + FastDeploy model management configuration class. + ''' + + def __init__(self): + self._initialize() + self.file = os.path.join(hubenv.CONF_HOME, 'config.yaml') + + if not os.path.exists(self.file): + self.flush() + return + + with open(self.file, 'r') as file: + try: + cfg = yaml.load(file, Loader=yaml.FullLoader) + self.data.update(cfg) + except: + ... + + def _initialize(self): + # Set default configuration values. + self.data = {} + self.data['server'] = 'http://paddlepaddle.org.cn/paddlehub' + + def reset(self): + '''Reset configuration to default.''' + self._initialize() + self.flush() + + @property + def server(self): + '''Model server url.''' + return self.data['server'] + + @server.setter + def server(self, url: str): + self.data['server'] = url + self.flush() + + def flush(self): + '''Flush the current configuration into the configuration file.''' + with open(self.file, 'w') as file: + cfg = json.loads(json.dumps(self.data)) + yaml.dump(cfg, file) + + def __str__(self): + cfg = json.loads(json.dumps(self.data)) + return yaml.dump(cfg) + + +config = HubConfig() diff --git a/python/fastdeploy/utils/hub_env.py b/python/fastdeploy/utils/hub_env.py new file mode 100644 index 000000000..0a81bc953 --- /dev/null +++ b/python/fastdeploy/utils/hub_env.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +''' +This module is used to store environmental variables for fastdeploy model hub. + +FASTDEPLOY_HUB_HOME --> the root directory for storing fastdeploy model hub related data. Default to ~/.fastdeploy. Users can change the +├ default value through the FASTDEPLOY_HUB_HOME environment variable. +├── MODEL_HOME --> Store the downloaded fastdeploy models. +├── CONF_HOME --> Store the default configuration files. +''' + +import os + + +def _get_user_home(): + return os.path.expanduser('~') + + +def _get_hub_home(): + if 'FASTDEPLOY_HUB_HOME' in os.environ: + home_path = os.environ['FASTDEPLOY_HUB_HOME'] + if os.path.exists(home_path): + if os.path.isdir(home_path): + return home_path + else: + raise RuntimeError( + 'The environment variable FASTDEPLOY_HUB_HOME {} is not a directory.'. + format(home_path)) + else: + return home_path + return os.path.join(_get_user_home(), '.fastdeploy') + + +def _get_sub_home(directory): + home = os.path.join(_get_hub_home(), directory) + os.makedirs(home, exist_ok=True) + return home + + +USER_HOME = _get_user_home() +HUB_HOME = _get_hub_home() +MODEL_HOME = _get_sub_home('models') +CONF_HOME = _get_sub_home('conf') diff --git a/python/fastdeploy/utils/hub_model_server.py b/python/fastdeploy/utils/hub_model_server.py new file mode 100644 index 000000000..849763b9f --- /dev/null +++ b/python/fastdeploy/utils/hub_model_server.py @@ -0,0 +1,119 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import requests +from typing import List + +from fastdeploy.utils.hub_config import config + + +class ServerConnectionError(Exception): + def __init__(self, url: str): + self.url = url + + def __str__(self): + tips = 'Can\'t connect to FastDeploy Model Server: {}'.format(self.url) + return tips + + +class ModelServer(object): + ''' + FastDeploy server source + + Args: + url(str) : Url of the server + timeout(int) : Request timeout + ''' + + def __init__(self, url: str, timeout: int=10): + self._url = url + self._timeout = timeout + + def search_model(self, name: str, format: str=None, + version: str=None) -> List[dict]: + ''' + Search model from model server. + + Args: + name(str) : FastDeploy model name + format(str): FastDeploy model format + version(str) : FastDeploy model version + Return: + result(list): search results + ''' + params = {} + params['name'] = name + if format: + params['format'] = format + if version: + params['version'] = version + result = self.request(path='fastdeploy_search', params=params) + if result['status'] == 0 and len(result['data']) > 0: + return result['data'] + return None + + def stat_model(self, name: str, format: str, version: str): + ''' + Note a record when download a model for statistics. + + Args: + name(str) : FastDeploy model name + format(str): FastDeploy model format + version(str) : FastDeploy model version + Return: + is_successful(bool): True if successful, False otherwise + ''' + params = {} + params['name'] = name + params['format'] = format + params['version'] = version + params['from'] = 'fastdeploy' + try: + result = self.request(path='stat', params=params) + except Exception: + return False + if result['status'] == 0: + return True + else: + return False + + def request(self, path: str, params: dict) -> dict: + '''Request server.''' + api = '{}/{}'.format(self._url, path) + try: + result = requests.get(api, params, timeout=self._timeout) + return result.json() + except requests.exceptions.ConnectionError as e: + raise ServerConnectionError(self._url) + + def is_connected(self): + return self.check(self._url) + + @classmethod + def check(cls, url: str) -> bool: + ''' + Check if the specified url is a valid model server + + Args: + url(str) : Url to check + ''' + try: + r = requests.get(url + '/search') + return r.status_code == 200 + except: + return False + + +model_server = ModelServer(config.server)