diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index ba415978b..cdd9e81d9 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -124,9 +124,19 @@ class EngineArgs: Ratio of tokens to process in a block. """ - pod_ips: Optional[List[str]] = None + dist_init_ip: Optional[str] = None """ - List of IP addresses for nodes in the cluster. + The master node ip of multinode deployment + """ + + nnodes: int = 1 + """ + The number of nodes in multinode deployment + """ + + node_rank: int = 0 + """ + The rank of the current node in multinode deployment """ swap_space: float = None @@ -485,11 +495,25 @@ class EngineArgs: # Cluster system parameters group system_group = parser.add_argument_group("System Configuration") system_group.add_argument( - "--pod-ips", - type=lambda s: s.split(",") if s else None, - default=EngineArgs.pod_ips, + "--dist-init-ip", + default=EngineArgs.dist_init_ip, help= - "List of IP addresses for nodes in the cluster (comma-separated).") + "IP addresses of master node.") + + system_group.add_argument( + "--nnodes", + type=int, + default=EngineArgs.nnodes, + help= + "The number of all nodes.") + + system_group.add_argument( + "--node-rank", + type=int, + default=EngineArgs.node_rank, + help= + "node rank id (range [0, nnodes)).") + # Performance tuning parameters group @@ -789,7 +813,9 @@ class EngineArgs: max_num_seqs=self.max_num_seqs, speculative_config=speculative_cfg, max_num_batched_tokens=self.max_num_batched_tokens, - pod_ips=self.pod_ips, + dist_init_ip=self.dist_init_ip, + nnodes=self.nnodes, + node_rank=self.node_rank, use_warmup=self.use_warmup, engine_worker_queue_port=self.engine_worker_queue_port, limit_mm_per_prompt=self.limit_mm_per_prompt, diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index 2d49aa0ce..02df10328 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -6,7 +6,7 @@ # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 -# +#dist_init_ip # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -24,7 +24,7 @@ from fastdeploy import envs from fastdeploy.platforms import current_platform from fastdeploy.scheduler import SchedulerConfig from fastdeploy.utils import (ceil_div, check_unified_ckpt, get_host_ip, - is_port_available, llm_logger) + is_port_available, get_random_port, llm_logger) TaskOption = Literal["generate"] @@ -642,7 +642,9 @@ class Config: max_model_len: int = 8192, max_num_seqs: int = 8, max_num_batched_tokens: Optional[int] = None, - pod_ips: Optional[List[str]] = None, + dist_init_ip: str = None, + nnodes: int = 1, + node_rank: int = 0, speculative_config: Optional[Dict[str, Any]] = None, graph_optimization_config: Optional[Dict[str, Any]] = None, use_warmup: bool = False, @@ -675,7 +677,6 @@ class Config: max_model_len (int): Maximum model length. Default is 8192. max_num_seqs (int): Maximum number of sequences. Default is 8. max_num_batched_tokens (Optional[int]): Maximum number of batched tokens. Default is None. - pod_ips (Optional[List[str]]): List of POD IPs. Default is None. mm_processor_kwargs (Optional[Dict[str, Any]]): Additional arguments for multi-modal processor. Default is None. speculative_config (Optional[Dict[str, Any]]): Speculative execution configuration. Default is None. graph_optimization_config (Optional[Dict[str, Any]]): Graph optimizaion backend execution configuration. Default is None. @@ -699,7 +700,16 @@ class Config: self.tokenizer = tokenizer self.max_num_batched_tokens = max_num_batched_tokens self.tensor_parallel_size = tensor_parallel_size - self.pod_ips = pod_ips + self.dist_init_ip = dist_init_ip + + self.nnode = nnodes + self.node_rank = node_rank + if self.dist_init_ip is None: + self.master_ip = "0.0.0.0" + else: + self.master_ip = self.dist_init_ip + self.dist_init_addr = f"{self.dist_init_ip}:{get_random_port()}" + self.max_model_len = max_model_len self.max_num_seqs = max_num_seqs self.limit_mm_per_prompt = limit_mm_per_prompt @@ -716,14 +726,8 @@ class Config: self.graph_optimization_config = graph_optimization_config self.guided_decoding_backend = guided_decoding_backend self.disable_any_whitespace = disable_any_whitespace - self.is_master = True self._str_to_list("innode_prefill_ports", int) - self._str_to_list("pod_ips", str) - if self.pod_ips is None: - self.nnode = 1 - else: - self.nnode = len(self.pod_ips) assert self.splitwise_role in ["mixed", "prefill", "decode"] @@ -778,9 +782,9 @@ class Config: self.host_ip = get_host_ip() - if self.pod_ips is None: - self.pod_ips = ["0.0.0.0"] - elif self.host_ip != self.pod_ips[0]: + if self.dist_init_ip is None or self.host_ip == self.master_ip: + self.is_master = True + else: self.is_master = False import paddle diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 89a4f2ca4..44cf24c38 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -174,7 +174,7 @@ class LLMEngine(object): cache_config=self.cfg.cache_config, tensor_parallel_size=self.cfg.tensor_parallel_size, device_ids=device_ids, - pod_ip=self.cfg.pod_ips[0], + pod_ip=self.cfg.master_ip, engine_worker_queue_port=self.cfg.engine_worker_queue_port, pid_suffix=self.ipc_signal_suffix) @@ -239,11 +239,12 @@ class LLMEngine(object): if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: self.dp_processed = [] - for i in range(1, self.cfg.parallel_config.data_parallel_size): + for i in range(1, self.cfg.parallel_config.data_parallel_size // self.cfg.nnode): time.sleep(1) self.dp_processed.append( multiprocessing.Process(target=start_expert_service, - args=(self.cfg, i, + args=(self.cfg, + i + self.cfg.node_rank * self.cfg.worker_num_per_node, self.ipc_signal_suffix))) llm_logger.info(f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}" \ + " data parallel id {}".format(i)) @@ -263,10 +264,11 @@ class LLMEngine(object): try: results = self.scheduler.get_results() if len(results) == 0: - time.sleep(0.001) + time.sleep(0.005) + continue for request_id, contents in results.items(): - for result in contents: - self.zmq_server.send_multipart(request_id, result) + self.zmq_server.send_multipart(request_id, contents) + except Exception as e: llm_logger.error("Unexcepted error happend: {}, {}".format( e, str(traceback.format_exc()))) @@ -1006,8 +1008,6 @@ class LLMEngine(object): ) arguments = ( - f" --nnodes {str(self.cfg.nnode)}" - f" --ips {','.join(self.cfg.pod_ips)}" f" --devices {self.cfg.device_ids} {py_script}" f" --max_num_seqs {self.cfg.max_num_seqs} --max_model_len {self.cfg.max_model_len}" f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}" @@ -1015,7 +1015,7 @@ class LLMEngine(object): f" --device_ids {self.cfg.device_ids}" f" --tensor_parallel_size {self.cfg.tensor_parallel_size}" f" --engine_worker_queue_port {str(self.cfg.engine_worker_queue_port)}" - f" --pod_ip {self.cfg.pod_ips[0]}" + f" --pod_ip {self.cfg.master_ip}" f" --total_block_num {self.cfg.cache_config.total_block_num}" f" --block_size {self.cfg.cache_config.block_size}" f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}" @@ -1056,7 +1056,11 @@ class LLMEngine(object): if value: arguments = arguments + f" --{worker_flag}" if self.cfg.nnode > 1: - pd_cmd = pd_cmd + f" --ips {self.cfg.ips}" + pd_cmd = pd_cmd + ( + f" --master {self.cfg.dist_init_addr}" + f" --nnodes {str(self.cfg.nnode)}" + f" --rank {str(self.cfg.node_rank)}" + ) pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log" llm_logger.info("Launch worker service command: {}".format(pd_cmd)) p = subprocess.Popen( @@ -1157,7 +1161,7 @@ class LLMEngine(object): cache_config=self.cfg.cache_config, tensor_parallel_size=self.cfg.tensor_parallel_size, device_ids=device_ids, - pod_ip=self.cfg.pod_ips[0], + pod_ip=self.cfg.master_ip, engine_worker_queue_port=self.cfg.engine_worker_queue_port, pid_suffix=self.ipc_signal_suffix) def check_health(self, time_interval_threashold=30): @@ -1244,8 +1248,9 @@ class LLMEngine(object): """ start queue service for engine worker communication """ - address = (self.cfg.pod_ips[0], self.cfg.engine_worker_queue_port) - if self.cfg.host_ip == self.cfg.pod_ips[0] or self.cfg.pod_ips[0] == "0.0.0.0": + address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port) + if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0": + llm_logger.info(f"Starting engine worker queue server service at {address}") self.engine_worker_queue_server = EngineWorkerQueue( address=address, is_server=True, @@ -1255,7 +1260,7 @@ class LLMEngine(object): if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != 'mixed': self.cache_task_queue = EngineCacheQueue( - address=(self.cfg.pod_ips[0], self.cfg.cache_config.cache_queue_port), + address=(self.cfg.master_ip, self.cfg.cache_config.cache_queue_port), authkey=b'cache_queue_service', is_server=True, num_client=self.cfg.tensor_parallel_size, @@ -1269,4 +1274,6 @@ class LLMEngine(object): is_server=False, num_client=self.cfg.tensor_parallel_size, client_id=0, - local_data_parallel_id=0) + local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, + local_data_parallel_id= min(self.cfg.worker_num_per_node * self.cfg.node_rank, + self.cfg.parallel_config.data_parallel_size - 1)) diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index 26d2d364c..66607da82 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -49,8 +49,8 @@ class ExpertService(object): cfg (Config): Config object containing all the configuration parameters. """ self.cfg = cfg - start_pos = local_data_parallel_id * self.cfg.tensor_parallel_size - end_pos = (local_data_parallel_id + 1) * self.cfg.tensor_parallel_size + start_pos = (local_data_parallel_id * self.cfg.tensor_parallel_size) % self.cfg.worker_num_per_node + end_pos = ((local_data_parallel_id + 1) * self.cfg.tensor_parallel_size) % self.cfg.worker_num_per_node self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[ start_pos:end_pos] self.cfg.local_device_ids = self.cfg.device_ids.split( @@ -65,7 +65,7 @@ class ExpertService(object): self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id - address = (cfg.pod_ips[0], cfg.engine_worker_queue_port) + address = (cfg.master_ip, cfg.engine_worker_queue_port) self.engine_worker_queue = EngineWorkerQueue( address=address, is_server=False, @@ -118,7 +118,7 @@ class ExpertService(object): cache_config=self.cfg.cache_config, tensor_parallel_size=self.cfg.tensor_parallel_size, device_ids=self.cfg.local_device_ids, - pod_ip=self.cfg.pod_ips[0], + pod_ip=self.cfg.master_ip, engine_worker_queue_port=self.cfg.engine_worker_queue_port, pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}" ) diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 6520f9f47..0c0851540 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -20,7 +20,7 @@ import time from dataclasses import asdict, dataclass, fields from typing import Any, Dict, Optional, Union -import numpy +import numpy as np from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.utils import data_processor_logger @@ -181,7 +181,7 @@ class Request: f"sampling_params={self.sampling_params})") -@dataclass +@dataclass(slots=True) class CompletionOutput: """The output data of one completion output of a request. @@ -235,7 +235,7 @@ class CompletionOutput: f"reasoning_content={self.reasoning_content!r}") -@dataclass +@dataclass(slots=True) class RequestMetrics: """Metrics associated with a request. @@ -310,6 +310,10 @@ class RequestOutput: None if decoder-only. num_cached_tokens: The number of tokens with prefix cache hit. """ + __slots__ = ( + 'request_id', 'prompt', 'prompt_token_ids', 'outputs', + 'finished', 'metrics', 'num_cached_tokens', 'error_code', 'error_msg' + ) def __init__( self, @@ -333,6 +337,12 @@ class RequestOutput: self.error_code = error_code self.error_msg = error_msg + + if prompt_token_ids is None: + self.prompt_token_ids = [] + elif isinstance(self.prompt_token_ids, np.ndarray): + self.prompt_token_ids = self.prompt_token_ids.tolist() + def add(self, next_output: "RequestOutput") -> None: """Merge RequestOutput into this one""" @@ -365,11 +375,6 @@ class RequestOutput: def to_dict(self): """convert RequestOutput into a serializable dict """ - if self.prompt_token_ids is None: - self.prompt_token_ids = [] - - if type(self.prompt_token_ids) is numpy.ndarray: - self.prompt_token_ids = self.prompt_token_ids.tolist() return { "request_id": self.request_id, diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py index 6c0ce4997..5601a7e4c 100644 --- a/fastdeploy/entrypoints/llm.py +++ b/fastdeploy/entrypoints/llm.py @@ -85,7 +85,7 @@ class LLM: self.mutex = threading.Lock() self.req_output = dict() - self.master_node_ip = self.llm_engine.cfg.pod_ips[0] + self.master_node_ip = self.llm_engine.cfg.master_ip self._receive_output_thread = threading.Thread( target=self._receive_output, daemon=True) self._receive_output_thread.start() @@ -169,6 +169,8 @@ class LLM: # get output outputs = self._run_engine(req_ids, use_tqdm=use_tqdm) + for i in range(len(outputs)): + outputs[i].prompt = prompts[i] return outputs def chat( diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 3a2ee5e72..5061d60cf 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -122,8 +122,8 @@ async def lifespan(app: FastAPI): args.mm_processor_kwargs, args.enable_mm, args.reasoning_parser) app.state.dynamic_load_weight = args.dynamic_load_weight - chat_handler = OpenAIServingChat(engine_client, pid, args.pod_ips) - completion_handler = OpenAIServingCompletion(engine_client, pid, args.pod_ips) + chat_handler = OpenAIServingChat(engine_client, pid, args.dist_init_ip) + completion_handler = OpenAIServingCompletion(engine_client, pid, args.dist_init_ip) engine_client.create_zmq_client(model=pid, mode=zmq.PUSH) engine_client.pid = pid app.state.engine_client = engine_client diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index fcaee9f9c..d55545428 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -21,6 +21,7 @@ import traceback import uuid from typing import List, Optional +import msgpack import aiozmq from aiozmq import zmq @@ -39,16 +40,16 @@ class OpenAIServingChat: OpenAI-style chat completions serving """ - def __init__(self, engine_client, pid, pod_ips): + def __init__(self, engine_client, pid, dist_init_ip): self.engine_client = engine_client self.pid = pid - self.pod_ips = pod_ips + self.master_ip = dist_init_ip self.host_ip = get_host_ip() def _check_master(self): - if self.pod_ips is None: + if self.master_ip is None: return True - if self.host_ip == self.pod_ips[0]: + if self.host_ip == self.master_ip: return True return False @@ -143,6 +144,8 @@ class OpenAIServingChat: dealer.write([b"", request_id.encode('utf-8')]) choices = [] current_waiting_time = 0 + if request.metadata is not None: + enable_thinking = request.metadata.get("enable_thinking") while num_choices > 0: try: raw_data = await asyncio.wait_for(dealer.read(), timeout=10) @@ -158,102 +161,106 @@ class OpenAIServingChat: raise ValueError(f"Engine is not healthy: {msg}") else: current_waiting_time = 0 - await asyncio.sleep(0.1) + await asyncio.sleep(0.01) continue + response = msgpack.unpackb(raw_data[-1]) + for res in response: + if res.get("error_code", 200) != 200: + raise ValueError("{}".format(res["error_msg"])) - res = json.loads(raw_data[-1].decode('utf-8')) - if res.get("error_code", 200) != 200: - raise ValueError("{}".format(res["error_msg"])) - if request.metadata is not None: - enable_thinking = request.metadata.get("enable_thinking") - self.engine_client.data_processor.process_response_dict( - res, stream=True, enable_thinking=enable_thinking) + self.engine_client.data_processor.process_response_dict( + res, stream=True, enable_thinking=enable_thinking) - if res['metrics']['first_token_time'] is not None: - arrival_time = res['metrics']['first_token_time'] - inference_start_time = res['metrics']['inference_start_time'] - else: - arrival_time = res['metrics']['arrival_time'] - inference_start_time - if first_iteration: - num_prompt_tokens = len(prompt_token_ids) - num_cached_tokens = res.get("num_cached_tokens", 0) - for i in range(num_choices): - choice = ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage(role="assistant", content="", reasoning_content="", tool_calls=None) - ) - if request.metadata is not None and request.metadata.get("training", False): - choice.delta.token_ids = prompt_token_ids - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice], - model=model_name - ) - if include_continuous_usage: - chunk.usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=0, - total_tokens=num_prompt_tokens, - prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cached_tokens) - ) - yield f"data: {chunk.model_dump_json(exclude_unset=True)} \n\n" - first_iteration = False - - output = res["outputs"] - delta_text = output["text"] - raw_top_logprobs = output["top_logprobs"] - logprobs_res = None - if raw_top_logprobs is not None: - top_logprobs = LogprobsLists( - logprob_token_ids=raw_top_logprobs[0], - logprobs=raw_top_logprobs[1], - sampled_token_ranks=raw_top_logprobs[2], - ) - logprobs_res = self.build_logprobs_response( - request_logprobs=request.logprobs, - response_logprobs=top_logprobs, - request_top_logprobs=request.top_logprobs, - ) - - previous_num_tokens += len(output["token_ids"]) - delta_message = DeltaMessage(content=delta_text, reasoning_content=output.get("reasoning_content"), \ - token_ids=output.get("token_ids"), tool_calls=output.get("tool_call_content", [])) - - choice = ChatCompletionResponseStreamChoice( - index=0, - delta=delta_message, - logprobs=logprobs_res, - arrival_time=arrival_time - ) - if res["finished"]: - num_choices -= 1 - work_process_metrics.e2e_request_latency.observe(time.time() - res["metrics"]["request_start_time"]) - has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None - max_tokens = request.max_completion_tokens or request.max_tokens - if has_no_token_limit or previous_num_tokens != max_tokens: - choice.finish_reason = "stop" - if self.engine_client.reasoning_parser == "ernie_x1" and \ - output.get("finish_reason", "") == "tool_calls": - choice.finish_reason = "tool_calls" + if res['metrics']['first_token_time'] is not None: + arrival_time = res['metrics']['first_token_time'] + inference_start_time = res['metrics']['inference_start_time'] else: - choice.finish_reason = "length" + arrival_time = res['metrics']['arrival_time'] - inference_start_time + if first_iteration: + num_prompt_tokens = len(prompt_token_ids) + num_cached_tokens = res.get("num_cached_tokens", 0) + for i in range(num_choices): + choice = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(role="assistant", content="", reasoning_content="", tool_calls=None) + ) + if request.metadata is not None and request.metadata.get("training", False): + choice.delta.token_ids = prompt_token_ids + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice], + model=model_name + ) + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens, + prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cached_tokens) + ) + yield f"data: {chunk.model_dump_json(exclude_unset=True)} \n\n" + first_iteration = False - if res.get("error_msg") is not None and "Recover" in res["error_msg"]: - choice.finish_reason = "recover_stop" + output = res["outputs"] + delta_text = output["text"] + raw_top_logprobs = output["top_logprobs"] + logprobs_res = None + if raw_top_logprobs is not None: + top_logprobs = LogprobsLists( + logprob_token_ids=raw_top_logprobs[0], + logprobs=raw_top_logprobs[1], + sampled_token_ranks=raw_top_logprobs[2], + ) + logprobs_res = self.build_logprobs_response( + request_logprobs=request.logprobs, + response_logprobs=top_logprobs, + request_top_logprobs=request.top_logprobs, + ) - if request.metadata is not None and request.metadata.get("training", False) and delta_text != "": - choice.delta.token_ids = output["token_ids"] - if include_continuous_usage: - chunk.usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=previous_num_tokens, - total_tokens=num_prompt_tokens + previous_num_tokens + previous_num_tokens += len(output["token_ids"]) + delta_message = DeltaMessage(content=delta_text, reasoning_content=output.get("reasoning_content"), \ + token_ids=output.get("token_ids"), tool_calls=output.get("tool_call_content", [])) + + choice = ChatCompletionResponseStreamChoice( + index=0, + delta=delta_message, + logprobs=logprobs_res, + arrival_time=arrival_time ) - choices.append(choice) + if res["finished"]: + num_choices -= 1 + work_process_metrics.e2e_request_latency.observe(time.time() - res["metrics"]["request_start_time"]) + has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None + max_tokens = request.max_completion_tokens or request.max_tokens + if has_no_token_limit or previous_num_tokens != max_tokens: + choice.finish_reason = "stop" + if self.engine_client.reasoning_parser == "ernie_x1" and \ + output.get("finish_reason", "") == "tool_calls": + choice.finish_reason = "tool_calls" + else: + choice.finish_reason = "length" - if len(choices) == max_streaming_response_tokens or res["finished"]: + if res.get("error_msg") is not None and "Recover" in res["error_msg"]: + choice.finish_reason = "recover_stop" + + if request.metadata is not None and request.metadata.get("training", False) and delta_text != "": + choice.delta.token_ids = output["token_ids"] + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=previous_num_tokens, + total_tokens=num_prompt_tokens + previous_num_tokens + ) + choices.append(choice) + + if len(choices) == max_streaming_response_tokens or res["finished"]: + chunk.choices = choices + yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + choices = [] + + if choices: chunk.choices = choices yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" choices = [] @@ -321,33 +328,38 @@ class OpenAIServingChat: await asyncio.sleep(0.1) continue - data = json.loads(raw_data[-1].decode('utf-8')) - if data.get("error_code", 200) != 200: - raise ValueError("{}".format(data["error_msg"])) - if request.metadata is not None: - enable_thinking = request.metadata.get("enable_thinking") - data = self.engine_client.data_processor.process_response_dict( - data, stream=False, enable_thinking=enable_thinking) - # api_server_logger.debug(f"Client {request_id} received: {data}") - previous_num_tokens += len(data["outputs"]["token_ids"]) - # The logprob for handling the response - output = data["outputs"] - raw_top_logprobs = output["top_logprobs"] - if raw_top_logprobs is not None: - top_logprobs = LogprobsLists( - logprob_token_ids=raw_top_logprobs[0], - logprobs=raw_top_logprobs[1], - sampled_token_ranks=raw_top_logprobs[2], - ) - logprobs_res = self.build_logprobs_response( - request_logprobs=request.logprobs, - response_logprobs=top_logprobs, - request_top_logprobs=request.top_logprobs, - ) - if logprobs_res and logprobs_res.content is not None: - logprob_contents.extend(logprobs_res.content) - if data["finished"]: - final_res = data + response = msgpack.unpackb(raw_data[-1]) + task_is_finished = False + for data in response: + if data.get("error_code", 200) != 200: + raise ValueError("{}".format(data["error_msg"])) + if request.metadata is not None: + enable_thinking = request.metadata.get("enable_thinking") + data = self.engine_client.data_processor.process_response_dict( + data, stream=False, enable_thinking=enable_thinking) + # api_server_logger.debug(f"Client {request_id} received: {data}") + previous_num_tokens += len(data["outputs"]["token_ids"]) + # The logprob for handling the response + output = data["outputs"] + raw_top_logprobs = output["top_logprobs"] + if raw_top_logprobs is not None: + top_logprobs = LogprobsLists( + logprob_token_ids=raw_top_logprobs[0], + logprobs=raw_top_logprobs[1], + sampled_token_ranks=raw_top_logprobs[2], + ) + logprobs_res = self.build_logprobs_response( + request_logprobs=request.logprobs, + response_logprobs=top_logprobs, + request_top_logprobs=request.top_logprobs, + ) + if logprobs_res and logprobs_res.content is not None: + logprob_contents.extend(logprobs_res.content) + if data["finished"]: + final_res = data + task_is_finished = True + break + if task_is_finished: break finally: dealer.close() diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 0c9bf6424..acefc3d17 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -17,6 +17,7 @@ import asyncio import aiozmq import json +import msgpack from aiozmq import zmq from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task import time @@ -44,16 +45,16 @@ from fastdeploy.engine.request import RequestOutput class OpenAIServingCompletion: - def __init__(self, engine_client, pid, pod_ips): + def __init__(self, engine_client, pid, dist_init_ip): self.engine_client = engine_client self.pid = pid - self.pod_ips = pod_ips + self.master_ip = dist_init_ip self.host_ip = get_host_ip() def _check_master(self): - if self.pod_ips is None: + if self.master_ip is None: return True - if self.host_ip == self.pod_ips[0]: + if self.host_ip == self.master_ip: return True return False @@ -179,18 +180,20 @@ class OpenAIServingCompletion: current_waiting_time = 0 await asyncio.sleep(0.1) continue - data = json.loads(raw_data[-1].decode("utf-8")) - rid = int(data["request_id"].split("-")[-1]) - if data.get("error_code", 200) != 200: - raise ValueError("{}".format(data["error_msg"])) + response = msgpack.unpackb(raw_data[-1]) + for data in response: + rid = int(data["request_id"].split("-")[-1]) + if data.get("error_code", 200) != 200: + raise ValueError("{}".format(data["error_msg"])) - self.engine_client.data_processor.process_response_dict( - data, stream=False) - output_tokens[rid] += len(data["outputs"]["token_ids"]) - if data.get("finished", False): - data["output_token_ids"] = output_tokens[rid] - valid_results[rid] = data - num_choices -= 1 + self.engine_client.data_processor.process_response_dict( + data, stream=False) + output_tokens[rid] += len(data["outputs"]["token_ids"]) + if data.get("finished", False): + data["output_token_ids"] = output_tokens[rid] + valid_results[rid] = data + num_choices -= 1 + break return self.request_output_to_completion_response( final_res_batch=valid_results, @@ -238,6 +241,12 @@ class OpenAIServingCompletion: if request.suffix is not None and request.suffix.get("max_streaming_response_tokens", 1) > 1: max_streaming_response_tokens = request.suffix["max_streaming_response_tokens"] choices = [] + chunk = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices + ) current_waiting_time = 0 while num_choices > 0: @@ -256,82 +265,86 @@ class OpenAIServingCompletion: continue - res = json.loads(raw_data[-1].decode('utf-8')) - idx = int(res["request_id"].split("-")[-1]) - if res.get("error_code", 200) != 200: - raise ValueError("{}".format(res["error_msg"])) + response = msgpack.unpackb(raw_data[-1]) + for res in response: + idx = int(res["request_id"].split("-")[-1]) + if res.get("error_code", 200) != 200: + raise ValueError("{}".format(res["error_msg"])) - if first_iteration[idx]: - if request.suffix is not None and request.suffix.get("training", False): + if first_iteration[idx]: + if request.suffix is not None and request.suffix.get("training", False): + chunk = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[CompletionResponseStreamChoice( + index=idx, + text="", + token_ids=list(prompt_batched_token_ids[idx]) + )] + ) + yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + first_iteration[idx] = False + + + self.engine_client.data_processor.process_response_dict( + res, stream=True) + if res['metrics'].get('first_token_time') is not None: + arrival_time = res['metrics']['first_token_time'] + inference_start_time[idx] = res['metrics']['inference_start_time'] + else: + arrival_time = res['metrics']['arrival_time'] - inference_start_time[idx] + + output = res["outputs"] + + choices.append(CompletionResponseStreamChoice( + index=idx, + text=output["text"], + token_ids=output.get("token_ids"), + tool_calls=output.get("tool_call_content"), + reasoning_content=output.get("reasoning_content"), + arrival_time=arrival_time + )) + if res["finished"]: + if request.max_tokens is None or output_tokens[idx] + 1 != request.max_tokens: + chunk.choices[0].finish_reason = "stop" + if self.engine_client.reasoning_parser == "ernie_x1" and \ + output.get("finish_reason", "") == "tool_calls": + chunk.choices[0].finish_reason = "tool_calls" + else: + chunk.choices[0].finish_reason = "length" + + output_tokens[idx] += 1 + + if len(choices) == max_streaming_response_tokens or res["finished"]: chunk = CompletionStreamResponse( id=request_id, created=created_time, model=model_name, - choices=[CompletionResponseStreamChoice( - index=idx, - text="", - token_ids=list(prompt_batched_token_ids[idx]) - )] + choices=choices ) yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" - first_iteration[idx] = False + choices = [] - self.engine_client.data_processor.process_response_dict( - res, stream=True) - if res['metrics'].get('first_token_time') is not None: - arrival_time = res['metrics']['first_token_time'] - inference_start_time[idx] = res['metrics']['inference_start_time'] - else: - arrival_time = res['metrics']['arrival_time'] - inference_start_time[idx] - # api_server_logger.info(f"{arrival_time}") - - output = res["outputs"] - - choices.append(CompletionResponseStreamChoice( - index=idx, - text=output["text"], - token_ids=output.get("token_ids"), - tool_calls=output.get("tool_call_content"), - reasoning_content=output.get("reasoning_content"), - arrival_time=arrival_time - )) - if res["finished"]: - if request.max_tokens is None or output_tokens[idx] + 1 != request.max_tokens: - chunk.choices[0].finish_reason = "stop" - if self.engine_client.reasoning_parser == "ernie_x1" and \ - output.get("finish_reason", "") == "tool_calls": - chunk.choices[0].finish_reason = "tool_calls" - else: - chunk.choices[0].finish_reason = "length" - - output_tokens[idx] += 1 - - if len(choices) == max_streaming_response_tokens or res["finished"]: - chunk = CompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=choices - ) - choices = [] - - yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" - - if res["finished"]: - num_choices -= 1 - if getattr(request, "stream_options", None) and request.stream_options.include_usage: - usage_chunk = CompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=[], - usage=UsageInfo( - prompt_tokens=len(prompt_batched_token_ids[idx]), - completion_tokens=output_tokens[idx] + if res["finished"]: + num_choices -= 1 + if getattr(request, "stream_options", None) and request.stream_options.include_usage: + usage_chunk = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[], + usage=UsageInfo( + prompt_tokens=len(prompt_batched_token_ids[idx]), + completion_tokens=output_tokens[idx] + ) ) - ) - yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n" + yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n" + if choices: + chunk.choices = choices + yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + choices = [] except Exception as e: diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 97d89f9f5..245647010 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -101,6 +101,10 @@ environment_variables: dict[str, Callable[[], Any]] = { # Whether to use DeepGemm for FP8 blockwise MoE. "FD_USE_DEEP_GEMM": lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))), + + # Whether to use aggregate send. + "FD_USE_AGGREGATE_SEND": + lambda: bool(int(os.getenv("FD_USE_AGGREGATE_SEND", "0"))), } diff --git a/fastdeploy/inter_communicator/zmq_client.py b/fastdeploy/inter_communicator/zmq_client.py index adc4555a2..115331c32 100644 --- a/fastdeploy/inter_communicator/zmq_client.py +++ b/fastdeploy/inter_communicator/zmq_client.py @@ -20,6 +20,7 @@ import threading import time import zmq +import msgpack from fastdeploy import envs from fastdeploy.utils import llm_logger @@ -37,6 +38,7 @@ class ZmqClient: self.router_path = f"/dev/shm/router_{name}.ipc" self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) + self.aggregate_send = envs.FD_USE_AGGREGATE_SEND self.mutex = threading.Lock() self.req_dict = dict() @@ -93,6 +95,16 @@ class ZmqClient: """ return self.socket.recv_pyobj() + def pack_aggregated_data(self, data): + """ + Aggregate multiple responses into one and send them to the client. + """ + result = data[0] + if len(data) > 1: + for response in data[1:]: + result.add(response) + result = msgpack.packb([result.to_dict()]) + return result def send_multipart(self, req_id, data): """ Send a multipart message to the router socket. @@ -116,14 +128,22 @@ class ZmqClient: break try: - result = json.dumps(data.to_dict()).encode('utf-8') + start_send = time.time() + if self.aggregate_send: + result = self.pack_aggregated_data(data) + else: + result = msgpack.packb([response.to_dict() for response in data]) self.router.send_multipart([self.req_dict[req_id], b'', result]) + llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}") + except Exception as e: llm_logger.error(f"Send result to zmq client failed: {e}") - if data.finished: + if data[-1].finished: with self.mutex: - self.req_dict.pop(data.request_id, None) + self.req_dict.pop(req_id, None) + llm_logger.info(f"send_multipart finished, req_id: {req_id}") + def receive_json_once(self, block=False): """ diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index e1efeaa7b..0647b269b 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -505,8 +505,6 @@ class TokenProcessor(object): result.outputs.token_ids.append(token_id) if token_id in task.eos_token_ids or is_prefill or recovery_stop: result.finished = True - result.prompt = task.prompt - result.prompt_token_ids = task.prompt_token_ids if recovery_stop: result.error_msg = "Recover is not supported, the result is incomplete!" llm_logger.info( diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 0316779b4..79ee65b77 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -27,7 +27,8 @@ from datetime import datetime from logging.handlers import BaseRotatingHandler from pathlib import Path from typing import Literal, TypeVar, Union - +import random +import socket import requests import yaml from aistudio_sdk.snapshot_download import snapshot_download @@ -421,6 +422,19 @@ def get_host_ip(): return ip + + +def get_random_port(): + while True: + port = random.randint(49152, 65535) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("0.0.0.0", port)) + return port + except OSError: + continue + + def is_port_available(host, port): """ Check the port is available diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index 70e596359..18c1b4302 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -23,6 +23,7 @@ import pynvml from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request +from fastdeploy.platforms import current_platform from fastdeploy.utils import get_logger from fastdeploy.worker.gpu_model_runner import GPUModelRunner from fastdeploy.worker.output import ModelRunnerOutput @@ -50,11 +51,12 @@ class GpuWorker(WorkerBase): """ Initialize device and construct model runner """ + self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 if self.device_config.device_type == "cuda" and paddle.device.is_compiled_with_cuda( ): # Set evironment variable self.device_ids = self.parallel_config.device_ids.split(",") - self.device = f"gpu:{self.local_rank}" + self.device = f"gpu:{self.local_rank % self.max_chips_per_node}" paddle.device.set_device(self.device) paddle.set_default_dtype(self.parallel_config.dtype) @@ -72,7 +74,7 @@ class GpuWorker(WorkerBase): self.model_runner: GPUModelRunner = GPUModelRunner( fd_config=self.fd_config, device=self.device, - device_id=self.device_ids[self.local_rank], + device_id=self.device_ids[self.local_rank % self.max_chips_per_node], rank=self.rank, local_rank=self.local_rank) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 8775c5de2..99504008c 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -136,9 +136,9 @@ class PaddleDisWorkerProc(): model_weights_status: """ # init worker_ready_signal - max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 + self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 array_size = min( - max_chips_per_node, self.parallel_config.tensor_parallel_size * + self.max_chips_per_node, self.parallel_config.tensor_parallel_size * self.parallel_config.expert_parallel_size) workers_ready = np.zeros(shape=[array_size], dtype=np.int32) self.worker_ready_signal = IPCSignal( @@ -148,10 +148,10 @@ class PaddleDisWorkerProc(): suffix=self.parallel_config.engine_pid, create=False) self.worker_ready_signal.value[self.local_rank % - max_chips_per_node] = 1 + self.max_chips_per_node] = 1 # init worker_healthy_live_signal - workers_alive = np.zeros(shape=[self.ranks], dtype=np.int32) + workers_alive = np.zeros(shape=[array_size], dtype=np.int32) self.worker_healthy_live_signal = IPCSignal( name="worker_healthy_live_signal", array=workers_alive, @@ -205,7 +205,7 @@ class PaddleDisWorkerProc(): Tmp loop function for ep utill DP is supported """ while True: - self.worker_healthy_live_signal.value[self.local_rank] = int( + self.worker_healthy_live_signal.value[self.local_rank % self.max_chips_per_node] = int( time.time()) if self.fd_config.parallel_config.tensor_parallel_rank == 0 and self.task_queue.num_tasks( diff --git a/requirements.txt b/requirements.txt index 1432d9c1f..f5a562254 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,9 +29,11 @@ triton==3.3 use-triton-in-paddle crcmod fastsafetensors==0.1.14 +msgpack opentelemetry-api>=1.24.0 opentelemetry-sdk>=1.24.0 opentelemetry-instrumentation-redis opentelemetry-instrumentation-mysql opentelemetry-distro  opentelemetry-exporter-otlp + diff --git a/requirements_dcu.txt b/requirements_dcu.txt index 75d549a83..7e6d524a9 100644 --- a/requirements_dcu.txt +++ b/requirements_dcu.txt @@ -27,3 +27,4 @@ moviepy use-triton-in-paddle crcmod fastsafetensors==0.1.14 +msgpack \ No newline at end of file diff --git a/requirements_iluvatar.txt b/requirements_iluvatar.txt index 75d549a83..14d2d42dd 100644 --- a/requirements_iluvatar.txt +++ b/requirements_iluvatar.txt @@ -27,3 +27,4 @@ moviepy use-triton-in-paddle crcmod fastsafetensors==0.1.14 +msgpack