[Feature] optimize expert parallel (#3196)

* optimize

* Update expert_service.py

* Update worker_process.py

* optimize
This commit is contained in:
ltd0924
2025-08-05 17:34:24 +08:00
committed by GitHub
parent dcf9c2daff
commit b20ffe3697
7 changed files with 174 additions and 134 deletions

View File

@@ -14,11 +14,11 @@
# limitations under the License.
"""
import json
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Dict
import msgpack
import zmq
from fastdeploy import envs
@@ -34,7 +34,7 @@ class SplitwiseConnector:
SplitwiseConnector class for managing and scheduling Splitwise tasks.
"""
def __init__(self, cfg, scheduler, worker_queue, resource_manager):
def __init__(self, cfg, scheduler, worker_queue, resource_manager, splitwise_queue):
"""
Initialize the SplitwiseConnector instance.
@@ -51,6 +51,7 @@ class SplitwiseConnector:
self.connect_innode_instances = {}
self.temp_cache_info = dict()
self.current_request_ids = dict()
self.splitwise_queue = splitwise_queue
if self.cfg.cache_config.pd_comm_port is not None:
self.zmq_ctx = zmq.Context()
@@ -406,13 +407,19 @@ class SplitwiseConnector:
if msg_type == "decode" or msg_type == "prefill":
payload = [output.to_dict() for output in payload]
json_data = json.dumps({"type": msg_type, "payload": payload}).encode("utf-8")
req_ids = [task["request_id"] for task in payload]
logger.info(f"send message {msg_type} {req_ids}")
json_data = msgpack.packb({"type": msg_type, "payload": payload})
return json_data
def _deserialize_message(self, data: bytes):
# JSON反序列化
message = json.loads(data.decode("utf-8"))
message = msgpack.unpackb(data)
req_ids = [task["request_id"] for task in message["payload"]]
logger.info(f"send message {message['type']} {req_ids}")
return message["type"], message["payload"]
def _process_message(self, message: bytes):
@@ -441,7 +448,9 @@ class SplitwiseConnector:
"""
tasks_data = [Request.from_dict(task) for task in tasks]
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks_data))
req_ids = [task["request_id"] for task in tasks]
self.splitwise_queue.append(("decode", tasks_data))
logger.debug(f"{req_ids} received prefill data")
def _handle_decode(self, payload):
"""
@@ -460,4 +469,6 @@ class SplitwiseConnector:
finished=True,
)
)
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks))
req_ids = [task["request_id"] for task in payload]
self.splitwise_queue.append(("decode", tasks))
logger.debug(f"{req_ids} received decode data")