[Feature] support seed parameter (#3161)

* support seed

* fix

* add SamplingMetadata seed test

* The next_tokens values are inconsistent!

* add air and rejection seed test

* fix

* add SamplingParams seed test

* fix seed=0

* Default to defualt

* fix

* fix args_utils

* fix review

* fix review

* fix

* fix

* add xpu,gcu,iluvatar support seed

* fix
This commit is contained in:
lizexu123
2025-08-06 15:20:47 +08:00
committed by GitHub
parent 20839abccf
commit afff4d37ea
15 changed files with 386 additions and 5 deletions

View File

@@ -122,6 +122,7 @@ class ModelConfig:
self.enable_mm = False self.enable_mm = False
self.enable_redundant_experts = False self.enable_redundant_experts = False
self.redundant_experts_num = 0 self.redundant_experts_num = 0
self.seed = 0
self.quantization = None self.quantization = None
for key, value in args.items(): for key, value in args.items():
if hasattr(self, key): if hasattr(self, key):

View File

@@ -316,6 +316,11 @@ class EngineArgs:
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values. Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
""" """
seed: int = 0
"""
Random seed to use for initialization. If not set, defaults to 0.
"""
enable_early_stop: bool = False enable_early_stop: bool = False
""" """
Flag to enable early stop. Default is False (disabled). Flag to enable early stop. Default is False (disabled).
@@ -484,6 +489,12 @@ class EngineArgs:
default=EngineArgs.enable_logprob, default=EngineArgs.enable_logprob,
help="Enable output of token-level log probabilities.", help="Enable output of token-level log probabilities.",
) )
model_group.add_argument(
"--seed",
type=int,
default=EngineArgs.seed,
help="Random seed for initialization. If not specified, defaults to 0.",
)
model_group.add_argument( model_group.add_argument(
"--enable-early-stop", "--enable-early-stop",
action="store_true", action="store_true",

View File

@@ -43,6 +43,7 @@ class SamplingMetadata:
top_p: paddle.Tensor top_p: paddle.Tensor
top_k: Optional[paddle.Tensor] = None top_k: Optional[paddle.Tensor] = None
min_p: Optional[paddle.Tensor] = None min_p: Optional[paddle.Tensor] = None
seed: Optional[paddle.Tensor] = None
max_num_logprobs: Optional[int] = None max_num_logprobs: Optional[int] = None
enable_early_stop: Optional[int] = False enable_early_stop: Optional[int] = False
stop_flags: Optional[paddle.Tensor] = None stop_flags: Optional[paddle.Tensor] = None

View File

@@ -282,7 +282,9 @@ class Sampler(nn.Layer):
probs = min_p_sampling(probs, sampling_metadata.min_p) probs = min_p_sampling(probs, sampling_metadata.min_p)
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k) _, next_tokens = top_k_top_p_sampling(
probs, sampling_metadata.top_p, sampling_metadata.top_k, seed=sampling_metadata.seed[0, 0]
)
logprobs_tensors = ( logprobs_tensors = (
None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens) None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens)

View File

@@ -29,6 +29,8 @@ from logging.handlers import BaseRotatingHandler
from pathlib import Path from pathlib import Path
from typing import Literal, TypeVar, Union from typing import Literal, TypeVar, Union
import numpy as np
import paddle
import requests import requests
import yaml import yaml
from aistudio_sdk.snapshot_download import snapshot_download as aistudio_download from aistudio_sdk.snapshot_download import snapshot_download as aistudio_download
@@ -295,6 +297,13 @@ def extract_tar(tar_path, output_dir):
raise RuntimeError(f"Extraction failed: {e!s}") raise RuntimeError(f"Extraction failed: {e!s}")
def set_random_seed(seed: int) -> None:
if seed is not None:
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
def download_model(url, output_dir, temp_tar): def download_model(url, output_dir, temp_tar):
""" """
下载模型,并将其解压到指定目录。 下载模型,并将其解压到指定目录。

View File

@@ -540,6 +540,7 @@ class GCUModelRunner(ModelRunnerBase):
top_p=self.share_inputs["top_p"], top_p=self.share_inputs["top_p"],
top_k=self.share_inputs["top_k"], top_k=self.share_inputs["top_k"],
min_p=self.share_inputs["min_p"], min_p=self.share_inputs["min_p"],
seed=self.share_inputs["infer_seed"],
step_idx=self.share_inputs["step_idx"], step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"], pre_token_ids=self.share_inputs["pre_ids"],
prompt_ids=self.share_inputs["prompt_ids"], prompt_ids=self.share_inputs["prompt_ids"],

View File

@@ -22,7 +22,7 @@ from paddle import nn
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request from fastdeploy.engine.request import Request
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger, set_random_seed
from fastdeploy.worker.gcu_model_runner import GCUModelRunner from fastdeploy.worker.gcu_model_runner import GCUModelRunner
from fastdeploy.worker.output import ModelRunnerOutput from fastdeploy.worker.output import ModelRunnerOutput
from fastdeploy.worker.worker_base import WorkerBase from fastdeploy.worker.worker_base import WorkerBase
@@ -60,6 +60,7 @@ class GcuWorker(WorkerBase):
else: else:
raise RuntimeError(f"Not support device type: {self.device_config.device}") raise RuntimeError(f"Not support device type: {self.device_config.device}")
set_random_seed(self.fd_config.model_config.seed)
# Construct model runner # Construct model runner
self.model_runner: GCUModelRunner = GCUModelRunner( self.model_runner: GCUModelRunner = GCUModelRunner(
fd_config=self.fd_config, fd_config=self.fd_config,
@@ -128,6 +129,7 @@ class GcuWorker(WorkerBase):
self.model_runner.sot_warmup() self.model_runner.sot_warmup()
# 2. Triger cuda grpah capture # 2. Triger cuda grpah capture
self.model_runner.capture_model() self.model_runner.capture_model()
set_random_seed(self.fd_config.model_config.seed)
def check_health(self) -> bool: def check_health(self) -> bool:
""" """ """ """

View File

@@ -131,6 +131,7 @@ class GPUModelRunner(ModelRunnerBase):
fill_value=4, fill_value=4,
dtype="int64", dtype="int64",
) )
self.restore_chunked_prefill_request = dict() self.restore_chunked_prefill_request = dict()
# Initialize attention Backend # Initialize attention Backend
@@ -813,6 +814,7 @@ class GPUModelRunner(ModelRunnerBase):
top_p=self.share_inputs["top_p"], top_p=self.share_inputs["top_p"],
top_k=self.share_inputs["top_k"], top_k=self.share_inputs["top_k"],
min_p=self.share_inputs["min_p"], min_p=self.share_inputs["min_p"],
seed=self.share_inputs["infer_seed"],
step_idx=self.share_inputs["step_idx"], step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"], pre_token_ids=self.share_inputs["pre_ids"],
prompt_ids=self.share_inputs["prompt_ids"], prompt_ids=self.share_inputs["prompt_ids"],

View File

@@ -27,7 +27,7 @@ from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request from fastdeploy.engine.request import Request
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
from fastdeploy.plugins.model_runner import load_model_runner_plugins from fastdeploy.plugins.model_runner import load_model_runner_plugins
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger, set_random_seed
from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.model_runner_base import ModelRunnerBase
from fastdeploy.worker.output import ModelRunnerOutput from fastdeploy.worker.output import ModelRunnerOutput
from fastdeploy.worker.worker_base import WorkerBase from fastdeploy.worker.worker_base import WorkerBase
@@ -75,6 +75,7 @@ class GpuWorker(WorkerBase):
else: else:
raise RuntimeError(f"Not support device type: {self.device_config.device}") raise RuntimeError(f"Not support device type: {self.device_config.device}")
set_random_seed(self.fd_config.model_config.seed)
# Construct model runner # Construct model runner
self.model_runner: ModelRunnerBase = ModelRunner( self.model_runner: ModelRunnerBase = ModelRunner(
fd_config=self.fd_config, fd_config=self.fd_config,
@@ -129,6 +130,7 @@ class GpuWorker(WorkerBase):
# 2. Profile run # 2. Profile run
self.model_runner.profile_run() self.model_runner.profile_run()
set_random_seed(self.fd_config.model_config.seed)
# 3. Statistical memory information # 3. Statistical memory information
paddle_reserved_mem_after_run = paddle.device.cuda.max_memory_reserved(local_rank) paddle_reserved_mem_after_run = paddle.device.cuda.max_memory_reserved(local_rank)

View File

@@ -509,6 +509,7 @@ class IluvatarModelRunner(ModelRunnerBase):
temperature=self.share_inputs["temperature"], temperature=self.share_inputs["temperature"],
top_p=self.share_inputs["top_p"], top_p=self.share_inputs["top_p"],
top_k=self.share_inputs["top_k"], top_k=self.share_inputs["top_k"],
seed=self.share_inputs["seed"],
step_idx=self.share_inputs["step_idx"], step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"], pre_token_ids=self.share_inputs["pre_ids"],
prompt_ids=self.share_inputs["prompt_ids"], prompt_ids=self.share_inputs["prompt_ids"],

View File

@@ -23,7 +23,7 @@ from paddle import nn
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request from fastdeploy.engine.request import Request
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger, set_random_seed
from fastdeploy.worker.iluvatar_model_runner import IluvatarModelRunner from fastdeploy.worker.iluvatar_model_runner import IluvatarModelRunner
from fastdeploy.worker.output import ModelRunnerOutput from fastdeploy.worker.output import ModelRunnerOutput
from fastdeploy.worker.worker_base import WorkerBase from fastdeploy.worker.worker_base import WorkerBase
@@ -60,6 +60,7 @@ class IluvatarWorker(WorkerBase):
else: else:
raise RuntimeError(f"Not support device type: {self.device_config.device}") raise RuntimeError(f"Not support device type: {self.device_config.device}")
set_random_seed(self.fd_config.model_config.seed)
# Construct model runner # Construct model runner
self.model_runner: IluvatarModelRunner = IluvatarModelRunner( self.model_runner: IluvatarModelRunner = IluvatarModelRunner(
fd_config=self.fd_config, fd_config=self.fd_config,
@@ -130,6 +131,7 @@ class IluvatarWorker(WorkerBase):
# 2. Triger cuda grpah capture # 2. Triger cuda grpah capture
self.model_runner.capture_model() self.model_runner.capture_model()
set_random_seed(self.fd_config.model_config.seed)
def check_health(self) -> bool: def check_health(self) -> bool:
""" """ """ """

View File

@@ -677,6 +677,7 @@ class XPUModelRunner(ModelRunnerBase):
top_p=self.share_inputs["top_p"], top_p=self.share_inputs["top_p"],
top_k=self.share_inputs["top_k"], top_k=self.share_inputs["top_k"],
min_p=self.share_inputs["min_p"], min_p=self.share_inputs["min_p"],
seed=self.share_inputs["infer_seed"],
step_idx=self.share_inputs["step_idx"], step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"], pre_token_ids=self.share_inputs["pre_ids"],
frequency_penalties=self.share_inputs["frequency_score"], frequency_penalties=self.share_inputs["frequency_score"],

View File

@@ -23,7 +23,7 @@ from paddle import nn
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request from fastdeploy.engine.request import Request
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger, set_random_seed
from fastdeploy.worker.output import ModelRunnerOutput from fastdeploy.worker.output import ModelRunnerOutput
from fastdeploy.worker.worker_base import WorkerBase from fastdeploy.worker.worker_base import WorkerBase
from fastdeploy.worker.xpu_model_runner import XPUModelRunner from fastdeploy.worker.xpu_model_runner import XPUModelRunner
@@ -60,6 +60,7 @@ class XpuWorker(WorkerBase):
else: else:
raise RuntimeError(f"Not support device type: {self.device_config.device}") raise RuntimeError(f"Not support device type: {self.device_config.device}")
set_random_seed(self.fd_config.model_config.seed)
# Construct model runner # Construct model runner
self.model_runner: XPUModelRunner = XPUModelRunner( self.model_runner: XPUModelRunner = XPUModelRunner(
fd_config=self.fd_config, fd_config=self.fd_config,
@@ -110,6 +111,7 @@ class XpuWorker(WorkerBase):
self.model_runner.prepare_profile() self.model_runner.prepare_profile()
self.model_runner.profile_run() self.model_runner.profile_run()
set_random_seed(self.fd_config.model_config.seed)
total_available_memory = int(total_memory * self.cache_config.gpu_memory_utilization) total_available_memory = int(total_memory * self.cache_config.gpu_memory_utilization)
used_memory = xpu_get_used_global_memory(int(self.device_ids[self.local_rank])) used_memory = xpu_get_used_global_memory(int(self.device_ids[self.local_rank]))

View File

@@ -191,6 +191,29 @@ def test_chat_completion(llm):
pytest.fail(f"Chat case {i + 1} failed") pytest.fail(f"Chat case {i + 1} failed")
def test_seed(llm):
"""
Test chat completion with same seed
"""
prompt = "请介绍下中国的四大发明,用一句话概述每个发明。"
sampling_params = SamplingParams(temperature=0.1, seed=1, max_tokens=100)
num_runs = 5
results = []
try:
for i in range(num_runs):
outputs = llm.generate(prompt, sampling_params)
results.append(outputs[0].outputs.text)
assert all([result == results[0] for result in results]), "Results are not identical."
print("All results are identical.")
except Exception:
print("Failed during prompt generation.")
traceback.print_exc()
pytest.fail("Prompt generation test failed")
if __name__ == "__main__": if __name__ == "__main__":
""" """
Main entry point for the test script. Main entry point for the test script.

View File

@@ -0,0 +1,321 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import signal
import socket
import subprocess
import sys
import time
import pytest
import requests
def is_port_open(host: str, port: int, timeout=1.0):
"""Check if a TCP port is open on the given host."""
try:
with socket.create_connection((host, port), timeout):
return True
except Exception:
return False
def kill_process_on_port(port: int):
"""Kill processes that are listening on the given port."""
try:
output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip()
for pid in output.splitlines():
os.kill(int(pid), signal.SIGKILL)
print(f"Killed process on port {port}, pid={pid}")
except subprocess.CalledProcessError:
pass
def clean_specific_ports(ports_list):
"""Kill all processes occupying the specified ports."""
for port in ports_list:
kill_process_on_port(port)
def create_server_process_with_sampling(sampling_class: str, api_port: int, queue_port: int, metrics_port: int):
"""
Create and start the API server process with specified sampling class and ports.
Returns the process object.
"""
base_path = os.getenv("MODEL_PATH")
if base_path:
model_path = os.path.join(base_path, "Qwen3-30B-A3B")
else:
model_path = "./Qwen3-30B-A3B"
log_path = f"server_{sampling_class}_{api_port}.log"
cmd = [
sys.executable,
"-m",
"fastdeploy.entrypoints.openai.api_server",
"--model",
model_path,
"--port",
str(api_port),
"--tensor-parallel-size",
"1",
"--engine-worker-queue-port",
str(queue_port),
"--metrics-port",
str(metrics_port),
"--max-model-len",
"32768",
"--max-num-seqs",
"50",
"--quantization",
"wint4",
]
env = os.environ.copy()
env["FD_SAMPLING_CLASS"] = sampling_class
print(f"Starting server with FD_SAMPLING_CLASS={sampling_class} on port {api_port}")
with open(log_path, "w") as logfile:
process = subprocess.Popen(
cmd,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True,
env=env,
)
return process
def wait_for_server_ready_on_port(api_port: int, timeout=300):
"""Wait for the API server to be ready on specified port."""
for _ in range(timeout):
if is_port_open("127.0.0.1", api_port):
print(f"API server is up on port {api_port}")
return True
time.sleep(1)
return False
# ==========================
# Fixtures for pytest
# ==========================
@pytest.fixture
def headers():
"""Returns common HTTP request headers."""
return {"Content-Type": "application/json"}
@pytest.fixture
def consistent_payload():
"""Returns a fixed payload for consistency testing with fixed seed."""
return {
"messages": [
{
"role": "user",
"content": "用一句话介绍 PaddlePaddle, 30字以内 /no_think",
}
],
"temperature": 0.8,
"seed": 42, # Fixed seed
"max_tokens": 50,
}
@pytest.fixture
def rejection_server():
"""Fixture to manage rejection sampling server lifecycle."""
sampling_class = "rejection"
api_port = 8288
queue_port = 8334
metrics_port = 8433
ports_to_clean = [api_port, queue_port, metrics_port]
# Setup: Clean ports and start server
clean_specific_ports(ports_to_clean)
time.sleep(2)
process = create_server_process_with_sampling(sampling_class, api_port, queue_port, metrics_port)
# Wait for server to be ready
if not wait_for_server_ready_on_port(api_port, timeout=300):
try:
os.killpg(process.pid, signal.SIGTERM)
except Exception:
pass
pytest.fail(f"Server failed to start for {sampling_class}")
# Yield server info to test
server_info = {
"api_url": f"http://0.0.0.0:{api_port}/v1/chat/completions",
"process": process,
"sampling_class": sampling_class,
}
yield server_info
# Teardown: Clean up server
try:
os.killpg(process.pid, signal.SIGTERM)
print(f"Server terminated for {sampling_class}")
except Exception as e:
print(f"Failed to terminate server: {e}")
time.sleep(3)
@pytest.fixture
def air_server():
"""Fixture to manage AIR sampling server lifecycle."""
sampling_class = "air"
api_port = 8123
queue_port = 8534
metrics_port = 8643
ports_to_clean = [api_port, queue_port, metrics_port]
# Setup: Clean ports and start server
clean_specific_ports(ports_to_clean)
time.sleep(2)
process = create_server_process_with_sampling(sampling_class, api_port, queue_port, metrics_port)
# Wait for server to be ready with detailed error reporting
if not wait_for_server_ready_on_port(api_port, timeout=300):
# Check log file for debugging
log_file = f"server_{sampling_class}_{api_port}.log"
error_msg = f"Server failed to start for {sampling_class}"
if os.path.exists(log_file):
with open(log_file, "r") as f:
lines = f.readlines()
print(f"Server startup failed. Last 10 lines of {log_file}:")
for line in lines[-10:]:
print(f" {line.strip()}")
try:
os.killpg(process.pid, signal.SIGTERM)
except Exception:
pass
pytest.fail(error_msg)
# Yield server info to test
server_info = {
"api_url": f"http://0.0.0.0:{api_port}/v1/chat/completions",
"process": process,
"sampling_class": sampling_class,
}
yield server_info
# Teardown: Clean up server
try:
os.killpg(process.pid, signal.SIGTERM)
print(f"Server terminated for {sampling_class}")
except Exception as e:
print(f"Failed to terminate server: {e}")
time.sleep(3)
# ==========================
# Test cases
# ==========================
def test_seed_consistency_rejection_sampling(rejection_server, headers, consistent_payload):
"""
Test seed consistency for rejection sampling - multiple runs should produce identical results.
"""
server_info = rejection_server
api_url = server_info["api_url"]
sampling_class = server_info["sampling_class"]
num_runs = 5
print(f"\n===== Testing seed consistency for {sampling_class.upper()} sampling =====")
# Run multiple requests with same seed
results = []
print(f"Running {num_runs} requests with fixed seed=42:")
for i in range(num_runs):
resp = requests.post(api_url, headers=headers, json=consistent_payload)
assert resp.status_code == 200, f"Request {i+1} failed with status {resp.status_code}"
content = resp.json()["choices"][0]["message"]["content"]
results.append(content)
print(f" Run {i+1}: {content[:50]}...")
time.sleep(1)
# Check if all results are identical
reference_result = results[0]
all_identical = all(result == reference_result for result in results)
print(f"\n--- {sampling_class.upper()} Sampling Results ---")
if all_identical:
print(f" ALL {num_runs} runs produced IDENTICAL results")
print(f" Result: {reference_result}")
else:
print(" Results are NOT identical:")
for i, result in enumerate(results):
status = "yes" if result == reference_result else "no"
print(f" Run {i+1} {status}: {result}")
# Use assertion for pytest compatibility
assert (
all_identical
), f"Rejection sampling should be consistent with fixed seed. Got {len(set(results))} different outputs: {list(set(results))}"
def test_seed_consistency_air_sampling(air_server, headers, consistent_payload):
"""
Test seed consistency for AIR sampling - multiple runs should produce identical results.
"""
server_info = air_server
api_url = server_info["api_url"]
sampling_class = server_info["sampling_class"]
num_runs = 5
print(f"\n===== Testing seed consistency for {sampling_class.upper()} sampling =====")
# Run multiple requests with same seed
results = []
print(f"Running {num_runs} requests with fixed seed=42:")
for i in range(num_runs):
resp = requests.post(api_url, headers=headers, json=consistent_payload)
assert resp.status_code == 200, f"Request {i+1} failed with status {resp.status_code}"
content = resp.json()["choices"][0]["message"]["content"]
results.append(content)
print(f" Run {i+1}: {content[:50]}...")
time.sleep(1)
# Check if all results are identical
reference_result = results[0]
all_identical = all(result == reference_result for result in results)
print(f"\n--- {sampling_class.upper()} Sampling Results ---")
if all_identical:
print(f" ALL {num_runs} runs produced IDENTICAL results")
print(f" Result: {reference_result}")
else:
print(" Results are NOT identical:")
for i, result in enumerate(results):
status = "yes" if result == reference_result else "no"
print(f" Run {i+1} {status}: {result}")
# Use assertion for pytest compatibility
assert (
all_identical
), f"AIR sampling should be consistent with fixed seed. Got {len(set(results))} different outputs: {list(set(results))}"