[Feature] support bos download retry (#5137)

* support bos download retry

* update code

* update code
This commit is contained in:
kevin
2025-11-21 10:18:32 +08:00
committed by GitHub
parent 43097a512a
commit 7454480e07
5 changed files with 51 additions and 29 deletions

View File

@@ -550,8 +550,6 @@ class ParallelConfig:
self.use_internode_ll_two_stage: bool = False
# disable sequence parallel moe
self.disable_sequence_parallel_moe: bool = False
# enable async download features
self.enable_async_download_features: bool = False
self.pod_ip: str = None
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).

View File

@@ -467,11 +467,6 @@ class EngineArgs:
Url for router server, such as `0.0.0.0:30000`.
"""
enable_async_download_features: bool = False
"""
Flag to enable async download features. Default is False (disabled).
"""
def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
@@ -844,12 +839,6 @@ class EngineArgs:
default=EngineArgs.enable_expert_parallel,
help="Enable expert parallelism.",
)
parallel_group.add_argument(
"--enable-async-download-features",
action="store_true",
default=EngineArgs.enable_async_download_features,
help="Enable async download features.",
)
# Load group
load_group = parser.add_argument_group("Load Configuration")

View File

@@ -809,7 +809,7 @@ class ResourceManagerV1(ResourceManager):
def download_bos_features(bos_client, features_urls):
result_list = []
for status, feature in download_from_bos(self.bos_client, features_urls):
for status, feature in download_from_bos(self.bos_client, features_urls, retry=1):
if status:
llm_logger.info(f"request {request.request_id} async download feature: {feature.shape}")
result_list.append(feature)
@@ -819,7 +819,7 @@ class ResourceManagerV1(ResourceManager):
return error_msg
return result_list
if not self.config.parallel_config.enable_async_download_features or not self._has_features_info(request):
if not self._has_features_info(request):
return None
if self.bos_client is None:

View File

@@ -29,6 +29,7 @@ import subprocess
import sys
import tarfile
import time
import traceback
from datetime import datetime
from enum import Enum
from http import HTTPStatus
@@ -976,12 +977,13 @@ def init_bos_client():
return BosClient(cfg)
def download_from_bos(bos_client, bos_links):
def download_from_bos(bos_client, bos_links, retry: int = 0):
"""
Download pickled objects from Baidu Object Storage (BOS).
Args:
bos_client: BOS client instance
bos_links: Single link or list of BOS links in format "bos://bucket-name/path/to/object"
retry: Number of times to retry on failure (only retries on network-related errors)
Yields:
tuple: (success: bool, data: np.ndarray | error_msg: str)
- On success: (True, deserialized_data)
@@ -989,20 +991,39 @@ def download_from_bos(bos_client, bos_links):
Security Note:
Uses pickle deserialization. Only use with trusted data sources.
"""
def _bos_download(bos_client, link):
if link.startswith("bos://"):
link = link.replace("bos://", "")
bucket_name = "/".join(link.split("/")[1:-1])
object_key = link.split("/")[-1]
return bos_client.get_object_as_string(bucket_name, object_key)
if not isinstance(bos_links, list):
bos_links = [bos_links]
for link in bos_links:
try:
if link.startswith("bos://"):
link = link.replace("bos://", "")
bucket_name = "/".join(link.split("/")[1:-1])
object_key = link.split("/")[-1]
response = bos_client.get_object_as_string(bucket_name, object_key)
response = _bos_download(bos_client, link)
yield True, pickle.loads(response)
except Exception as e:
yield False, f"link {link} download error: {str(e)}"
except Exception:
# Only retry on network-related or timeout exceptions
exceptions_msg = str(traceback.format_exc())
if "request rate is too high" not in exceptions_msg or retry <= 0:
yield False, f"Failed to download {link}: {exceptions_msg}"
break
for attempt in range(retry):
try:
llm_logger.warning(f"Retry attempt {attempt + 1}/{retry} for {link}")
response = _bos_download(bos_client, link)
yield True, pickle.loads(response)
break
except Exception:
if attempt == retry - 1: # Last attempt failed
yield False, f"Failed after {retry} retries for {link}: {str(traceback.format_exc())}"
break

View File

@@ -20,7 +20,6 @@ class TestResourceManagerV1(unittest.TestCase):
max_num_seqs=max_num_seqs,
num_gpu_blocks_override=102,
max_num_batched_tokens=3200,
enable_async_download_features=True,
)
args = asdict(engine_args)
@@ -130,9 +129,9 @@ class TestResourceManagerV1(unittest.TestCase):
self.manager.bos_client = mock_client
result = self.manager._download_features(self.request)
self.assertIsNone(result)
self.assertEqual(
self.assertIn(
"request test_request download features error",
self.request.error_message,
"request test_request download features error: link bucket-name/path/to/object1 download error: network error",
)
self.assertEqual(self.request.error_code, 530)
@@ -151,12 +150,27 @@ class TestResourceManagerV1(unittest.TestCase):
self.manager.bos_client = mock_client
result = self.manager._download_features(self.request)
self.assertIsNone(result)
self.assertEqual(
self.assertIn(
"request test_request download features error",
self.request.error_message,
"request test_request download features error: link bucket-name/path/to/object2 download error: timeout",
)
self.assertEqual(self.request.error_code, 530)
def test_download_features_retry(self):
"""Test image feature download with error"""
mock_client = MagicMock()
mock_client.get_object_as_string.side_effect = Exception(
"Your request rate is too high. We have put limits on your bucket."
)
self.request.multimodal_inputs = {"image_feature_urls": ["bos://bucket-name/path/to/object1"]}
self.manager.bos_client = mock_client
result = self.manager._download_features(self.request)
self.assertIsNone(result)
self.assertIn("Failed after 1 retries for bos://bucket-name/path/to/object1", self.request.error_message)
self.assertEqual(self.request.error_code, 530)
if __name__ == "__main__":
unittest.main()