diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 668073298..0eed87c55 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -531,16 +531,16 @@ def retrive_model_from_server(model_name_or_path, revision="master"): local_path = f"{local_path}/{repo_id}" aistudio_download(repo_id=repo_id, revision=revision, local_dir=local_path) model_name_or_path = local_path - except Exception: + except requests.exceptions.ConnectTimeout: if os.path.exists(local_path): llm_logger.error( f"Failed to connect to aistudio, but detected that the model directory {local_path} exists. Attempting to start." ) return local_path - else: - raise Exception( - f"The {revision} of {model_name_or_path} is not exist. Please check the model name or revision." - ) + except Exception: + raise Exception( + f"The {revision} of {model_name_or_path} is not exist. Please check the model name or revision." + ) elif model_source == "MODELSCOPE": try: from modelscope.hub.snapshot_download import ( @@ -554,6 +554,12 @@ def retrive_model_from_server(model_name_or_path, revision="master"): local_path = f"{local_path}/{repo_id}" modelscope_download(repo_id=repo_id, revision=revision, local_dir=local_path) model_name_or_path = local_path + except requests.exceptions.ConnectTimeout: + if os.path.exists(local_path): + llm_logger.error( + f"Failed to connect to modelscope, but detected that the model directory {local_path} exists. Attempting to start." + ) + return local_path except Exception: raise Exception( f"The {revision} of {model_name_or_path} is not exist. Please check the model name or revision." diff --git a/test/utils/test_download.py b/test/utils/test_download.py index 44be39cd5..19949f8ac 100644 --- a/test/utils/test_download.py +++ b/test/utils/test_download.py @@ -50,21 +50,6 @@ class TestAistudioDownload(unittest.TestCase): os.environ.clear() - def test_retrive_model_from_aistudio_server_(self): - """ - Test case for retrieving a model from AI Studio server. - """ - os.environ["FD_MODEL_SOURCE"] = "AISTUDIO" - os.environ["FD_MODEL_CACHE"] = "./models" - - model_name_or_path = "baidu/ERNIE-4.5-0.3B-PT" - revision = "aaa" - expected_path = "./models/PaddlePaddle/ERNIE-4.5-0.3B-PT" - result = retrive_model_from_server(model_name_or_path, revision) - self.assertEqual(expected_path, result) - - os.environ.clear() - if __name__ == "__main__": unittest.main()