mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] support bos download retry (#5137)
* support bos download retry * update code * update code
This commit is contained in:
@@ -550,8 +550,6 @@ class ParallelConfig:
|
||||
self.use_internode_ll_two_stage: bool = False
|
||||
# disable sequence parallel moe
|
||||
self.disable_sequence_parallel_moe: bool = False
|
||||
# enable async download features
|
||||
self.enable_async_download_features: bool = False
|
||||
|
||||
self.pod_ip: str = None
|
||||
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
|
||||
|
||||
@@ -467,11 +467,6 @@ class EngineArgs:
|
||||
Url for router server, such as `0.0.0.0:30000`.
|
||||
"""
|
||||
|
||||
enable_async_download_features: bool = False
|
||||
"""
|
||||
Flag to enable async download features. Default is False (disabled).
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Post-initialization processing to set default tokenizer if not provided.
|
||||
@@ -844,12 +839,6 @@ class EngineArgs:
|
||||
default=EngineArgs.enable_expert_parallel,
|
||||
help="Enable expert parallelism.",
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--enable-async-download-features",
|
||||
action="store_true",
|
||||
default=EngineArgs.enable_async_download_features,
|
||||
help="Enable async download features.",
|
||||
)
|
||||
|
||||
# Load group
|
||||
load_group = parser.add_argument_group("Load Configuration")
|
||||
|
||||
@@ -809,7 +809,7 @@ class ResourceManagerV1(ResourceManager):
|
||||
|
||||
def download_bos_features(bos_client, features_urls):
|
||||
result_list = []
|
||||
for status, feature in download_from_bos(self.bos_client, features_urls):
|
||||
for status, feature in download_from_bos(self.bos_client, features_urls, retry=1):
|
||||
if status:
|
||||
llm_logger.info(f"request {request.request_id} async download feature: {feature.shape}")
|
||||
result_list.append(feature)
|
||||
@@ -819,7 +819,7 @@ class ResourceManagerV1(ResourceManager):
|
||||
return error_msg
|
||||
return result_list
|
||||
|
||||
if not self.config.parallel_config.enable_async_download_features or not self._has_features_info(request):
|
||||
if not self._has_features_info(request):
|
||||
return None
|
||||
|
||||
if self.bos_client is None:
|
||||
|
||||
@@ -29,6 +29,7 @@ import subprocess
|
||||
import sys
|
||||
import tarfile
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
@@ -976,12 +977,13 @@ def init_bos_client():
|
||||
return BosClient(cfg)
|
||||
|
||||
|
||||
def download_from_bos(bos_client, bos_links):
|
||||
def download_from_bos(bos_client, bos_links, retry: int = 0):
|
||||
"""
|
||||
Download pickled objects from Baidu Object Storage (BOS).
|
||||
Args:
|
||||
bos_client: BOS client instance
|
||||
bos_links: Single link or list of BOS links in format "bos://bucket-name/path/to/object"
|
||||
retry: Number of times to retry on failure (only retries on network-related errors)
|
||||
Yields:
|
||||
tuple: (success: bool, data: np.ndarray | error_msg: str)
|
||||
- On success: (True, deserialized_data)
|
||||
@@ -989,20 +991,39 @@ def download_from_bos(bos_client, bos_links):
|
||||
Security Note:
|
||||
Uses pickle deserialization. Only use with trusted data sources.
|
||||
"""
|
||||
|
||||
def _bos_download(bos_client, link):
|
||||
if link.startswith("bos://"):
|
||||
link = link.replace("bos://", "")
|
||||
|
||||
bucket_name = "/".join(link.split("/")[1:-1])
|
||||
object_key = link.split("/")[-1]
|
||||
return bos_client.get_object_as_string(bucket_name, object_key)
|
||||
|
||||
if not isinstance(bos_links, list):
|
||||
bos_links = [bos_links]
|
||||
|
||||
for link in bos_links:
|
||||
try:
|
||||
if link.startswith("bos://"):
|
||||
link = link.replace("bos://", "")
|
||||
|
||||
bucket_name = "/".join(link.split("/")[1:-1])
|
||||
object_key = link.split("/")[-1]
|
||||
response = bos_client.get_object_as_string(bucket_name, object_key)
|
||||
response = _bos_download(bos_client, link)
|
||||
yield True, pickle.loads(response)
|
||||
except Exception as e:
|
||||
yield False, f"link {link} download error: {str(e)}"
|
||||
except Exception:
|
||||
# Only retry on network-related or timeout exceptions
|
||||
exceptions_msg = str(traceback.format_exc())
|
||||
|
||||
if "request rate is too high" not in exceptions_msg or retry <= 0:
|
||||
yield False, f"Failed to download {link}: {exceptions_msg}"
|
||||
break
|
||||
|
||||
for attempt in range(retry):
|
||||
try:
|
||||
llm_logger.warning(f"Retry attempt {attempt + 1}/{retry} for {link}")
|
||||
response = _bos_download(bos_client, link)
|
||||
yield True, pickle.loads(response)
|
||||
break
|
||||
except Exception:
|
||||
if attempt == retry - 1: # Last attempt failed
|
||||
yield False, f"Failed after {retry} retries for {link}: {str(traceback.format_exc())}"
|
||||
break
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ class TestResourceManagerV1(unittest.TestCase):
|
||||
max_num_seqs=max_num_seqs,
|
||||
num_gpu_blocks_override=102,
|
||||
max_num_batched_tokens=3200,
|
||||
enable_async_download_features=True,
|
||||
)
|
||||
args = asdict(engine_args)
|
||||
|
||||
@@ -130,9 +129,9 @@ class TestResourceManagerV1(unittest.TestCase):
|
||||
self.manager.bos_client = mock_client
|
||||
result = self.manager._download_features(self.request)
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(
|
||||
self.assertIn(
|
||||
"request test_request download features error",
|
||||
self.request.error_message,
|
||||
"request test_request download features error: link bucket-name/path/to/object1 download error: network error",
|
||||
)
|
||||
self.assertEqual(self.request.error_code, 530)
|
||||
|
||||
@@ -151,12 +150,27 @@ class TestResourceManagerV1(unittest.TestCase):
|
||||
self.manager.bos_client = mock_client
|
||||
result = self.manager._download_features(self.request)
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(
|
||||
self.assertIn(
|
||||
"request test_request download features error",
|
||||
self.request.error_message,
|
||||
"request test_request download features error: link bucket-name/path/to/object2 download error: timeout",
|
||||
)
|
||||
self.assertEqual(self.request.error_code, 530)
|
||||
|
||||
def test_download_features_retry(self):
|
||||
"""Test image feature download with error"""
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_object_as_string.side_effect = Exception(
|
||||
"Your request rate is too high. We have put limits on your bucket."
|
||||
)
|
||||
|
||||
self.request.multimodal_inputs = {"image_feature_urls": ["bos://bucket-name/path/to/object1"]}
|
||||
|
||||
self.manager.bos_client = mock_client
|
||||
result = self.manager._download_features(self.request)
|
||||
self.assertIsNone(result)
|
||||
self.assertIn("Failed after 1 retries for bos://bucket-name/path/to/object1", self.request.error_message)
|
||||
self.assertEqual(self.request.error_code, 530)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user