mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
[BugFix] fix multinode deployment (#2977)
This commit is contained in:
@@ -138,20 +138,10 @@ class EngineArgs:
|
|||||||
"""
|
"""
|
||||||
Token slot threshold for preallocating decoder blocks.
|
Token slot threshold for preallocating decoder blocks.
|
||||||
"""
|
"""
|
||||||
|
ips: Optional[List[str]] = None
|
||||||
|
"""
|
||||||
|
The ips of multinode deployment
|
||||||
|
|
||||||
dist_init_ip: Optional[str] = None
|
|
||||||
"""
|
|
||||||
The master node ip of multinode deployment
|
|
||||||
"""
|
|
||||||
|
|
||||||
nnodes: int = 1
|
|
||||||
"""
|
|
||||||
The number of nodes in multinode deployment
|
|
||||||
"""
|
|
||||||
|
|
||||||
node_rank: int = 0
|
|
||||||
"""
|
|
||||||
The rank of the current node in multinode deployment
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
swap_space: float = None
|
swap_space: float = None
|
||||||
@@ -566,24 +556,11 @@ class EngineArgs:
|
|||||||
# Cluster system parameters group
|
# Cluster system parameters group
|
||||||
system_group = parser.add_argument_group("System Configuration")
|
system_group = parser.add_argument_group("System Configuration")
|
||||||
system_group.add_argument(
|
system_group.add_argument(
|
||||||
"--dist-init-ip",
|
"--ips",
|
||||||
default=EngineArgs.dist_init_ip,
|
type=lambda s: s.split(",") if s else None,
|
||||||
help="IP addresses of master node.",
|
default=EngineArgs.ips,
|
||||||
)
|
help=
|
||||||
|
"IP addresses of all nodes participating in distributed inference.")
|
||||||
system_group.add_argument(
|
|
||||||
"--nnodes",
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.nnodes,
|
|
||||||
help="The number of all nodes.",
|
|
||||||
)
|
|
||||||
|
|
||||||
system_group.add_argument(
|
|
||||||
"--node-rank",
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.node_rank,
|
|
||||||
help="node rank id (range [0, nnodes)).",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Performance tuning parameters group
|
# Performance tuning parameters group
|
||||||
perf_group = parser.add_argument_group("Performance Tuning")
|
perf_group = parser.add_argument_group("Performance Tuning")
|
||||||
@@ -899,9 +876,7 @@ class EngineArgs:
|
|||||||
max_num_seqs=self.max_num_seqs,
|
max_num_seqs=self.max_num_seqs,
|
||||||
speculative_config=speculative_cfg,
|
speculative_config=speculative_cfg,
|
||||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||||
dist_init_ip=self.dist_init_ip,
|
ips=self.ips,
|
||||||
nnodes=self.nnodes,
|
|
||||||
node_rank=self.node_rank,
|
|
||||||
use_warmup=self.use_warmup,
|
use_warmup=self.use_warmup,
|
||||||
engine_worker_queue_port=self.engine_worker_queue_port,
|
engine_worker_queue_port=self.engine_worker_queue_port,
|
||||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||||
|
@@ -6,7 +6,6 @@
|
|||||||
# You may obtain a copy of the License at
|
# You may obtain a copy of the License at
|
||||||
#
|
#
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
#dist_init_ip
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
@@ -27,7 +26,6 @@ from fastdeploy.utils import (
|
|||||||
ceil_div,
|
ceil_div,
|
||||||
check_unified_ckpt,
|
check_unified_ckpt,
|
||||||
get_host_ip,
|
get_host_ip,
|
||||||
get_random_port,
|
|
||||||
is_port_available,
|
is_port_available,
|
||||||
llm_logger,
|
llm_logger,
|
||||||
)
|
)
|
||||||
@@ -644,9 +642,7 @@ class Config:
|
|||||||
max_model_len: int = 8192,
|
max_model_len: int = 8192,
|
||||||
max_num_seqs: int = 8,
|
max_num_seqs: int = 8,
|
||||||
max_num_batched_tokens: Optional[int] = None,
|
max_num_batched_tokens: Optional[int] = None,
|
||||||
dist_init_ip: str = None,
|
ips: str = None,
|
||||||
nnodes: int = 1,
|
|
||||||
node_rank: int = 0,
|
|
||||||
speculative_config: Optional[Dict[str, Any]] = None,
|
speculative_config: Optional[Dict[str, Any]] = None,
|
||||||
graph_optimization_config: Optional[Dict[str, Any]] = None,
|
graph_optimization_config: Optional[Dict[str, Any]] = None,
|
||||||
use_warmup: bool = False,
|
use_warmup: bool = False,
|
||||||
@@ -701,15 +697,25 @@ class Config:
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.max_num_batched_tokens = max_num_batched_tokens
|
self.max_num_batched_tokens = max_num_batched_tokens
|
||||||
self.tensor_parallel_size = tensor_parallel_size
|
self.tensor_parallel_size = tensor_parallel_size
|
||||||
self.dist_init_ip = dist_init_ip
|
self.ips = ips
|
||||||
|
|
||||||
self.nnode = nnodes
|
if self.ips is None:
|
||||||
self.node_rank = node_rank
|
|
||||||
if self.dist_init_ip is None:
|
|
||||||
self.master_ip = "0.0.0.0"
|
self.master_ip = "0.0.0.0"
|
||||||
|
elif isinstance(self.ips, list):
|
||||||
|
self.master_ip = self.ips[0]
|
||||||
else:
|
else:
|
||||||
self.master_ip = self.dist_init_ip
|
self.ips = self.ips.split(",")
|
||||||
self.dist_init_addr = f"{self.dist_init_ip}:{get_random_port()}"
|
self.master_ip = self.ips[0]
|
||||||
|
|
||||||
|
if self.ips is None:
|
||||||
|
self.nnode = 1
|
||||||
|
self.node_rank = 0
|
||||||
|
else:
|
||||||
|
self.nnode = len(self.ips)
|
||||||
|
|
||||||
|
for idx, ip in enumerate(self.ips):
|
||||||
|
if ip == self.master_ip:
|
||||||
|
self.node_rank = idx
|
||||||
|
|
||||||
self.max_model_len = max_model_len
|
self.max_model_len = max_model_len
|
||||||
self.max_num_seqs = max_num_seqs
|
self.max_num_seqs = max_num_seqs
|
||||||
@@ -773,14 +779,11 @@ class Config:
|
|||||||
self.device_ids.split(",").__len__() == self.worker_num_per_node
|
self.device_ids.split(",").__len__() == self.worker_num_per_node
|
||||||
), f"invalid CUDA_VISIBLE_DEVICES, should be equal to {self.worker_num_per_node}"
|
), f"invalid CUDA_VISIBLE_DEVICES, should be equal to {self.worker_num_per_node}"
|
||||||
|
|
||||||
assert (
|
|
||||||
self.worker_num_per_node % self.tensor_parallel_size == 0
|
|
||||||
), f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by worker_num_per_node: {self.worker_num_per_node}"
|
|
||||||
self.local_device_ids = self.device_ids.split(",")[: self.tensor_parallel_size]
|
self.local_device_ids = self.device_ids.split(",")[: self.tensor_parallel_size]
|
||||||
|
|
||||||
self.host_ip = get_host_ip()
|
self.host_ip = get_host_ip()
|
||||||
|
|
||||||
if self.dist_init_ip is None or self.host_ip == self.master_ip:
|
if self.ips is None or self.host_ip == self.master_ip:
|
||||||
self.is_master = True
|
self.is_master = True
|
||||||
else:
|
else:
|
||||||
self.is_master = False
|
self.is_master = False
|
||||||
@@ -817,9 +820,6 @@ class Config:
|
|||||||
assert is_port_available(
|
assert is_port_available(
|
||||||
"0.0.0.0", self.engine_worker_queue_port
|
"0.0.0.0", self.engine_worker_queue_port
|
||||||
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
|
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
|
||||||
assert (
|
|
||||||
self.max_chips_per_node >= self.tensor_parallel_size > 0
|
|
||||||
), f"tensor_parallel_size: {self.tensor_parallel_size} should be between 1 and {self.max_chips_per_node}"
|
|
||||||
assert self.nnode >= 1, f"nnode: {self.nnode} should no less than 1"
|
assert self.nnode >= 1, f"nnode: {self.nnode} should no less than 1"
|
||||||
assert self.max_model_len >= 16, f"max_model_len: {self.max_model_len} should be larger than 16"
|
assert self.max_model_len >= 16, f"max_model_len: {self.max_model_len} should be larger than 16"
|
||||||
assert self.max_num_seqs >= 1, f"max_num_seqs: {self.max_num_seqs} should be larger than 1"
|
assert self.max_num_seqs >= 1, f"max_num_seqs: {self.max_num_seqs} should be larger than 1"
|
||||||
|
@@ -994,10 +994,6 @@ class LLMEngine:
|
|||||||
配置环境变量
|
配置环境变量
|
||||||
"""
|
"""
|
||||||
variables = {
|
variables = {
|
||||||
"PADDLE_TRAINER_ID": 0,
|
|
||||||
"PADDLE_TRAINERS_NUM": 1,
|
|
||||||
"TRAINER_INSTANCES_NUM": 1,
|
|
||||||
"TRAINER_INSTANCES": "0.0.0.0",
|
|
||||||
"ENABLE_FASTDEPLOY_LOAD_MODEL_CONCURRENCY": 0,
|
"ENABLE_FASTDEPLOY_LOAD_MODEL_CONCURRENCY": 0,
|
||||||
"LOAD_STATE_DICT_THREAD_NUM": len(self.cfg.device_ids.split(",")),
|
"LOAD_STATE_DICT_THREAD_NUM": len(self.cfg.device_ids.split(",")),
|
||||||
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
|
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
|
||||||
@@ -1107,11 +1103,7 @@ class LLMEngine:
|
|||||||
if value:
|
if value:
|
||||||
arguments = arguments + f" --{worker_flag}"
|
arguments = arguments + f" --{worker_flag}"
|
||||||
if self.cfg.nnode > 1:
|
if self.cfg.nnode > 1:
|
||||||
pd_cmd = pd_cmd + (
|
pd_cmd = pd_cmd + f" --ips {','.join(self.cfg.ips)} --nnodes {len(self.cfg.ips)}"
|
||||||
f" --master {self.cfg.dist_init_addr}"
|
|
||||||
f" --nnodes {self.cfg.nnode!s}"
|
|
||||||
f" --rank {self.cfg.node_rank!s}"
|
|
||||||
)
|
|
||||||
pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log"
|
pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log"
|
||||||
llm_logger.info(f"Launch worker service command: {pd_cmd}")
|
llm_logger.info(f"Launch worker service command: {pd_cmd}")
|
||||||
p = subprocess.Popen(
|
p = subprocess.Popen(
|
||||||
|
@@ -22,6 +22,7 @@ import numpy as np
|
|||||||
from fastdeploy.input.preprocess import InputPreprocessor
|
from fastdeploy.input.preprocess import InputPreprocessor
|
||||||
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
|
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
|
||||||
from fastdeploy.metrics.work_metrics import work_process_metrics
|
from fastdeploy.metrics.work_metrics import work_process_metrics
|
||||||
|
from fastdeploy.platforms import current_platform
|
||||||
from fastdeploy.utils import EngineError, api_server_logger
|
from fastdeploy.utils import EngineError, api_server_logger
|
||||||
|
|
||||||
|
|
||||||
@@ -40,6 +41,7 @@ class EngineClient:
|
|||||||
mm_processor_kwargs,
|
mm_processor_kwargs,
|
||||||
enable_mm=False,
|
enable_mm=False,
|
||||||
reasoning_parser=None,
|
reasoning_parser=None,
|
||||||
|
data_parallel_size=1
|
||||||
):
|
):
|
||||||
input_processor = InputPreprocessor(
|
input_processor = InputPreprocessor(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -52,7 +54,10 @@ class EngineClient:
|
|||||||
self.reasoning_parser = reasoning_parser
|
self.reasoning_parser = reasoning_parser
|
||||||
self.data_processor = input_processor.create_processor()
|
self.data_processor = input_processor.create_processor()
|
||||||
self.max_model_len = max_model_len
|
self.max_model_len = max_model_len
|
||||||
self.worker_healthy_live_recorded_time_array = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
|
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||||
|
array_size = min(
|
||||||
|
max_chips_per_node, tensor_parallel_size * data_parallel_size)
|
||||||
|
self.worker_healthy_live_recorded_time_array = np.zeros(shape=[array_size], dtype=np.int32)
|
||||||
self.worker_healthy_live_signal = IPCSignal(
|
self.worker_healthy_live_signal = IPCSignal(
|
||||||
name="worker_healthy_live_signal",
|
name="worker_healthy_live_signal",
|
||||||
array=self.worker_healthy_live_recorded_time_array,
|
array=self.worker_healthy_live_recorded_time_array,
|
||||||
|
@@ -113,10 +113,11 @@ async def lifespan(app: FastAPI):
|
|||||||
args.mm_processor_kwargs,
|
args.mm_processor_kwargs,
|
||||||
args.enable_mm,
|
args.enable_mm,
|
||||||
args.reasoning_parser,
|
args.reasoning_parser,
|
||||||
|
args.data_parallel_size
|
||||||
)
|
)
|
||||||
app.state.dynamic_load_weight = args.dynamic_load_weight
|
app.state.dynamic_load_weight = args.dynamic_load_weight
|
||||||
chat_handler = OpenAIServingChat(engine_client, pid, args.dist_init_ip)
|
chat_handler = OpenAIServingChat(engine_client, pid, args.ips)
|
||||||
completion_handler = OpenAIServingCompletion(engine_client, pid, args.dist_init_ip)
|
completion_handler = OpenAIServingCompletion(engine_client, pid, args.ips)
|
||||||
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH)
|
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH)
|
||||||
engine_client.pid = pid
|
engine_client.pid = pid
|
||||||
app.state.engine_client = engine_client
|
app.state.engine_client = engine_client
|
||||||
|
@@ -19,7 +19,7 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
import numpy as np
|
||||||
import aiozmq
|
import aiozmq
|
||||||
import msgpack
|
import msgpack
|
||||||
from aiozmq import zmq
|
from aiozmq import zmq
|
||||||
@@ -48,11 +48,16 @@ class OpenAIServingChat:
|
|||||||
OpenAI-style chat completions serving
|
OpenAI-style chat completions serving
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, engine_client, pid, dist_init_ip):
|
def __init__(self, engine_client, pid, ips):
|
||||||
self.engine_client = engine_client
|
self.engine_client = engine_client
|
||||||
self.pid = pid
|
self.pid = pid
|
||||||
self.master_ip = dist_init_ip
|
self.master_ip = ips
|
||||||
self.host_ip = get_host_ip()
|
self.host_ip = get_host_ip()
|
||||||
|
if self.master_ip is not None:
|
||||||
|
if isinstance(self.master_ip, list):
|
||||||
|
self.master_ip = self.master_ip[0]
|
||||||
|
else:
|
||||||
|
self.master_ip = self.master_ip.split(",")[0]
|
||||||
|
|
||||||
def _check_master(self):
|
def _check_master(self):
|
||||||
if self.master_ip is None:
|
if self.master_ip is None:
|
||||||
@@ -80,6 +85,8 @@ class OpenAIServingChat:
|
|||||||
current_req_dict = request.to_dict_for_infer(request_id)
|
current_req_dict = request.to_dict_for_infer(request_id)
|
||||||
current_req_dict["arrival_time"] = time.time()
|
current_req_dict["arrival_time"] = time.time()
|
||||||
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
|
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
|
||||||
|
if isinstance(prompt_token_ids, np.ndarray):
|
||||||
|
prompt_token_ids = prompt_token_ids.tolist()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ErrorResponse(code=400, message=str(e))
|
return ErrorResponse(code=400, message=str(e))
|
||||||
|
|
||||||
|
@@ -18,7 +18,7 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import numpy as np
|
||||||
import aiozmq
|
import aiozmq
|
||||||
import msgpack
|
import msgpack
|
||||||
from aiozmq import zmq
|
from aiozmq import zmq
|
||||||
@@ -37,11 +37,17 @@ from fastdeploy.utils import api_server_logger, get_host_ip
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIServingCompletion:
|
class OpenAIServingCompletion:
|
||||||
def __init__(self, engine_client, pid, dist_init_ip):
|
def __init__(self, engine_client, pid, ips):
|
||||||
self.engine_client = engine_client
|
self.engine_client = engine_client
|
||||||
self.pid = pid
|
self.pid = pid
|
||||||
self.master_ip = dist_init_ip
|
self.master_ip = ips
|
||||||
self.host_ip = get_host_ip()
|
self.host_ip = get_host_ip()
|
||||||
|
if self.master_ip is not None:
|
||||||
|
if isinstance(self.master_ip, list):
|
||||||
|
self.master_ip = self.master_ip[0]
|
||||||
|
else:
|
||||||
|
self.master_ip = self.master_ip.split(",")[0]
|
||||||
|
|
||||||
|
|
||||||
def _check_master(self):
|
def _check_master(self):
|
||||||
if self.master_ip is None:
|
if self.master_ip is None:
|
||||||
@@ -97,7 +103,10 @@ class OpenAIServingCompletion:
|
|||||||
current_req_dict = request.to_dict_for_infer(request_id_idx, prompt)
|
current_req_dict = request.to_dict_for_infer(request_id_idx, prompt)
|
||||||
try:
|
try:
|
||||||
current_req_dict["arrival_time"] = time.time()
|
current_req_dict["arrival_time"] = time.time()
|
||||||
prompt_batched_token_ids.append(self.engine_client.format_and_add_data(current_req_dict))
|
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
|
||||||
|
if isinstance(prompt_token_ids, np.ndarray):
|
||||||
|
prompt_token_ids = prompt_token_ids.tolist()
|
||||||
|
prompt_batched_token_ids.append(prompt_token_ids)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ErrorResponse(message=str(e), code=400)
|
return ErrorResponse(message=str(e), code=400)
|
||||||
|
|
||||||
|
@@ -100,13 +100,14 @@ class GpuWorker(WorkerBase):
|
|||||||
# 1. Record memory state before profile run
|
# 1. Record memory state before profile run
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
Gb = 1024**3
|
Gb = 1024**3
|
||||||
paddle.device.cuda.reset_max_memory_reserved(self.local_rank)
|
local_rank = self.local_rank % self.max_chips_per_node
|
||||||
paddle.device.cuda.reset_max_memory_allocated(self.local_rank)
|
paddle.device.cuda.reset_max_memory_reserved(local_rank)
|
||||||
paddle_reserved_mem_before_run = paddle.device.cuda.max_memory_reserved(self.local_rank)
|
paddle.device.cuda.reset_max_memory_allocated(local_rank)
|
||||||
paddle_allocated_mem_before_run = paddle.device.cuda.max_memory_allocated(self.local_rank) # not reserved
|
paddle_reserved_mem_before_run = paddle.device.cuda.max_memory_reserved(local_rank)
|
||||||
|
paddle_allocated_mem_before_run = paddle.device.cuda.max_memory_allocated(local_rank) # not reserved
|
||||||
|
|
||||||
pynvml.nvmlInit()
|
pynvml.nvmlInit()
|
||||||
handle = pynvml.nvmlDeviceGetHandleByIndex(int(self.device_ids[self.local_rank]))
|
handle = pynvml.nvmlDeviceGetHandleByIndex(int(self.device_ids[local_rank]))
|
||||||
before_run_meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
before_run_meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -124,8 +125,8 @@ class GpuWorker(WorkerBase):
|
|||||||
self.model_runner.profile_run()
|
self.model_runner.profile_run()
|
||||||
|
|
||||||
# 3. Statistical memory information
|
# 3. Statistical memory information
|
||||||
paddle_reserved_mem_after_run = paddle.device.cuda.max_memory_reserved(self.local_rank)
|
paddle_reserved_mem_after_run = paddle.device.cuda.max_memory_reserved(local_rank)
|
||||||
paddle_allocated_mem_after_run = paddle.device.cuda.max_memory_allocated(self.local_rank)
|
paddle_allocated_mem_after_run = paddle.device.cuda.max_memory_allocated(local_rank)
|
||||||
|
|
||||||
model_block_memory_used = self.cal_theortical_kvcache()
|
model_block_memory_used = self.cal_theortical_kvcache()
|
||||||
paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run
|
paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run
|
||||||
|
@@ -149,7 +149,7 @@ class PaddleDisWorkerProc:
|
|||||||
self.parallel_config.pod_ip,
|
self.parallel_config.pod_ip,
|
||||||
self.parallel_config.engine_worker_queue_port,
|
self.parallel_config.engine_worker_queue_port,
|
||||||
)
|
)
|
||||||
|
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||||
self.task_queue = TaskQueue(
|
self.task_queue = TaskQueue(
|
||||||
address=task_address,
|
address=task_address,
|
||||||
is_server=False,
|
is_server=False,
|
||||||
@@ -193,7 +193,7 @@ class PaddleDisWorkerProc:
|
|||||||
suffix=self.parallel_config.engine_pid,
|
suffix=self.parallel_config.engine_pid,
|
||||||
create=False,
|
create=False,
|
||||||
)
|
)
|
||||||
self.worker_healthy_live_signal.value[self.local_rank % 8] = int(time.time())
|
self.worker_healthy_live_signal.value[self.local_rank % self.max_chips_per_node] = int(time.time())
|
||||||
|
|
||||||
# init model_weights_status
|
# init model_weights_status
|
||||||
workers_model_weights = np.zeros(shape=[1], dtype=np.int32)
|
workers_model_weights = np.zeros(shape=[1], dtype=np.int32)
|
||||||
@@ -388,7 +388,7 @@ class PaddleDisWorkerProc:
|
|||||||
dist.all_reduce(num_blocks_local, op=dist.ReduceOp.MIN)
|
dist.all_reduce(num_blocks_local, op=dist.ReduceOp.MIN)
|
||||||
num_blocks_local = num_blocks_local.item()
|
num_blocks_local = num_blocks_local.item()
|
||||||
|
|
||||||
if self.local_rank == 0:
|
if self.local_rank % self.max_chips_per_node == 0:
|
||||||
# 3. Send IPCSignal
|
# 3. Send IPCSignal
|
||||||
get_profile_block_num = np.zeros(shape=[1], dtype=np.int32)
|
get_profile_block_num = np.zeros(shape=[1], dtype=np.int32)
|
||||||
self.get_profile_block_num_signal = IPCSignal(
|
self.get_profile_block_num_signal = IPCSignal(
|
||||||
|
Reference in New Issue
Block a user