mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-25 09:31:38 +08:00 
			
		
		
		
	 0d0389f8d9
			
		
	
	0d0389f8d9
	
	
	
		
			
			* update example codes using download_model api * add resource files to repo * add resource files to repo * fix * reduce unused lib * fix
		
			
				
	
	
		
			251 lines
		
	
	
		
			8.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			251 lines
		
	
	
		
			8.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #    http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| import os
 | |
| import os.path as osp
 | |
| import shutil
 | |
| import requests
 | |
| import time
 | |
| import zipfile
 | |
| import tarfile
 | |
| import hashlib
 | |
| import tqdm
 | |
| import logging
 | |
| 
 | |
| from fastdeploy.utils.hub_model_server import model_server
 | |
| import fastdeploy.utils.hub_env as hubenv
 | |
| 
 | |
| DOWNLOAD_RETRY_LIMIT = 3
 | |
| 
 | |
| 
 | |
| def md5check(fullname, md5sum=None):
 | |
|     if md5sum is None:
 | |
|         return True
 | |
| 
 | |
|     logging.info("File {} md5 checking...".format(fullname))
 | |
|     md5 = hashlib.md5()
 | |
|     with open(fullname, 'rb') as f:
 | |
|         for chunk in iter(lambda: f.read(4096), b""):
 | |
|             md5.update(chunk)
 | |
|     calc_md5sum = md5.hexdigest()
 | |
| 
 | |
|     if calc_md5sum != md5sum:
 | |
|         logging.info("File {} md5 check failed, {}(calc) != "
 | |
|                      "{}(base)".format(fullname, calc_md5sum, md5sum))
 | |
|         return False
 | |
|     return True
 | |
| 
 | |
| 
 | |
| def move_and_merge_tree(src, dst):
 | |
|     """
 | |
|     Move src directory to dst, if dst is already exists,
 | |
|     merge src to dst
 | |
|     """
 | |
|     if not osp.exists(dst):
 | |
|         shutil.move(src, dst)
 | |
|     else:
 | |
|         if not osp.isdir(src):
 | |
|             shutil.move(src, dst)
 | |
|             return
 | |
|         for fp in os.listdir(src):
 | |
|             src_fp = osp.join(src, fp)
 | |
|             dst_fp = osp.join(dst, fp)
 | |
|             if osp.isdir(src_fp):
 | |
|                 if osp.isdir(dst_fp):
 | |
|                     move_and_merge_tree(src_fp, dst_fp)
 | |
|                 else:
 | |
|                     shutil.move(src_fp, dst_fp)
 | |
|             elif osp.isfile(src_fp) and \
 | |
|                     not osp.isfile(dst_fp):
 | |
|                 shutil.move(src_fp, dst_fp)
 | |
| 
 | |
| 
 | |
| def download(url, path, rename=None, md5sum=None, show_progress=False):
 | |
|     """
 | |
|     Download from url, save to path.
 | |
|     url (str): download url
 | |
|     path (str): download to given path
 | |
|     """
 | |
|     if not osp.exists(path):
 | |
|         os.makedirs(path)
 | |
| 
 | |
|     fname = osp.split(url)[-1]
 | |
|     fullname = osp.join(path, fname)
 | |
|     if rename is not None:
 | |
|         fullname = osp.join(path, rename)
 | |
|     retry_cnt = 0
 | |
|     while not (osp.exists(fullname) and md5check(fullname, md5sum)):
 | |
|         if retry_cnt < DOWNLOAD_RETRY_LIMIT:
 | |
|             retry_cnt += 1
 | |
|         else:
 | |
|             logging.debug("{} download failed.".format(fname))
 | |
|             raise RuntimeError("Download from {} failed. "
 | |
|                                "Retry limit reached".format(url))
 | |
| 
 | |
|         logging.info("Downloading {} from {}".format(fname, url))
 | |
| 
 | |
|         req = requests.get(url, stream=True)
 | |
|         if req.status_code != 200:
 | |
|             raise RuntimeError("Downloading from {} failed with code "
 | |
|                                "{}!".format(url, req.status_code))
 | |
| 
 | |
|         # For protecting download interupted, download to
 | |
|         # tmp_fullname firstly, move tmp_fullname to fullname
 | |
|         # after download finished
 | |
|         tmp_fullname = fullname + "_tmp"
 | |
|         total_size = req.headers.get('content-length')
 | |
|         with open(tmp_fullname, 'wb') as f:
 | |
|             if total_size and show_progress:
 | |
|                 for chunk in tqdm.tqdm(
 | |
|                         req.iter_content(chunk_size=1024),
 | |
|                         total=(int(total_size) + 1023) // 1024,
 | |
|                         unit='KB'):
 | |
|                     f.write(chunk)
 | |
|             else:
 | |
|                 for chunk in req.iter_content(chunk_size=1024):
 | |
|                     if chunk:
 | |
|                         f.write(chunk)
 | |
|         shutil.move(tmp_fullname, fullname)
 | |
|         logging.debug("{} download completed.".format(fname))
 | |
| 
 | |
|     return fullname
 | |
| 
 | |
| 
 | |
| def decompress(fname):
 | |
|     """
 | |
|     Decompress for zip and tar file
 | |
|     """
 | |
|     logging.info("Decompressing {}...".format(fname))
 | |
| 
 | |
|     # For protecting decompressing interupted,
 | |
|     # decompress to fpath_tmp directory firstly, if decompress
 | |
|     # successed, move decompress files to fpath and delete
 | |
|     # fpath_tmp and remove download compress file.
 | |
|     fpath = osp.split(fname)[0]
 | |
|     fpath_tmp = osp.join(fpath, 'tmp')
 | |
|     if osp.isdir(fpath_tmp):
 | |
|         shutil.rmtree(fpath_tmp)
 | |
|         os.makedirs(fpath_tmp)
 | |
| 
 | |
|     if fname.find('.tar') >= 0 or fname.find('.tgz') >= 0:
 | |
|         with tarfile.open(fname) as tf:
 | |
| 
 | |
|             def is_within_directory(directory, target):
 | |
| 
 | |
|                 abs_directory = os.path.abspath(directory)
 | |
|                 abs_target = os.path.abspath(target)
 | |
| 
 | |
|                 prefix = os.path.commonprefix([abs_directory, abs_target])
 | |
| 
 | |
|                 return prefix == abs_directory
 | |
| 
 | |
|             def safe_extract(tar,
 | |
|                              path=".",
 | |
|                              members=None,
 | |
|                              *,
 | |
|                              numeric_owner=False):
 | |
| 
 | |
|                 for member in tar.getmembers():
 | |
|                     member_path = os.path.join(path, member.name)
 | |
|                     if not is_within_directory(path, member_path):
 | |
|                         raise Exception("Attempted Path Traversal in Tar File")
 | |
| 
 | |
|                 tar.extractall(path, members, numeric_owner=numeric_owner)
 | |
| 
 | |
|             safe_extract(tf, path=fpath_tmp)
 | |
|     elif fname.find('.zip') >= 0:
 | |
|         with zipfile.ZipFile(fname) as zf:
 | |
|             zf.extractall(path=fpath_tmp)
 | |
|     else:
 | |
|         raise TypeError("Unsupport compress file type {}".format(fname))
 | |
| 
 | |
|     for f in os.listdir(fpath_tmp):
 | |
|         src_dir = osp.join(fpath_tmp, f)
 | |
|         dst_dir = osp.join(fpath, f)
 | |
|         move_and_merge_tree(src_dir, dst_dir)
 | |
| 
 | |
|     shutil.rmtree(fpath_tmp)
 | |
|     logging.debug("{} decompressed.".format(fname))
 | |
|     return dst_dir
 | |
| 
 | |
| 
 | |
| def url2dir(url, path, rename=None):
 | |
|     full_name = download(url, path, rename, show_progress=True)
 | |
|     print("File is donwloaded, now extracting...")
 | |
|     if url.count(".tgz") > 0 or url.count(".tar") > 0 or url.count("zip") > 0:
 | |
|         return decompress(full_name)
 | |
| 
 | |
| 
 | |
| def download_and_decompress(url, path='.', rename=None):
 | |
|     fname = osp.split(url)[-1]
 | |
|     fullname = osp.join(path, fname)
 | |
|     # if url.endswith(('tgz', 'tar.gz', 'tar', 'zip')):
 | |
|     #     fullname = osp.join(path, fname.split('.')[0])
 | |
|     nranks = 0
 | |
|     if nranks <= 1:
 | |
|         dst_dir = url2dir(url, path, rename)
 | |
|         if dst_dir is not None:
 | |
|             fullname = dst_dir
 | |
|     else:
 | |
|         lock_path = fullname + '.lock'
 | |
|         if not os.path.exists(fullname):
 | |
|             with open(lock_path, 'w'):
 | |
|                 os.utime(lock_path, None)
 | |
|             if local_rank == 0:
 | |
|                 dst_dir = url2dir(url, path, rename)
 | |
|                 if dst_dir is not None:
 | |
|                     fullname = dst_dir
 | |
|                 os.remove(lock_path)
 | |
|             else:
 | |
|                 while os.path.exists(lock_path):
 | |
|                     time.sleep(1)
 | |
|     return
 | |
| 
 | |
| 
 | |
| def download_model(name: str,
 | |
|                    path: str=None,
 | |
|                    format: str=None,
 | |
|                    version: str=None):
 | |
|     '''
 | |
|     Download pre-trained model for FastDeploy inference engine.
 | |
|     Args:
 | |
|         name: model name
 | |
|         path(str): local path for saving model. If not set, default is hubenv.MODEL_HOME
 | |
|         format(str): FastDeploy model format
 | |
|         version(str) : FastDeploy model version
 | |
|     '''
 | |
|     result = model_server.search_model(name, format, version)
 | |
|     if path is None:
 | |
|         path = hubenv.MODEL_HOME
 | |
|     if result:
 | |
|         url = result[0]['url']
 | |
|         format = result[0]['format']
 | |
|         version = result[0]['version']
 | |
|         fullpath = download(url, path, show_progress=True)
 | |
|         model_server.stat_model(name, format, version)
 | |
|         if format == 'paddle':
 | |
|             if url.count(".tgz") > 0 or url.count(".tar") > 0 or url.count(
 | |
|                     "zip") > 0:
 | |
|                 fullpath = decompress(fullpath)
 | |
|                 try:
 | |
|                     os.rename(fullpath,
 | |
|                               os.path.join(os.path.dirname(fullpath), name))
 | |
|                     fullpath = os.path.join(os.path.dirname(fullpath), name)
 | |
|                 except FileExistsError:
 | |
|                     pass
 | |
|         print('Successfully download model at path: {}'.format(fullpath))
 | |
|         return fullpath
 | |
|     else:
 | |
|         print('ERROR: Could not find a model named {}'.format(name))
 |