[LLM] support multi node deploy (#2708)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* [LLM] support multi node deploy

* Update engine.py

* fix bugs

* fix

* [LLM] support multi node deploy

* [LLM] support multi node deploy

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
ltd0924
2025-07-06 10:33:51 +08:00
committed by GitHub
parent 04a8e1ef2b
commit 68b4755587
13 changed files with 157 additions and 87 deletions

View File

@@ -37,6 +37,7 @@ class CacheMessager(object):
def __init__(self, def __init__(self,
splitwise_role, splitwise_role,
transfer_protocol, transfer_protocol,
pod_ip,
engine_worker_queue_port, engine_worker_queue_port,
local_data_parallel_id, local_data_parallel_id,
gpu_cache_kvs, gpu_cache_kvs,
@@ -69,7 +70,7 @@ class CacheMessager(object):
self.gpu_cache_kvs = gpu_cache_kvs self.gpu_cache_kvs = gpu_cache_kvs
self.rank = rank self.rank = rank
self.nranks = nranks self.nranks = nranks
address = ('0.0.0.0', engine_worker_queue_port) address = (pod_ip, engine_worker_queue_port)
self.engine_worker_queue = EngineWorkerQueue( self.engine_worker_queue = EngineWorkerQueue(
address=address, address=address,
is_server=False, is_server=False,

View File

@@ -71,6 +71,10 @@ def parse_args():
type=int, type=int,
default=9923, default=9923,
help="cache queue port") help="cache queue port")
parser.add_argument("--pod_ip",
type=str,
default="0.0.0.0",
help="pod ip")
parser.add_argument("--engine_worker_queue_port", parser.add_argument("--engine_worker_queue_port",
type=int, type=int,
default=9923, default=9923,
@@ -144,7 +148,7 @@ class CacheTransferManager:
self.rank = rank self.rank = rank
self.device = device self.device = device
address = ('0.0.0.0', args.cache_queue_port) address = (args.pod_ip, args.cache_queue_port)
self.cache_task_queue = EngineCacheQueue( self.cache_task_queue = EngineCacheQueue(
address=address, address=address,
is_server=False, is_server=False,
@@ -236,6 +240,7 @@ class CacheTransferManager:
self.cache_messager = CacheMessager( self.cache_messager = CacheMessager(
splitwise_role=args.splitwise_role, splitwise_role=args.splitwise_role,
transfer_protocol=args.protocol, transfer_protocol=args.protocol,
pod_ip=args.pod_ip,
engine_worker_queue_port=args.engine_worker_queue_port, engine_worker_queue_port=args.engine_worker_queue_port,
local_data_parallel_id=args.local_data_parallel_id, local_data_parallel_id=args.local_data_parallel_id,
gpu_cache_kvs=self.gpu_cache_kvs, gpu_cache_kvs=self.gpu_cache_kvs,

View File

@@ -109,7 +109,7 @@ class PrefixCacheManager:
def launch_cache_manager(self, cache_config, tensor_parallel_size, \ def launch_cache_manager(self, cache_config, tensor_parallel_size, \
device_ids, engine_worker_queue_port, pid_suffix): device_ids, pod_ip, engine_worker_queue_port, pid_suffix):
""" """
launch_cache_manager function used to initialize the cache manager. launch_cache_manager function used to initialize the cache manager.
""" """
@@ -123,7 +123,7 @@ class PrefixCacheManager:
create=True) create=True)
self.cache_task_queue = EngineCacheQueue( self.cache_task_queue = EngineCacheQueue(
address=('127.0.0.1', cache_config.cache_queue_port), address=(pod_ip, cache_config.cache_queue_port),
authkey=b'cache_queue_service', authkey=b'cache_queue_service',
is_server=False, is_server=False,
num_client=tensor_parallel_size, num_client=tensor_parallel_size,
@@ -166,6 +166,7 @@ class PrefixCacheManager:
f" --cache_dtype {cache_config.cache_dtype}" + f" --cache_dtype {cache_config.cache_dtype}" +
f" --cache_queue_port {cache_config.cache_queue_port}" + f" --cache_queue_port {cache_config.cache_queue_port}" +
f" --enable_splitwise {int(self.enable_splitwise)}" + f" --enable_splitwise {int(self.enable_splitwise)}" +
f" --pod_ip {pod_ip}" +
f" --engine_worker_queue_port {engine_worker_queue_port}" + f" --engine_worker_queue_port {engine_worker_queue_port}" +
f" --num_gpu_blocks {cache_config.total_block_num}" + f" --num_gpu_blocks {cache_config.total_block_num}" +
f" --num_cpu_blocks {cache_config.num_cpu_blocks}" + f" --num_cpu_blocks {cache_config.num_cpu_blocks}" +

View File

@@ -122,10 +122,7 @@ class EngineArgs:
""" """
Ratio of tokens to process in a block. Ratio of tokens to process in a block.
""" """
nnode: int = 1
"""
Number of nodes in the cluster.
"""
pod_ips: Optional[List[str]] = None pod_ips: Optional[List[str]] = None
""" """
List of IP addresses for nodes in the cluster. List of IP addresses for nodes in the cluster.
@@ -485,10 +482,7 @@ class EngineArgs:
default=EngineArgs.pod_ips, default=EngineArgs.pod_ips,
help= help=
"List of IP addresses for nodes in the cluster (comma-separated).") "List of IP addresses for nodes in the cluster (comma-separated).")
system_group.add_argument("--nnode",
type=int,
default=EngineArgs.nnode,
help="Number of nodes in the cluster.")
# Performance tuning parameters group # Performance tuning parameters group
perf_group = parser.add_argument_group("Performance Tuning") perf_group = parser.add_argument_group("Performance Tuning")
@@ -773,7 +767,6 @@ class EngineArgs:
max_num_seqs=self.max_num_seqs, max_num_seqs=self.max_num_seqs,
speculative_config=speculative_cfg, speculative_config=speculative_cfg,
max_num_batched_tokens=self.max_num_batched_tokens, max_num_batched_tokens=self.max_num_batched_tokens,
nnode=self.nnode,
pod_ips=self.pod_ips, pod_ips=self.pod_ips,
use_warmup=self.use_warmup, use_warmup=self.use_warmup,
engine_worker_queue_port=self.engine_worker_queue_port, engine_worker_queue_port=self.engine_worker_queue_port,

View File

@@ -505,7 +505,6 @@ class Config:
model_name_or_path: str = None, model_name_or_path: str = None,
tokenizer: str = None, tokenizer: str = None,
tensor_parallel_size: int = 8, tensor_parallel_size: int = 8,
nnode: int = 1,
max_model_len: int = 8192, max_model_len: int = 8192,
max_num_seqs: int = 8, max_num_seqs: int = 8,
max_num_batched_tokens: Optional[int] = None, max_num_batched_tokens: Optional[int] = None,
@@ -539,7 +538,6 @@ class Config:
model_name_or_path (str): Model directory path or model name. model_name_or_path (str): Model directory path or model name.
tokenizer (str): Default is the model. tokenizer (str): Default is the model.
tensor_parallel_size (int): Tensor parallel size. Default is 8. tensor_parallel_size (int): Tensor parallel size. Default is 8.
nnode (int): Number of nodes. Default is 1.
max_model_len (int): Maximum model length. Default is 8192. max_model_len (int): Maximum model length. Default is 8192.
max_num_seqs (int): Maximum number of sequences. Default is 8. max_num_seqs (int): Maximum number of sequences. Default is 8.
max_num_batched_tokens (Optional[int]): Maximum number of batched tokens. Default is None. max_num_batched_tokens (Optional[int]): Maximum number of batched tokens. Default is None.
@@ -565,7 +563,6 @@ class Config:
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
self.tensor_parallel_size = tensor_parallel_size self.tensor_parallel_size = tensor_parallel_size
self.nnode = nnode
self.pod_ips = pod_ips self.pod_ips = pod_ips
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
@@ -585,12 +582,15 @@ class Config:
self.max_capture_batch_size = max_capture_batch_size self.max_capture_batch_size = max_capture_batch_size
self.guided_decoding_backend = guided_decoding_backend self.guided_decoding_backend = guided_decoding_backend
self.disable_any_whitespace = disable_any_whitespace self.disable_any_whitespace = disable_any_whitespace
self.is_master = True
self._str_to_list("innode_prefill_ports", int)
self._str_to_list("pod_ips", str)
if self.innode_prefill_ports is not None: if self.pod_ips is None:
if not isinstance(self.innode_prefill_ports, list): self.nnode = 1
ports = str(self.innode_prefill_ports).split(',') else:
self.innode_prefill_ports = [int(port) for port in ports] self.nnode = len(self.pod_ips)
assert self.splitwise_role in ["mixed", "prefill", "decode"] assert self.splitwise_role in ["mixed", "prefill", "decode"]
# TODO # TODO
@@ -609,14 +609,15 @@ class Config:
num_ranks = self.tensor_parallel_size * self.parallel_config.expert_parallel_size num_ranks = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
if num_ranks > 8: if num_ranks > 8:
local_num_ranks = 8 self.worker_num_per_node = 8
self.nnode = ceil_div(num_ranks, local_num_ranks) nnode = ceil_div(num_ranks, self.worker_num_per_node)
assert nnode == self.nnode, \
f"nnode: {nnode}, but got {self.nnode}"
else: else:
local_num_ranks = num_ranks self.worker_num_per_node = num_ranks
self.engine_worker_queue_port = engine_worker_queue_port self.engine_worker_queue_port = engine_worker_queue_port
self.device_ids = ",".join([str(i) for i in range(min((self.tensor_parallel_size * \ self.device_ids = ",".join([str(i) for i in range(self.worker_num_per_node)])
self.parallel_config.expert_parallel_size), 8))])
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids) self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)
self.read_from_config() self.read_from_config()
@@ -628,16 +629,21 @@ class Config:
""" """
calculate some parameters calculate some parameters
""" """
total_rank = self.tensor_parallel_size * self.parallel_config.expert_parallel_size assert self.device_ids.split(',').__len__() == self.worker_num_per_node, \
assert self.device_ids.split(',').__len__() == min(total_rank, 8), \ f"invalid CUDA_VISIBLE_DEVICES, should be equal to {self.worker_num_per_node}"
f"invalid CUDA_VISIBLE_DEVICES, should be equal to {min(total_rank, 8)}"
assert self.worker_num_per_node % self.tensor_parallel_size == 0, \
f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by worker_num_per_node: {self.worker_num_per_node}"
self.local_device_ids = self.device_ids.split( self.local_device_ids = self.device_ids.split(
',')[:self.tensor_parallel_size] ',')[:self.tensor_parallel_size]
assert self.tensor_parallel_size % self.nnode == 0, \
f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by nnode: {self.nnode}"
self.worker_num_per_node = total_rank // self.nnode
self.host_ip = get_host_ip() self.host_ip = get_host_ip()
if self.pod_ips is None:
self.pod_ips = ["0.0.0.0"]
elif self.host_ip != self.pod_ips[0]:
self.is_master = False
import paddle import paddle
self.paddle_commit_id = paddle.version.commit self.paddle_commit_id = paddle.version.commit
@@ -808,5 +814,16 @@ class Config:
"return_full_hidden_states") "return_full_hidden_states")
reset_value(self.cache_config, "cache_dtype", "infer_model_dtype") reset_value(self.cache_config, "cache_dtype", "infer_model_dtype")
def _check_master(self):
return self.is_master
def _str_to_list(self, attr_name, default_type):
if hasattr(self, attr_name):
val = getattr(self, attr_name)
if type(val) is str:
setattr(self, attr_name, [default_type(i) for i in val.split(",")])
else:
setattr(self, attr_name, val)
def __str__(self) -> str: def __str__(self) -> str:
return json.dumps(self.__dict__, indent=4) return json.dumps(self.__dict__, indent=4)

View File

@@ -98,30 +98,7 @@ class LLMEngine(object):
cfg.mm_processor_kwargs, cfg.mm_processor_kwargs,
cfg.enable_mm) cfg.enable_mm)
address = ('0.0.0.0', self.cfg.engine_worker_queue_port) self.start_queue_service()
self.engine_worker_queue_server = EngineWorkerQueue(
address=address,
is_server=True,
num_client=self.cfg.tensor_parallel_size,
local_data_parallel_size=self.cfg.parallel_config.
data_parallel_size)
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,
num_client=self.cfg.tensor_parallel_size,
client_id=0,
local_data_parallel_id=0)
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != 'mixed':
self.cache_task_queue = EngineCacheQueue(
address=('127.0.0.1', self.cfg.cache_config.cache_queue_port),
authkey=b'cache_queue_service',
is_server=True,
num_client=self.cfg.tensor_parallel_size,
client_id=-1,
local_data_parallel_size=self.cfg.parallel_config.
data_parallel_size)
self.resource_manager = ResourceManager(cfg.max_num_seqs, cfg, self.resource_manager = ResourceManager(cfg.max_num_seqs, cfg,
cfg.tensor_parallel_size, cfg.tensor_parallel_size,
@@ -198,9 +175,12 @@ class LLMEngine(object):
or self.cfg.splitwise_role != "mixed"): or self.cfg.splitwise_role != "mixed"):
device_ids = self.cfg.device_ids.split(",") device_ids = self.cfg.device_ids.split(",")
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager( self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
self.cfg.cache_config, self.cfg.tensor_parallel_size, cache_config=self.cfg.cache_config,
device_ids, self.cfg.engine_worker_queue_port, tensor_parallel_size=self.cfg.tensor_parallel_size,
self.ipc_signal_suffix) device_ids=device_ids,
pod_ip=self.cfg.pod_ips[0],
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=self.ipc_signal_suffix)
self.worker_proc = self._start_worker_service() self.worker_proc = self._start_worker_service()
console_logger.info("Waitting worker processes ready...") console_logger.info("Waitting worker processes ready...")
@@ -850,10 +830,7 @@ class LLMEngine(object):
Initialize shared memory to indicate engine status Initialize shared memory to indicate engine status
""" """
# worker_ready_signatensor_parallel_size # worker_ready_signatensor_parallel_size
array_size = min( worker_ready_signal_data = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32)
8, self.cfg.tensor_parallel_size *
self.cfg.parallel_config.data_parallel_size)
worker_ready_signal_data = np.zeros(shape=[array_size], dtype=np.int32)
self.worker_ready_signal = IPCSignal(name="worker_ready_signal", self.worker_ready_signal = IPCSignal(name="worker_ready_signal",
array=worker_ready_signal_data, array=worker_ready_signal_data,
dtype=np.int32, dtype=np.int32,
@@ -889,7 +866,7 @@ class LLMEngine(object):
create=True) create=True)
# worker_live_signal 用于engine感知各worker进程是否存活记录每个step 时间 # worker_live_signal 用于engine感知各worker进程是否存活记录每个step 时间
worker_healthy_live_recorded_time_array = np.zeros(shape=[array_size], worker_healthy_live_recorded_time_array = np.zeros(shape=[self.cfg.worker_num_per_node],
dtype=np.int32) dtype=np.int32)
self.worker_healthy_live_signal = IPCSignal( self.worker_healthy_live_signal = IPCSignal(
name="worker_healthy_live_signal", name="worker_healthy_live_signal",
@@ -899,7 +876,7 @@ class LLMEngine(object):
create=True) create=True)
if self.do_profile: if self.do_profile:
get_profile_block_num = np.zeros([array_size], dtype=np.int32) get_profile_block_num = np.zeros([self.cfg.worker_num_per_node], dtype=np.int32)
self.get_profile_block_num_signal = IPCSignal( self.get_profile_block_num_signal = IPCSignal(
name="get_profile_block_num", name="get_profile_block_num",
array=get_profile_block_num, array=get_profile_block_num,
@@ -1028,6 +1005,7 @@ class LLMEngine(object):
arguments = ( arguments = (
f" --nnodes {str(self.cfg.nnode)}" f" --nnodes {str(self.cfg.nnode)}"
f" --ips {','.join(self.cfg.pod_ips)}"
f" --devices {self.cfg.device_ids} {py_script}" f" --devices {self.cfg.device_ids} {py_script}"
f" --max_num_seqs {self.cfg.max_num_seqs} --max_model_len {self.cfg.max_model_len}" f" --max_num_seqs {self.cfg.max_num_seqs} --max_model_len {self.cfg.max_model_len}"
f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}" f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}"
@@ -1035,6 +1013,7 @@ class LLMEngine(object):
f" --device_ids {self.cfg.device_ids}" f" --device_ids {self.cfg.device_ids}"
f" --tensor_parallel_size {self.cfg.tensor_parallel_size}" f" --tensor_parallel_size {self.cfg.tensor_parallel_size}"
f" --engine_worker_queue_port {str(self.cfg.engine_worker_queue_port)}" f" --engine_worker_queue_port {str(self.cfg.engine_worker_queue_port)}"
f" --pod_ip {self.cfg.pod_ips[0]}"
f" --total_block_num {self.cfg.cache_config.total_block_num}" f" --total_block_num {self.cfg.cache_config.total_block_num}"
f" --block_size {self.cfg.cache_config.block_size}" f" --block_size {self.cfg.cache_config.block_size}"
f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}" f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}"
@@ -1171,10 +1150,12 @@ class LLMEngine(object):
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed": if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
device_ids = self.cfg.device_ids.split(",") device_ids = self.cfg.device_ids.split(",")
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager( self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
self.cfg.cache_config, self.cfg.tensor_parallel_size, cache_config=self.cfg.cache_config,
device_ids, self.cfg.engine_worker_queue_port, tensor_parallel_size=self.cfg.tensor_parallel_size,
self.ipc_signal_suffix) device_ids=device_ids,
pod_ip=self.cfg.pod_ips[0],
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=self.ipc_signal_suffix)
def check_health(self, time_interval_threashold=30): def check_health(self, time_interval_threashold=30):
""" """
Check the health of the model server by checking whether all workers are alive. Check the health of the model server by checking whether all workers are alive.
@@ -1254,3 +1235,34 @@ class LLMEngine(object):
except Exception: except Exception:
pass pass
return True return True
def start_queue_service(self):
"""
start queue service for engine worker communication
"""
address = (self.cfg.pod_ips[0], self.cfg.engine_worker_queue_port)
if self.cfg.host_ip == self.cfg.pod_ips[0] or self.cfg.pod_ips[0] == "0.0.0.0":
self.engine_worker_queue_server = EngineWorkerQueue(
address=address,
is_server=True,
num_client=self.cfg.tensor_parallel_size,
local_data_parallel_size=self.cfg.parallel_config.
data_parallel_size)
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != 'mixed':
self.cache_task_queue = EngineCacheQueue(
address=(self.cfg.pod_ips[0], self.cfg.cache_config.cache_queue_port),
authkey=b'cache_queue_service',
is_server=True,
num_client=self.cfg.tensor_parallel_size,
client_id=-1,
local_data_parallel_size=self.cfg.parallel_config.
data_parallel_size)
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,
num_client=self.cfg.tensor_parallel_size,
client_id=0,
local_data_parallel_id=0)

View File

@@ -65,7 +65,7 @@ class ExpertService(object):
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
address = ('0.0.0.0', cfg.engine_worker_queue_port) address = (cfg.pod_ips[0], cfg.engine_worker_queue_port)
self.engine_worker_queue = EngineWorkerQueue( self.engine_worker_queue = EngineWorkerQueue(
address=address, address=address,
is_server=False, is_server=False,

View File

@@ -85,10 +85,16 @@ class LLM:
self.mutex = threading.Lock() self.mutex = threading.Lock()
self.req_output = dict() self.req_output = dict()
self.master_node_ip = self.llm_engine.cfg.pod_ips[0]
self._receive_output_thread = threading.Thread( self._receive_output_thread = threading.Thread(
target=self._receive_output, daemon=True) target=self._receive_output, daemon=True)
self._receive_output_thread.start() self._receive_output_thread.start()
def _check_master(self):
"""
Check if the current node is the master node.
"""
return self.llm_engine.cfg._check_master()
def _receive_output(self): def _receive_output(self):
""" """
@@ -130,6 +136,10 @@ class LLM:
Union[str, list[str]]: The generated response. Union[str, list[str]]: The generated response.
""" """
if not self._check_master():
err_msg = f"Only master node can accept completion request, please send request to master node: {self.master_node_ip}"
raise ValueError(err_msg)
if sampling_params is None: if sampling_params is None:
sampling_params = self.default_sampling_params sampling_params = self.default_sampling_params
@@ -182,6 +192,11 @@ class LLM:
Returns: Returns:
Union[str, list[str]]: The generated response. Union[str, list[str]]: The generated response.
""" """
if not self._check_master():
err_msg = f"Only master node can accept completion request, please send request to master node: {self.master_node_ip}"
raise ValueError(err_msg)
if sampling_params is None: if sampling_params is None:
sampling_params = self.default_sampling_params sampling_params = self.default_sampling_params

View File

@@ -120,8 +120,8 @@ async def lifespan(app: FastAPI):
args.mm_processor_kwargs, args.enable_mm, args.mm_processor_kwargs, args.enable_mm,
args.reasoning_parser) args.reasoning_parser)
app.state.dynamic_load_weight = args.dynamic_load_weight app.state.dynamic_load_weight = args.dynamic_load_weight
chat_handler = OpenAIServingChat(engine_client, pid) chat_handler = OpenAIServingChat(engine_client, pid, args.pod_ips)
completion_handler = OpenAIServingCompletion(engine_client, pid) completion_handler = OpenAIServingCompletion(engine_client, pid, args.pod_ips)
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH) engine_client.create_zmq_client(model=pid, mode=zmq.PUSH)
engine_client.pid = pid engine_client.pid = pid
app.state.engine_client = engine_client app.state.engine_client = engine_client

View File

@@ -38,9 +38,7 @@ from fastdeploy.entrypoints.openai.protocol import (
ErrorResponse, ErrorResponse,
) )
from fastdeploy.metrics.work_metrics import work_process_metrics from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.utils import api_server_logger, get_host_ip
from fastdeploy.utils import api_server_logger
from fastdeploy.engine.request import RequestOutput from fastdeploy.engine.request import RequestOutput
@@ -50,9 +48,18 @@ class OpenAIServingChat:
OpenAI-style chat completions serving OpenAI-style chat completions serving
""" """
def __init__(self, engine_client, pid): def __init__(self, engine_client, pid, pod_ips):
self.engine_client = engine_client self.engine_client = engine_client
self.pid = pid self.pid = pid
self.pod_ips = pod_ips
self.host_ip = get_host_ip()
def _check_master(self):
if self.pod_ips is None:
return True
if self.host_ip == self.pod_ips[0]:
return True
return False
async def create_chat_completion( async def create_chat_completion(
self, self,
@@ -61,6 +68,11 @@ class OpenAIServingChat:
""" """
Create a new chat completion using the specified parameters. Create a new chat completion using the specified parameters.
""" """
if not self._check_master():
err_msg = f"Only master node can accept completion request, please send request to master node: {self.pod_ips[0]}"
api_server_logger.error(err_msg)
return ErrorResponse(message=err_msg, code=400)
if request.user is not None: if request.user is not None:
request_id = f"chatcmpl-{request.user}-{uuid.uuid4()}" request_id = f"chatcmpl-{request.user}-{uuid.uuid4()}"
else: else:

View File

@@ -39,19 +39,32 @@ from fastdeploy.entrypoints.openai.protocol import (
ToolCall, ToolCall,
FunctionCall FunctionCall
) )
from fastdeploy.utils import api_server_logger from fastdeploy.utils import api_server_logger, get_host_ip
from fastdeploy.engine.request import RequestOutput from fastdeploy.engine.request import RequestOutput
class OpenAIServingCompletion: class OpenAIServingCompletion:
def __init__(self, engine_client, pid): def __init__(self, engine_client, pid, pod_ips):
self.engine_client = engine_client self.engine_client = engine_client
self.pid = pid self.pid = pid
self.pod_ips = pod_ips
self.host_ip = get_host_ip()
def _check_master(self):
if self.pod_ips is None:
return True
if self.host_ip == self.pod_ips[0]:
return True
return False
async def create_completion(self, request: CompletionRequest): async def create_completion(self, request: CompletionRequest):
""" """
Create a completion for the given prompt. Create a completion for the given prompt.
""" """
if not self._check_master():
err_msg = f"Only master node can accept completion request, please send request to master node: {self.pod_ips[0]}"
api_server_logger.error(err_msg)
return ErrorResponse(message=err_msg, code=400)
created_time = int(time.time()) created_time = int(time.time())
if request.user is not None: if request.user is not None:
request_id = f"cmpl-{request.user}-{uuid.uuid4()}" request_id = f"cmpl-{request.user}-{uuid.uuid4()}"

View File

@@ -131,8 +131,7 @@ class Worker:
rank=self.rank) rank=self.rank)
self.prefill_tracker = PrefillTracker(args.engine_pid) self.prefill_tracker = PrefillTracker(args.engine_pid)
# Only applicable for standalone (single-machine) inference address = (self.args.pod_ip, self.args.engine_worker_queue_port)
address = ('0.0.0.0', self.args.engine_worker_queue_port)
self.engine_worker_queue = EngineWorkerQueue( self.engine_worker_queue = EngineWorkerQueue(
address=address, address=address,
is_server=False, is_server=False,
@@ -324,8 +323,9 @@ class Worker:
infer_seed_increment = paddle.full(shape=[self.args.max_num_seqs, 1], infer_seed_increment = paddle.full(shape=[self.args.max_num_seqs, 1],
fill_value=4, fill_value=4,
dtype="int64") dtype="int64")
self.nnode = 1
self.nnode = int((self.nranks + 7) // 8)
mp_num_per_node = self.nranks // self.nnode
while True: while True:
if self.rank == 0: if self.rank == 0:
if self.model_weights_status_signal.value[0] != 0: if self.model_weights_status_signal.value[0] != 0:
@@ -342,7 +342,6 @@ class Worker:
self.insert_step = False self.insert_step = False
self.worker_healthy_live_signal.value[self.rank] = int(time.time()) self.worker_healthy_live_signal.value[self.rank] = int(time.time())
mp_num_per_node = self.nranks
if self.rank % mp_num_per_node == 0: if self.rank % mp_num_per_node == 0:
if self.engine_worker_queue.num_tasks( if self.engine_worker_queue.num_tasks(

View File

@@ -103,7 +103,7 @@ class PaddleDisWorkerProc():
rank=self.ranks) rank=self.ranks)
# Initialize task queue # Initialize task queue
task_address = ('0.0.0.0', task_address = (self.parallel_config.pod_ip,
self.parallel_config.engine_worker_queue_port) self.parallel_config.engine_worker_queue_port)
self.task_queue = TaskQueue( self.task_queue = TaskQueue(
@@ -218,7 +218,8 @@ class PaddleDisWorkerProc():
TODO(gongshaotian): support remote calling of functions that control worker. TODO(gongshaotian): support remote calling of functions that control worker.
""" """
# Currently, only support single node # Currently, only support single node
self.nnode = 1 self.nnode = int((self.parallel_config.tensor_parallel_degree + 7) // 8)
mp_num_per_node = self.parallel_config.tensor_parallel_degree // self.nnode
req_ids = [] req_ids = []
while True: while True:
if self.local_rank == 0: if self.local_rank == 0:
@@ -236,8 +237,7 @@ class PaddleDisWorkerProc():
time.time()) time.time())
# The first worker detects whether there are tasks in the task queue # The first worker detects whether there are tasks in the task queue
mp_num_per_node = self.ranks / self.nnode if self.local_rank % mp_num_per_node == 0:
if self.local_rank % mp_num_per_node == 0:
if self.task_queue.num_tasks() > 0: if self.task_queue.num_tasks() > 0:
if self.nnode > 1: if self.nnode > 1:
self.task_queue.read_finish_flag.set(1) self.task_queue.read_finish_flag.set(1)
@@ -412,6 +412,7 @@ def parse_args():
help="max batch size") help="max batch size")
parser.add_argument("--total_block_num", type=int, default=2000) parser.add_argument("--total_block_num", type=int, default=2000)
parser.add_argument("--block_size", type=int, default=64) parser.add_argument("--block_size", type=int, default=64)
parser.add_argument("--pod_ip", type=str, default="127.0.0.1")
parser.add_argument("--engine_worker_queue_port", type=int, default=9923) parser.add_argument("--engine_worker_queue_port", type=int, default=9923)
parser.add_argument("--max_model_len", parser.add_argument("--max_model_len",
type=int, type=int,
@@ -600,6 +601,7 @@ def initialize_fd_config(args: argparse.Namespace) -> FDConfig:
parallel_config.max_num_seqs = args.max_num_seqs parallel_config.max_num_seqs = args.max_num_seqs
parallel_config.max_block_num = args.total_block_num parallel_config.max_block_num = args.total_block_num
parallel_config.block_size = args.block_size parallel_config.block_size = args.block_size
parallel_config.pod_ip = args.pod_ip
parallel_config.engine_worker_queue_port = args.engine_worker_queue_port parallel_config.engine_worker_queue_port = args.engine_worker_queue_port
parallel_config.max_model_len = args.max_model_len parallel_config.max_model_len = args.max_model_len
model_config.max_seq_len = args.max_model_len model_config.max_seq_len = args.max_model_len