mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] optimize expert parallel (#3196)
* optimize * Update expert_service.py * Update worker_process.py * optimize
This commit is contained in:
@@ -14,11 +14,11 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict
|
||||
|
||||
import msgpack
|
||||
import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
@@ -34,7 +34,7 @@ class SplitwiseConnector:
|
||||
SplitwiseConnector class for managing and scheduling Splitwise tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, scheduler, worker_queue, resource_manager):
|
||||
def __init__(self, cfg, scheduler, worker_queue, resource_manager, splitwise_queue):
|
||||
"""
|
||||
Initialize the SplitwiseConnector instance.
|
||||
|
||||
@@ -51,6 +51,7 @@ class SplitwiseConnector:
|
||||
self.connect_innode_instances = {}
|
||||
self.temp_cache_info = dict()
|
||||
self.current_request_ids = dict()
|
||||
self.splitwise_queue = splitwise_queue
|
||||
|
||||
if self.cfg.cache_config.pd_comm_port is not None:
|
||||
self.zmq_ctx = zmq.Context()
|
||||
@@ -406,13 +407,19 @@ class SplitwiseConnector:
|
||||
if msg_type == "decode" or msg_type == "prefill":
|
||||
payload = [output.to_dict() for output in payload]
|
||||
|
||||
json_data = json.dumps({"type": msg_type, "payload": payload}).encode("utf-8")
|
||||
req_ids = [task["request_id"] for task in payload]
|
||||
logger.info(f"send message {msg_type} {req_ids}")
|
||||
|
||||
json_data = msgpack.packb({"type": msg_type, "payload": payload})
|
||||
|
||||
return json_data
|
||||
|
||||
def _deserialize_message(self, data: bytes):
|
||||
|
||||
# JSON反序列化
|
||||
message = json.loads(data.decode("utf-8"))
|
||||
message = msgpack.unpackb(data)
|
||||
req_ids = [task["request_id"] for task in message["payload"]]
|
||||
logger.info(f"send message {message['type']} {req_ids}")
|
||||
return message["type"], message["payload"]
|
||||
|
||||
def _process_message(self, message: bytes):
|
||||
@@ -441,7 +448,9 @@ class SplitwiseConnector:
|
||||
"""
|
||||
|
||||
tasks_data = [Request.from_dict(task) for task in tasks]
|
||||
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks_data))
|
||||
req_ids = [task["request_id"] for task in tasks]
|
||||
self.splitwise_queue.append(("decode", tasks_data))
|
||||
logger.debug(f"{req_ids} received prefill data")
|
||||
|
||||
def _handle_decode(self, payload):
|
||||
"""
|
||||
@@ -460,4 +469,6 @@ class SplitwiseConnector:
|
||||
finished=True,
|
||||
)
|
||||
)
|
||||
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks))
|
||||
req_ids = [task["request_id"] for task in payload]
|
||||
self.splitwise_queue.append(("decode", tasks))
|
||||
logger.debug(f"{req_ids} received decode data")
|
||||
|
Reference in New Issue
Block a user