diff --git a/custom_ops/gpu_ops/remote_cache_kv_ipc.h b/custom_ops/gpu_ops/remote_cache_kv_ipc.h index 3c09af1e4..759e1d650 100644 --- a/custom_ops/gpu_ops/remote_cache_kv_ipc.h +++ b/custom_ops/gpu_ops/remote_cache_kv_ipc.h @@ -18,88 +18,94 @@ #include #include #include +#include #include +#include #include #include -#include -#include #include #include "driver_types.h" +#include "msg_utils.h" #include "paddle/extension.h" #include "paddle/phi/core/allocator.h" #include "paddle/phi/core/dense_tensor.h" -#include "msg_utils.h" struct RemoteCacheKvIpc { - struct save_cache_kv_complete_signal_layerwise_meta_data{ - int32_t layer_id=-1; - void * shm_ptr=nullptr; - int shm_fd=-1; - save_cache_kv_complete_signal_layerwise_meta_data(){} - save_cache_kv_complete_signal_layerwise_meta_data(int32_t layer_id_, - void* shm_ptr_, - int shm_fd_) - :layer_id(layer_id_), shm_ptr(shm_ptr_), shm_fd(shm_fd_){ + struct save_cache_kv_complete_signal_layerwise_meta_data { + int32_t layer_id = -1; + void* shm_ptr = nullptr; + int shm_fd = -1; + save_cache_kv_complete_signal_layerwise_meta_data() {} + save_cache_kv_complete_signal_layerwise_meta_data(int32_t layer_id_, + void* shm_ptr_, + int shm_fd_) + : layer_id(layer_id_), shm_ptr(shm_ptr_), shm_fd(shm_fd_) {} + }; + + struct save_cache_kv_complete_signal_layerwise_meta_data_per_query { + int layer_id_; + int num_layers_; + bool inited = false; + struct msgdatakv msg_sed; + int msgid; + + save_cache_kv_complete_signal_layerwise_meta_data_per_query() {} + + void init(const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int rank, + const int num_layers, + const int real_bsz) { + layer_id_ = 0; + num_layers_ = num_layers; + msg_sed.mtype = 1; + int encoder_count = 0; + for (int i = 0; i < real_bsz; i++) { + if (seq_lens_encoder[i] > 0) { + msg_sed.mtext[3 * encoder_count + 2] = i; + msg_sed.mtext[3 * encoder_count + 3] = seq_lens_decoder[i]; + msg_sed.mtext[3 * encoder_count + 4] = seq_lens_encoder[i]; + encoder_count++; } - }; + } + msg_sed.mtext[0] = encoder_count; - struct save_cache_kv_complete_signal_layerwise_meta_data_per_query{ - int layer_id_; - int num_layers_; - bool inited = false; - struct msgdatakv msg_sed; - int msgid; + if (!inited) { + // just init once + const int msg_id = 1024 + rank; + key_t key = ftok("/opt/", msg_id); + msgid = msgget(key, IPC_CREAT | 0666); + inited = true; + } + } - save_cache_kv_complete_signal_layerwise_meta_data_per_query(){} - - void init(const int *seq_lens_encoder, - const int *seq_lens_decoder, - const int rank, - const int num_layers, - const int real_bsz) { - layer_id_ = 0; - num_layers_ = num_layers; - msg_sed.mtype = 1; - int encoder_count = 0; - for (int i = 0; i < real_bsz; i++) { - if (seq_lens_encoder[i] > 0) { - msg_sed.mtext[3 * encoder_count + 2] = i; - msg_sed.mtext[3 * encoder_count + 3] = seq_lens_decoder[i]; - msg_sed.mtext[3 * encoder_count + 4] = seq_lens_encoder[i]; - encoder_count++; - } - } - msg_sed.mtext[0] = encoder_count; - - if (!inited) { - // just init once - const int msg_id = 1024 + rank; - key_t key = ftok("/opt/", msg_id); - msgid = msgget(key, IPC_CREAT | 0666); - inited = true; - } + void CUDART_CB send_signal() { + if (inited) { + msg_sed.mtext[1] = layer_id_; + if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 3 + 2) * 4, 0)) == -1) { + printf("kv signal full msg buffer\n"); } + layer_id_ = (layer_id_ + 1); + assert(layer_id_ <= num_layers_); + } + } + }; - void CUDART_CB send_signal() { - msg_sed.mtext[1] = layer_id_; - if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 3 + 2) * 4, 0)) == -1) { - printf("kv signal full msg buffer\n"); - } - layer_id_ = (layer_id_ + 1); - assert(layer_id_ <= num_layers_); - } - }; + static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data + kv_complete_signal_meta_data; + static RemoteCacheKvIpc:: + save_cache_kv_complete_signal_layerwise_meta_data_per_query + kv_complete_signal_meta_data_per_query; + static void* kv_complete_signal_identity_ptr; + static bool kv_complete_signal_shmem_opened; - static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data kv_complete_signal_meta_data; - static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data_per_query kv_complete_signal_meta_data_per_query; - static void* kv_complete_signal_identity_ptr; - static bool kv_complete_signal_shmem_opened; - - static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data open_shm_and_get_complete_signal_meta_data( - const int rank_id, - const int device_id, - const bool keep_pd_step_flag); - static void CUDART_CB save_cache_kv_complete_signal_layerwise(void* meta_data); - static void CUDART_CB save_cache_kv_complete_signal_layerwise_per_query(void* meta_data); + static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data + open_shm_and_get_complete_signal_meta_data(const int rank_id, + const int device_id, + const bool keep_pd_step_flag); + static void CUDART_CB + save_cache_kv_complete_signal_layerwise(void* meta_data); + static void CUDART_CB + save_cache_kv_complete_signal_layerwise_per_query(void* meta_data); }; diff --git a/custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.h b/custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.h index 1cc4531c6..4835e2a82 100644 --- a/custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.h +++ b/custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.h @@ -72,12 +72,14 @@ struct RemoteCacheKvIpc { } void send_signal() { - msg_sed.mtext[1] = layer_id_; - if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 3 + 2) * 4, 0)) == -1) { - printf("kv signal full msg buffer\n"); + if (inited) { + msg_sed.mtext[1] = layer_id_; + if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 3 + 2) * 4, 0)) == -1) { + printf("kv signal full msg buffer\n"); + } + layer_id_ = (layer_id_ + 1); + assert(layer_id_ <= num_layers_); } - layer_id_ = (layer_id_ + 1); - assert(layer_id_ <= num_layers_); } }; diff --git a/examples/splitwise/start_v1_tp1.sh b/examples/splitwise/start_v1_tp1.sh index 31eca8ab7..5ccebc1ed 100644 --- a/examples/splitwise/start_v1_tp1.sh +++ b/examples/splitwise/start_v1_tp1.sh @@ -68,7 +68,6 @@ nohup python -m fastdeploy.entrypoints.openai.api_server \ --cache-transfer-protocol "rdma" \ --rdma-comm-ports "$((P_PORT + 4))" \ --pd-comm-port "$((P_PORT + 5))" \ - --num-gpu-blocks-override 2000 \ --router "0.0.0.0:${ROUTER_PORT}" \ 2>&1 >${FD_LOG_DIR}/nohup & diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index dc3d64099..feb54f481 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -687,8 +687,8 @@ class CacheMessagerV1: for engine_idx, _ in batch_engine_signals: task = self.idx_cache_task_dict[engine_idx] if task["status"] == "finished" or ("error" in task["status"]): - target_id = int(task["rdma_ports"][self.rank]) if task["transfer_protocol"] == "ipc": + target_id = int(task["device_ids"][self.rank]) self.messager["ipc"].write_block_by_sync(target_id) self.engine_worker_queue.finish_send_cache_barrier.wait() self.engine_worker_queue.put_finished_req([[task["request_id"], task["status"]]]) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 17a9944ff..86e78c5e8 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -517,18 +517,6 @@ class EngineArgs: f"The number of rdma comm ports must be equal to number of ranks ({self.data_parallel_size=} * {self.tensor_parallel_size=} = {self.data_parallel_size * self.tensor_parallel_size}), but got {len(self.rdma_comm_ports)}." ) - if envs.ENABLE_V1_KVCACHE_SCHEDULER == 1: - if "ipc" in self.cache_transfer_protocol: - # FIXME: support ipc cache transfer protocol - raise NotImplementedError( - "only support rdma cache transfer protocol " "when using ENABLE_V1_KVCACHE_SCHEDULER." - ) - # FIXME: fix this bug - if self.splitwise_role == "prefill" and self.num_gpu_blocks_override is None: - raise NotImplementedError( - "please set num_gpu_blocks_override for prefill " "instance using ENABLE_V1_KVCACHE_SCHEDULER." - ) - if not current_platform.is_cuda() and not current_platform.is_xpu(): envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 if self.guided_decoding_backend != "off": diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 8ce4ac909..17a650253 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -1001,7 +1001,7 @@ class ResourceManagerV1(ResourceManager): request.need_prefill_tokens + self.config.cache_config.block_size - 1 ) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks): - request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks)) + request.block_tables = self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks) request.num_computed_tokens = request.need_prefill_tokens request.disaggregate_info["block_tables"] = request.block_tables allocated_position = self.get_available_position() diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 26889efcd..b945b1b63 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -140,6 +140,8 @@ class ForwardMeta: block_tables: Optional[paddle.Tensor] = None # KV caches caches: Optional[list[paddle.Tensor]] = None + # Flag of profile run + is_dummy_or_profile_run: bool = False def clear_caches(self): """Safely clean up the caches""" diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index cae006e96..346251a30 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -178,7 +178,7 @@ class AppendAttentionBackend(AttentionBackend): # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers if self.pd_disaggregation_mode == "per_chunk": - if not self.keep_pd_step_flag: + if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: init_kv_signal_per_query( forward_meta.seq_lens_encoder, forward_meta.seq_lens_this_time, diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 31d6d7488..b39ab0a88 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -231,7 +231,7 @@ class FlashAttentionBackend(AttentionBackend): # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers if self.pd_disaggregation_mode == "per_chunk": - if not self.keep_pd_step_flag: + if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: init_kv_signal_per_query( forward_meta.seq_lens_encoder, forward_meta.seq_lens_this_time, diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 8df65d39d..cda5684e6 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -214,7 +214,7 @@ class MLAAttentionBackend(AttentionBackend): # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers if self.pd_disaggregation_mode == "per_chunk": - if not self.keep_pd_step_flag: + if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: init_kv_signal_per_query( forward_meta.seq_lens_encoder, forward_meta.seq_lens_this_time, diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index e547da97d..edbf67d4b 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1229,7 +1229,7 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["mask_rollback"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") - def _prepare_inputs(self) -> None: + def _prepare_inputs(self, is_dummy_or_profile_run=False) -> None: """Prepare the model inputs""" if envs.ENABLE_V1_KVCACHE_SCHEDULER: recover_decode_task( @@ -1280,7 +1280,7 @@ class GPUModelRunner(ModelRunnerBase): max_bad_tokens_len = np.max(self.share_inputs["bad_tokens_len"].numpy()) # Initialize forward meta data - self.initialize_forward_meta() + self.initialize_forward_meta(is_dummy_or_profile_run=is_dummy_or_profile_run) # Get sampling metadata self.sampling_metadata = SamplingMetadata( @@ -1334,7 +1334,7 @@ class GPUModelRunner(ModelRunnerBase): """Get current model""" return self.model - def initialize_forward_meta(self): + def initialize_forward_meta(self, is_dummy_or_profile_run=False): """ Initialize forward meta and attention meta data """ @@ -1386,6 +1386,9 @@ class GPUModelRunner(ModelRunnerBase): only_prefill_use_cudagraph if self.cudagraph_only_prefill else only_decode_use_cudagraph ) + # Set forward_meta.is_dummy_or_profile_run to True to skip init_kv_signal_per_query for attention backends + self.forward_meta.is_dummy_or_profile_run = is_dummy_or_profile_run + # Initialzie attention meta data for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) @@ -1778,7 +1781,7 @@ class GPUModelRunner(ModelRunnerBase): while True: # 1. Initialize forward meta and attention meta data - self._prepare_inputs() + self._prepare_inputs(is_dummy_or_profile_run=True) # 2. Padding inputs for cuda graph self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph diff --git a/tests/e2e/test_ernie_03b_pd_router_v1.py b/tests/e2e/test_ernie_03b_pd_router_v1.py new file mode 100644 index 000000000..b8be6e3fa --- /dev/null +++ b/tests/e2e/test_ernie_03b_pd_router_v1.py @@ -0,0 +1,418 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Test splitwise deployment which uses local_scheduler + router, +# and ENABLE_V1_KVCACHE_SCHEDULER is 1 + +import json +import os +import shutil +import signal +import subprocess +import sys +import time + +import pytest +import requests +from utils.serving_utils import ( + FD_API_PORT, + FD_CACHE_QUEUE_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + clean_ports, + get_registered_number, +) + +# Read ports from environment variables; use default values if not set +FD_CONNECTOR_PORT = int(os.getenv("FD_CONNECTOR_PORT", 8433)) +FD_ROUTER_PORT = int(os.getenv("FD_ROUTER_PORT", 8533)) + +# List of ports to clean before and after tests +PORTS_TO_CLEAN = [ + FD_API_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + FD_CACHE_QUEUE_PORT, + FD_CONNECTOR_PORT, + FD_API_PORT + 1, + FD_ENGINE_QUEUE_PORT + 1, + FD_METRICS_PORT + 1, + FD_CACHE_QUEUE_PORT + 1, + FD_CONNECTOR_PORT + 1, + FD_ROUTER_PORT, +] + + +@pytest.fixture(scope="session", autouse=True) +def setup_and_run_server(): + """ + Pytest fixture that runs once per test session: + - Cleans ports before tests + - Starts the API server as a subprocess + - Waits for server port to open (up to 30 seconds) + - Tears down server after all tests finish + """ + print("Pre-test port cleanup...") + clean_ports(PORTS_TO_CLEAN) + + print("log dir clean ") + if os.path.exists("log_router") and os.path.isdir("log_router"): + shutil.rmtree("log_router") + if os.path.exists("log_prefill") and os.path.isdir("log_prefill"): + shutil.rmtree("log_prefill") + if os.path.exists("log_decode") and os.path.isdir("log_decode"): + shutil.rmtree("log_decode") + + base_path = os.getenv("MODEL_PATH") + if base_path: + model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle") + else: + model_path = "baidu/ERNIE-4.5-0.3B-Paddle" + print(f"model_path: {model_path}") + + # router + print("start router...") + env_router = os.environ.copy() + env_router["FD_LOG_DIR"] = "log_router" + router_log_path = "router.log" + + router_cmd = [ + sys.executable, + "-m", + "fastdeploy.router.launch", + "--port", + str(FD_ROUTER_PORT), + "--splitwise", + ] + + with open(router_log_path, "w") as logfile: + process_router = subprocess.Popen( + router_cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + env=env_router, + ) + + # prefill实例 + print("start prefill...") + env_prefill = os.environ.copy() + env_prefill["CUDA_VISIBLE_DEVICES"] = "0" + env_prefill["ENABLE_V1_KVCACHE_SCHEDULER"] = "1" + env_prefill["FD_LOG_DIR"] = "log_prefill" + prefill_log_path = "server.log" + prefill_cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT), + "--tensor-parallel-size", + "1", + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT), + "--metrics-port", + str(FD_METRICS_PORT), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT), + "--max-model-len", + "8192", + "--max-num-seqs", + "20", + "--quantization", + "wint8", + "--splitwise-role", + "prefill", + "--cache-transfer-protocol", + "ipc", + "--pd-comm-port", + str(FD_CONNECTOR_PORT), + "--router", + f"0.0.0.0:{FD_ROUTER_PORT}", + ] + + # Start subprocess in new process group + with open(prefill_log_path, "w") as logfile: + process_prefill = subprocess.Popen( + prefill_cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + env=env_prefill, + ) + time.sleep(1) + + # decode实例 + print("start decode...") + env_decode = os.environ.copy() + env_decode["CUDA_VISIBLE_DEVICES"] = "1" + env_decode["ENABLE_V1_KVCACHE_SCHEDULER"] = "1" + env_decode["FD_LOG_DIR"] = "log_decode" + decode_log_path = "decode_server.log" + decode_cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT + 1), + "--tensor-parallel-size", + "1", + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT + 1), + "--metrics-port", + str(FD_METRICS_PORT + 1), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT + 1), + "--max-model-len", + "8192", + "--max-num-seqs", + "20", + "--quantization", + "wint8", + "--splitwise-role", + "decode", + "--cache-transfer-protocol", + "ipc", + "--pd-comm-port", + str(FD_CONNECTOR_PORT + 1), + "--router", + f"0.0.0.0:{FD_ROUTER_PORT}", + ] + + # Start subprocess in new process group + with open(decode_log_path, "w") as logfile: + process_decode = subprocess.Popen( + decode_cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + env=env_decode, + ) + + # Wait up to 300 seconds for API server to be ready + for _ in range(60): + registered_numbers = get_registered_number(f"0.0.0.0:{FD_ROUTER_PORT}") + if registered_numbers["prefill"] >= 1 and registered_numbers["decode"] >= 1: + print("Prefill and decode servers are both online") + break + time.sleep(5) + else: + print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...") + try: + os.killpg(process_prefill.pid, signal.SIGTERM) + os.killpg(process_decode.pid, signal.SIGTERM) + clean_ports() + except Exception as e: + print(f"Failed to kill process group: {e}") + raise RuntimeError(f"API server did not start on port {FD_API_PORT}") + + yield # Run tests + + print("\n===== Post-test server cleanup... =====") + try: + os.killpg(process_router.pid, signal.SIGTERM) + os.killpg(process_prefill.pid, signal.SIGTERM) + os.killpg(process_decode.pid, signal.SIGTERM) + clean_ports(PORTS_TO_CLEAN) + print(f"Prefill server (pid={process_prefill.pid}) terminated") + print(f"Decode server (pid={process_decode.pid}) terminated") + except Exception as e: + print(f"Failed to terminate API server: {e}") + + +@pytest.fixture(scope="session") +def api_url(request): + """ + Returns the API endpoint URL for chat completions. + """ + return f"http://0.0.0.0:{FD_ROUTER_PORT}/v1/chat/completions" + + +@pytest.fixture(scope="session") +def metrics_url(request): + """ + Returns the metrics endpoint URL. + """ + return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics" + + +@pytest.fixture +def headers(): + """ + Returns common HTTP request headers. + """ + return {"Content-Type": "application/json"} + + +def test_metrics_config(metrics_url): + timeout = 600 + url = metrics_url.replace("metrics", "config-info") + res = requests.get(url, timeout=timeout) + assert res.status_code == 200 + + +def send_request(url, payload, timeout=600): + """ + 发送请求到指定的URL,并返回响应结果。 + """ + headers = { + "Content-Type": "application/json", + } + + try: + res = requests.post(url, headers=headers, json=payload, timeout=timeout) + print("🟢 接收响应中...\n") + return res + except requests.exceptions.Timeout: + print(f"❌ 请求超时(超过 {timeout} 秒)") + return None + except requests.exceptions.RequestException as e: + print(f"❌ 请求失败:{e}") + return None + + +def get_stream_chunks(response): + """解析流式返回,生成chunk List[dict]""" + chunks = [] + + if response.status_code == 200: + for line in response.iter_lines(decode_unicode=True): + if line: + if line.startswith("data: "): + line = line[len("data: ") :] + + if line.strip() == "[DONE]": + break + + try: + chunk = json.loads(line) + chunks.append(chunk) + except Exception as e: + print(f"解析失败: {e}, 行内容: {line}") + else: + print(f"请求失败,状态码: {response.status_code}") + print("返回内容:", response.text) + + return chunks + + +def test_chat_usage_stream(api_url): + """测试流式chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]]) + print("Decode Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_chat_usage_non_stream(api_url): + """测试非流式chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + + response = send_request(url=api_url, payload=payload).json() + usage = response["usage"] + result = response["choices"][0]["message"]["content"] + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_stream(api_url): + """测试流式非chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + api_url = api_url.replace("chat/completions", "completions") + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["text"] for x in chunks[:-1]]) + print("Decode Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_non_stream(api_url): + """测试非流式非chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + api_url = api_url.replace("chat/completions", "completions") + + response = send_request(url=api_url, payload=payload).json() + usage = response["usage"] + result = response["choices"][0]["text"] + print("Decode Response:", result) + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"