Files
FastDeploy/tests/e2e/test_Qwen2-7B-Instruct_serving.py
qwes5s5 2ee91d7a96 [metrics] Add serveral observability metrics (#3868) (#4011)
* [metrics] Add serveral observability metrics (#3868)

* Add several observability metrics

* [wenxin-tools-584] 【可观测性】支持查看本节点的并发数、剩余block_size、排队请求数等信息

* adjust some metrics and md files

* trigger ci

* adjust ci file

* trigger ci

* trigger ci

---------

Co-authored-by: K11OntheBoat <your_email@example.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>

* version adjust

---------

Co-authored-by: K11OntheBoat <your_email@example.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-09-10 10:59:57 +08:00

819 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import concurrent.futures
import json
import os
import re
import shutil
import signal
import socket
import subprocess
import sys
import time
import openai
import pytest
import requests
from jsonschema import validate
# Read ports from environment variables; use default values if not set
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133))
FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233))
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
# List of ports to clean before and after tests
PORTS_TO_CLEAN = [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT]
def is_port_open(host: str, port: int, timeout=1.0):
"""
Check if a TCP port is open on the given host.
Returns True if connection succeeds, False otherwise.
"""
try:
with socket.create_connection((host, port), timeout):
return True
except Exception:
return False
def kill_process_on_port(port: int):
"""
Kill processes that are listening on the given port.
Uses `lsof` to find process ids and sends SIGKILL.
"""
try:
output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip()
current_pid = os.getpid()
parent_pid = os.getppid()
for pid in output.splitlines():
pid = int(pid)
if pid in (current_pid, parent_pid):
print(f"Skip killing current process (pid={pid}) on port {port}")
continue
os.kill(pid, signal.SIGKILL)
print(f"Killed process on port {port}, pid={pid}")
except subprocess.CalledProcessError:
pass
def clean_ports():
"""
Kill all processes occupying the ports listed in PORTS_TO_CLEAN.
"""
for port in PORTS_TO_CLEAN:
kill_process_on_port(port)
time.sleep(2)
@pytest.fixture(scope="session", autouse=True)
def setup_and_run_server():
"""
Pytest fixture that runs once per test session:
- Cleans ports before tests
- Starts the API server as a subprocess
- Waits for server port to open (up to 30 seconds)
- Tears down server after all tests finish
"""
print("Pre-test port cleanup...")
clean_ports()
print("log dir clean ")
if os.path.exists("log") and os.path.isdir("log"):
shutil.rmtree("log")
base_path = os.getenv("MODEL_PATH")
if base_path:
model_path = os.path.join(base_path, "Qwen2-7B-Instruct")
else:
model_path = "./Qwen2-7B-Instruct"
log_path = "server.log"
cmd = [
sys.executable,
"-m",
"fastdeploy.entrypoints.openai.api_server",
"--model",
model_path,
"--port",
str(FD_API_PORT),
"--tensor-parallel-size",
"1",
"--engine-worker-queue-port",
str(FD_ENGINE_QUEUE_PORT),
"--metrics-port",
str(FD_METRICS_PORT),
"--cache-queue-port",
str(FD_CACHE_QUEUE_PORT),
"--max-model-len",
"32768",
"--max-num-seqs",
"128",
"--quantization",
"wint8",
]
# Start subprocess in new process group
with open(log_path, "w") as logfile:
process = subprocess.Popen(
cmd,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True, # Enables killing full group via os.killpg
)
# Wait up to 300 seconds for API server to be ready
for _ in range(300):
if is_port_open("127.0.0.1", FD_API_PORT):
print(f"API server is up on port {FD_API_PORT}")
break
time.sleep(1)
else:
print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...")
try:
os.killpg(process.pid, signal.SIGTERM)
except Exception as e:
print(f"Failed to kill process group: {e}")
raise RuntimeError(f"API server did not start on port {FD_API_PORT}")
yield # Run tests
print("\n===== Post-test server cleanup... =====")
try:
os.killpg(process.pid, signal.SIGTERM)
clean_ports()
print(f"API server (pid={process.pid}) terminated")
except Exception as e:
print(f"Failed to terminate API server: {e}")
@pytest.fixture(scope="session")
def api_url(request):
"""
Returns the API endpoint URL for chat completions.
"""
return f"http://0.0.0.0:{FD_API_PORT}/v1/chat/completions"
@pytest.fixture(scope="session")
def metrics_url(request):
"""
Returns the metrics endpoint URL.
"""
return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
@pytest.fixture
def headers():
"""
Returns common HTTP request headers.
"""
return {"Content-Type": "application/json"}
@pytest.fixture
def consistent_payload():
"""
Returns a fixed payload for consistency testing,
including a fixed random seed and temperature.
"""
return {
"messages": [{"role": "user", "content": "用一句话介绍 PaddlePaddle"}],
"temperature": 0.9,
"top_p": 0, # fix top_p to reduce randomness
"seed": 13, # fixed random seed
}
# ==========================
# JSON Schema for validating chat API responses
# ==========================
chat_response_schema = {
"type": "object",
"properties": {
"id": {"type": "string"},
"object": {"type": "string"},
"created": {"type": "number"},
"model": {"type": "string"},
"choices": {
"type": "array",
"items": {
"type": "object",
"properties": {
"message": {
"type": "object",
"properties": {
"role": {"type": "string"},
"content": {"type": "string"},
},
"required": ["role", "content"],
},
"index": {"type": "number"},
"finish_reason": {"type": "string"},
},
"required": ["message", "index", "finish_reason"],
},
},
},
"required": ["id", "object", "created", "model", "choices"],
}
# ==========================
# Helper function to calculate difference rate between two texts
# ==========================
def calculate_diff_rate(text1, text2):
"""
Calculate the difference rate between two strings
based on the normalized Levenshtein edit distance.
Returns a float in [0,1], where 0 means identical.
"""
if text1 == text2:
return 0.0
len1, len2 = len(text1), len(text2)
dp = [[0] * (len2 + 1) for _ in range(len1 + 1)]
for i in range(len1 + 1):
for j in range(len2 + 1):
if i == 0 or j == 0:
dp[i][j] = i + j
elif text1[i - 1] == text2[j - 1]:
dp[i][j] = dp[i - 1][j - 1]
else:
dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1])
edit_distance = dp[len1][len2]
max_len = max(len1, len2)
return edit_distance / max_len if max_len > 0 else 0.0
# ==========================
# Valid prompt test cases for parameterized testing
# ==========================
valid_prompts = [
[{"role": "user", "content": "你好"}],
[{"role": "user", "content": "用一句话介绍 FastDeploy"}],
]
@pytest.mark.parametrize("messages", valid_prompts)
def test_valid_chat(messages, api_url, headers):
"""
Test valid chat requests.
"""
resp = requests.post(api_url, headers=headers, json={"messages": messages})
assert resp.status_code == 200
validate(instance=resp.json(), schema=chat_response_schema)
# ==========================
# Consistency test for repeated runs with fixed payload
# ==========================
def test_consistency_between_runs(api_url, headers, consistent_payload):
"""
Test that two runs with the same fixed input produce similar outputs.
"""
# First request
resp1 = requests.post(api_url, headers=headers, json=consistent_payload)
assert resp1.status_code == 200
result1 = resp1.json()
content1 = result1["choices"][0]["message"]["content"]
# Second request
resp2 = requests.post(api_url, headers=headers, json=consistent_payload)
assert resp2.status_code == 200
result2 = resp2.json()
content2 = result2["choices"][0]["message"]["content"]
# Calculate difference rate
diff_rate = calculate_diff_rate(content1, content2)
# Verify that the difference rate is below the threshold
assert diff_rate < 0.05, f"Output difference too large ({diff_rate:.4%})"
# ==========================
# Invalid prompt tests
# ==========================
invalid_prompts = [
[], # Empty array
[{}], # Empty object
[{"role": "user"}], # Missing content
[{"content": "hello"}], # Missing role
]
@pytest.mark.parametrize("messages", invalid_prompts)
def test_invalid_chat(messages, api_url, headers):
"""
Test invalid chat inputs
"""
resp = requests.post(api_url, headers=headers, json={"messages": messages})
assert resp.status_code >= 400, "Invalid request should return an error status code"
# ==========================
# Test for input exceeding context length
# ==========================
def test_exceed_context_length(api_url, headers):
"""
Test case for inputs that exceed the model's maximum context length.
"""
# Construct an overly long message
long_content = "你好," * 20000
messages = [{"role": "user", "content": long_content}]
resp = requests.post(api_url, headers=headers, json={"messages": messages})
# Check if the response indicates a token limit error or server error (500)
try:
response_json = resp.json()
except Exception:
response_json = {}
# Check status code and response content
assert (
resp.status_code != 200 or "token" in json.dumps(response_json).lower()
), f"Expected token limit error or similar, but got a normal response: {response_json}"
# ==========================
# Multi-turn Conversation Test
# ==========================
def test_multi_turn_conversation(api_url, headers):
"""
Test whether multi-turn conversation context is effective.
"""
messages = [
{"role": "user", "content": "你是谁?"},
{"role": "assistant", "content": "我是AI助手"},
{"role": "user", "content": "你能做什么?"},
]
resp = requests.post(api_url, headers=headers, json={"messages": messages})
assert resp.status_code == 200
validate(instance=resp.json(), schema=chat_response_schema)
# ==========================
# Concurrent Performance Test
# ==========================
def test_concurrent_perf(api_url, headers):
"""
Send concurrent requests to test stability and response time.
"""
prompts = [{"role": "user", "content": "Introduce FastDeploy."}]
def send_request():
"""
Send a single request
"""
resp = requests.post(api_url, headers=headers, json={"messages": prompts})
assert resp.status_code == 200
return resp.elapsed.total_seconds()
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
futures = [executor.submit(send_request) for _ in range(8)]
durations = [f.result() for f in futures]
print("\nResponse time for each request:", durations)
# ==========================
# Metrics Endpoint Test
# ==========================
def test_metrics_endpoint(metrics_url):
"""
Test the metrics monitoring endpoint.
"""
resp = requests.get(metrics_url, timeout=5)
assert resp.status_code == 200, f"Unexpected status code: {resp.status_code}"
assert "text/plain" in resp.headers["Content-Type"], "Content-Type is not text/plain"
# Parse Prometheus metrics data
metrics_data = resp.text
lines = metrics_data.split("\n")
metric_lines = [line for line in lines if not line.startswith("#") and line.strip() != ""]
# 断言 具体值
num_requests_running_found = False
num_requests_waiting_found = False
time_to_first_token_seconds_sum_found = False
time_per_output_token_seconds_sum_found = False
e2e_request_latency_seconds_sum_found = False
request_inference_time_seconds_sum_found = False
request_queue_time_seconds_sum_found = False
request_prefill_time_seconds_sum_found = False
request_decode_time_seconds_sum_found = False
prompt_tokens_total_found = False
generation_tokens_total_found = False
request_prompt_tokens_sum_found = False
request_generation_tokens_sum_found = False
gpu_cache_usage_perc_found = False
request_params_max_tokens_sum_found = False
request_success_total_found = False
cache_config_info_found = False
available_batch_size_found = False
hit_req_rate_found = False
hit_token_rate_found = False
cpu_hit_token_rate_found = False
gpu_hit_token_rate_found = False
for line in metric_lines:
if line.startswith("fastdeploy:num_requests_running"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "num_requests_running 值错误"
num_requests_running_found = True
elif line.startswith("fastdeploy:num_requests_waiting"):
_, value = line.rsplit(" ", 1)
num_requests_waiting_found = True
assert float(value) >= 0, "num_requests_waiting 值错误"
elif line.startswith("fastdeploy:time_to_first_token_seconds_sum"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "time_to_first_token_seconds_sum 值错误"
time_to_first_token_seconds_sum_found = True
elif line.startswith("fastdeploy:time_per_output_token_seconds_sum"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "time_per_output_token_seconds_sum 值错误"
time_per_output_token_seconds_sum_found = True
elif line.startswith("fastdeploy:e2e_request_latency_seconds_sum"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "e2e_request_latency_seconds_sum_found 值错误"
e2e_request_latency_seconds_sum_found = True
elif line.startswith("fastdeploy:request_inference_time_seconds_sum"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "request_inference_time_seconds_sum 值错误"
request_inference_time_seconds_sum_found = True
elif line.startswith("fastdeploy:request_queue_time_seconds_sum"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "request_queue_time_seconds_sum 值错误"
request_queue_time_seconds_sum_found = True
elif line.startswith("fastdeploy:request_prefill_time_seconds_sum"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "request_prefill_time_seconds_sum 值错误"
request_prefill_time_seconds_sum_found = True
elif line.startswith("fastdeploy:request_decode_time_seconds_sum"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "request_decode_time_seconds_sum 值错误"
request_decode_time_seconds_sum_found = True
elif line.startswith("fastdeploy:prompt_tokens_total"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "prompt_tokens_total 值错误"
prompt_tokens_total_found = True
elif line.startswith("fastdeploy:generation_tokens_total"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "generation_tokens_total 值错误"
generation_tokens_total_found = True
elif line.startswith("fastdeploy:request_prompt_tokens_sum"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "request_prompt_tokens_sum 值错误"
request_prompt_tokens_sum_found = True
elif line.startswith("fastdeploy:request_generation_tokens_sum"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "request_generation_tokens_sum 值错误"
request_generation_tokens_sum_found = True
elif line.startswith("fastdeploy:gpu_cache_usage_perc"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "gpu_cache_usage_perc 值错误"
gpu_cache_usage_perc_found = True
elif line.startswith("fastdeploy:request_params_max_tokens_sum"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "request_params_max_tokens_sum 值错误"
request_params_max_tokens_sum_found = True
elif line.startswith("fastdeploy:request_success_total"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "request_success_total 值错误"
request_success_total_found = True
elif line.startswith("fastdeploy:cache_config_info"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "cache_config_info 值错误"
cache_config_info_found = True
elif line.startswith("fastdeploy:available_batch_size"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "available_batch_size 值错误"
available_batch_size_found = True
elif line.startswith("fastdeploy:hit_req_rate"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "hit_req_rate 值错误"
hit_req_rate_found = True
elif line.startswith("fastdeploy:hit_token_rate"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "hit_token_rate 值错误"
hit_token_rate_found = True
elif line.startswith("fastdeploy:cpu_hit_token_rate"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "cpu_hit_token_rate 值错误"
cpu_hit_token_rate_found = True
elif line.startswith("fastdeploy:gpu_hit_token_rate"):
_, value = line.rsplit(" ", 1)
assert float(value) >= 0, "gpu_hit_token_rate 值错误"
gpu_hit_token_rate_found = True
assert num_requests_running_found, "缺少 fastdeploy:num_requests_running 指标"
assert num_requests_waiting_found, "缺少 fastdeploy:num_requests_waiting 指标"
assert time_to_first_token_seconds_sum_found, "缺少 fastdeploy:time_to_first_token_seconds_sum 指标"
assert time_per_output_token_seconds_sum_found, "缺少 fastdeploy:time_per_output_token_seconds_sum 指标"
assert e2e_request_latency_seconds_sum_found, "缺少 fastdeploy:e2e_request_latency_seconds_sum_found 指标"
assert request_inference_time_seconds_sum_found, "缺少 fastdeploy:request_inference_time_seconds_sum 指标"
assert request_queue_time_seconds_sum_found, "缺少 fastdeploy:request_queue_time_seconds_sum 指标"
assert request_prefill_time_seconds_sum_found, "缺少 fastdeploy:request_prefill_time_seconds_sum 指标"
assert request_decode_time_seconds_sum_found, "缺少 fastdeploy:request_decode_time_seconds_sum 指标"
assert prompt_tokens_total_found, "缺少 fastdeploy:prompt_tokens_total 指标"
assert generation_tokens_total_found, "缺少 fastdeploy:generation_tokens_total 指标"
assert request_prompt_tokens_sum_found, "缺少 fastdeploy:request_prompt_tokens_sum 指标"
assert request_generation_tokens_sum_found, "缺少 fastdeploy:request_generation_tokens_sum 指标"
assert gpu_cache_usage_perc_found, "缺少 fastdeploy:gpu_cache_usage_perc 指标"
assert request_params_max_tokens_sum_found, "缺少 fastdeploy:request_params_max_tokens_sum 指标"
assert request_success_total_found, "缺少 fastdeploy:request_success_total 指标"
assert cache_config_info_found, "缺少 fastdeploy:cache_config_info 指标"
assert available_batch_size_found, "缺少 fastdeploy:available_batch_size 指标"
assert hit_req_rate_found, "缺少 fastdeploy:hit_req_rate 指标"
assert hit_token_rate_found, "缺少 fastdeploy:hit_token_rate 指标"
assert cpu_hit_token_rate_found, "缺少 fastdeploy:hit_token_rate 指标"
assert gpu_hit_token_rate_found, "缺少 fastdeploy:gpu_hit_token_rate 指标"
# ==========================
# OpenAI Client chat.completions Test
# ==========================
@pytest.fixture
def openai_client():
ip = "0.0.0.0"
service_http_port = str(FD_API_PORT)
client = openai.Client(
base_url=f"http://{ip}:{service_http_port}/v1",
api_key="EMPTY_API_KEY",
)
return client
# Non-streaming test
def test_non_streaming_chat(openai_client):
"""Test non-streaming chat functionality with the local service"""
response = openai_client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=1,
max_tokens=1024,
stream=False,
)
assert hasattr(response, "choices")
assert len(response.choices) > 0
assert hasattr(response.choices[0], "message")
assert hasattr(response.choices[0].message, "content")
# Streaming test
def test_streaming_chat(openai_client, capsys):
"""Test streaming chat functionality with the local service"""
response = openai_client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "List 3 countries and their capitals."},
{
"role": "assistant",
"content": "China(Beijing), France(Paris), Australia(Canberra).",
},
{"role": "user", "content": "OK, tell more."},
],
temperature=1,
max_tokens=1024,
stream=True,
)
output = []
for chunk in response:
if hasattr(chunk.choices[0], "delta") and hasattr(chunk.choices[0].delta, "content"):
output.append(chunk.choices[0].delta.content)
assert len(output) > 2
# ==========================
# OpenAI Client completions Test
# ==========================
def test_non_streaming(openai_client):
"""Test non-streaming chat functionality with the local service"""
response = openai_client.completions.create(
model="default",
prompt="Hello, how are you?",
temperature=1,
max_tokens=1024,
stream=False,
)
# Assertions to check the response structure
assert hasattr(response, "choices")
assert len(response.choices) > 0
def test_streaming(openai_client, capsys):
"""Test streaming functionality with the local service"""
response = openai_client.completions.create(
model="default",
prompt="Hello, how are you?",
temperature=1,
max_tokens=1024,
stream=True,
)
# Collect streaming output
output = []
for chunk in response:
output.append(chunk.choices[0].text)
assert len(output) > 0
def test_profile_reset_block_num():
"""测试profile reset_block_num功能与baseline diff不能超过5%"""
log_file = "./log/config.log"
baseline = 32562
if not os.path.exists(log_file):
pytest.fail(f"Log file not found: {log_file}")
with open(log_file, "r") as f:
log_lines = f.readlines()
target_line = None
for line in log_lines:
if "Reset block num" in line:
target_line = line.strip()
break
if target_line is None:
pytest.fail("日志中没有Reset block num信息")
match = re.search(r"total_block_num:(\d+)", target_line)
if not match:
pytest.fail(f"Failed to extract total_block_num from line: {target_line}")
try:
actual_value = int(match.group(1))
except ValueError:
pytest.fail(f"Invalid number format: {match.group(1)}")
lower_bound = baseline * (1 - 0.05)
upper_bound = baseline * (1 + 0.05)
print(f"Reset total_block_num: {actual_value}. baseline: {baseline}")
assert lower_bound <= actual_value <= upper_bound, (
f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内"
f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]"
)
def test_prompt_token_ids_in_non_streaming_completion(openai_client):
"""
Test cases for passing token ids through `prompt`/`prompt_token_ids` in non-streaming completion api
"""
# Test case for passing a token id list in `prompt_token_ids`
response = openai_client.completions.create(
model="default",
prompt="",
temperature=1,
max_tokens=5,
extra_body={"prompt_token_ids": [5209, 626, 274, 45954, 1071, 3265, 3934, 1869, 93937]},
stream=False,
)
assert len(response.choices) == 1
assert response.usage.prompt_tokens == 9
# Test case for passing a batch of token id lists in `prompt_token_ids`
response = openai_client.completions.create(
model="default",
prompt="",
temperature=1,
max_tokens=5,
extra_body={"prompt_token_ids": [[5209, 626, 274, 45954, 1071, 3265, 3934, 1869, 93937], [1, 2, 3]]},
stream=False,
)
assert len(response.choices) == 2
assert response.usage.prompt_tokens == 9 + 3
# Test case for passing a token id list in `prompt`
response = openai_client.completions.create(
model="default",
prompt=[5209, 626, 274, 45954, 1071, 3265, 3934, 1869, 93937],
temperature=1,
max_tokens=5,
stream=False,
)
assert len(response.choices) == 1
assert response.usage.prompt_tokens == 9
# Test case for passing a batch of token id lists in `prompt`
response = openai_client.completions.create(
model="default",
prompt=[[5209, 626, 274, 45954, 1071, 3265, 3934, 1869, 93937], [1, 2, 3]],
temperature=1,
max_tokens=5,
stream=False,
)
assert len(response.choices) == 2
assert response.usage.prompt_tokens == 9 + 3
def test_prompt_token_ids_in_streaming_completion(openai_client):
"""
Test cases for passing token ids through `prompt`/`prompt_token_ids` in streaming completion api
"""
# Test case for passing a token id list in `prompt_token_ids`
response = openai_client.completions.create(
model="default",
prompt="",
temperature=1,
max_tokens=5,
extra_body={"prompt_token_ids": [5209, 626, 274, 45954, 1071, 3265, 3934, 1869, 93937]},
stream=True,
stream_options={"include_usage": True},
)
sum_prompt_tokens = 0
for chunk in response:
if len(chunk.choices) > 0:
assert chunk.usage is None
else:
sum_prompt_tokens += chunk.usage.prompt_tokens
assert sum_prompt_tokens == 9
# Test case for passing a batch of token id lists in `prompt_token_ids`
response = openai_client.completions.create(
model="default",
prompt="",
temperature=1,
max_tokens=5,
extra_body={"prompt_token_ids": [[5209, 626, 274, 45954, 1071, 3265, 3934, 1869, 93937], [1, 2, 3]]},
stream=True,
stream_options={"include_usage": True},
)
sum_prompt_tokens = 0
for chunk in response:
if len(chunk.choices) > 0:
assert chunk.usage is None
else:
sum_prompt_tokens += chunk.usage.prompt_tokens
assert sum_prompt_tokens == 9 + 3
# Test case for passing a token id list in `prompt`
response = openai_client.completions.create(
model="default",
prompt=[5209, 626, 274, 45954, 1071, 3265, 3934, 1869, 93937],
temperature=1,
max_tokens=5,
stream=True,
stream_options={"include_usage": True},
)
sum_prompt_tokens = 0
for chunk in response:
if len(chunk.choices) > 0:
assert chunk.usage is None
else:
sum_prompt_tokens += chunk.usage.prompt_tokens
assert sum_prompt_tokens == 9
# Test case for passing a batch of token id lists in `prompt`
response = openai_client.completions.create(
model="default",
prompt=[[5209, 626, 274, 45954, 1071, 3265, 3934, 1869, 93937], [1, 2, 3]],
temperature=1,
max_tokens=5,
stream=True,
stream_options={"include_usage": True},
)
sum_prompt_tokens = 0
for chunk in response:
if len(chunk.choices) > 0:
assert chunk.usage is None
else:
sum_prompt_tokens += chunk.usage.prompt_tokens
assert sum_prompt_tokens == 9 + 3