mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Bug fix] Test td cache messager (#3242)
* support disable cache task in decode node * fix busg * Update engine.py * Update expert_service.py * Update splitwise_connector.py * Optimize log for debug * Optimize log for debug * fix bug --------- Co-authored-by: ltd0924 <ltd0924@sina.com> Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com>
This commit is contained in:
@@ -17,8 +17,9 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import time
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
@@ -196,7 +197,9 @@ class CacheMessager:
|
||||
|
||||
self.gpu_id = gpu_id
|
||||
self.cache_info = dict()
|
||||
self.rank_id = self.rank + local_data_parallel_id * self.nranks # align with engine worker rank (paddle.distributed.launch)
|
||||
self.rank_id = (
|
||||
self.rank + local_data_parallel_id * self.nranks
|
||||
) # align with engine worker rank (paddle.distributed.launch)
|
||||
|
||||
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
|
||||
connect_rdma_thread.daemon = True
|
||||
@@ -284,7 +287,7 @@ class CacheMessager:
|
||||
if not self.cache_info:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
logger.info(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
|
||||
logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
|
||||
for req_id, item in list(self.cache_info.items()):
|
||||
if "status" not in item:
|
||||
continue
|
||||
@@ -364,7 +367,7 @@ class CacheMessager:
|
||||
|
||||
except Exception as e:
|
||||
logger.info(f"prefill layerwise send cache thread has exception: {e}")
|
||||
|
||||
|
||||
def _handle_connect_task(self):
|
||||
while True:
|
||||
try:
|
||||
@@ -465,7 +468,8 @@ def main():
|
||||
if __name__ == "__main__":
|
||||
|
||||
args = parse_args()
|
||||
logger = get_logger("cache_messager", "cache_messager.log")
|
||||
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
||||
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
|
||||
|
||||
logger.info("create cache messager...")
|
||||
logger.info(f"{args}")
|
||||
|
Reference in New Issue
Block a user