diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 5cb1f9be3..0ed723eef 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -122,6 +122,7 @@ class ModelConfig: self.enable_mm = False self.enable_redundant_experts = False self.redundant_experts_num = 0 + self.seed = 0 self.quantization = None for key, value in args.items(): if hasattr(self, key): diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index d8c57ae45..e609475b6 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -316,6 +316,11 @@ class EngineArgs: Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values. """ + seed: int = 0 + """ + Random seed to use for initialization. If not set, defaults to 0. + """ + enable_early_stop: bool = False """ Flag to enable early stop. Default is False (disabled). @@ -484,6 +489,12 @@ class EngineArgs: default=EngineArgs.enable_logprob, help="Enable output of token-level log probabilities.", ) + model_group.add_argument( + "--seed", + type=int, + default=EngineArgs.seed, + help="Random seed for initialization. If not specified, defaults to 0.", + ) model_group.add_argument( "--enable-early-stop", action="store_true", diff --git a/fastdeploy/model_executor/layers/sample/meta_data.py b/fastdeploy/model_executor/layers/sample/meta_data.py index 9cca5af27..06281a5a5 100644 --- a/fastdeploy/model_executor/layers/sample/meta_data.py +++ b/fastdeploy/model_executor/layers/sample/meta_data.py @@ -43,6 +43,7 @@ class SamplingMetadata: top_p: paddle.Tensor top_k: Optional[paddle.Tensor] = None min_p: Optional[paddle.Tensor] = None + seed: Optional[paddle.Tensor] = None max_num_logprobs: Optional[int] = None enable_early_stop: Optional[int] = False stop_flags: Optional[paddle.Tensor] = None diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 412a7eda7..bf6b191c1 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -282,7 +282,9 @@ class Sampler(nn.Layer): probs = min_p_sampling(probs, sampling_metadata.min_p) - _, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k) + _, next_tokens = top_k_top_p_sampling( + probs, sampling_metadata.top_p, sampling_metadata.top_k, seed=sampling_metadata.seed[0, 0] + ) logprobs_tensors = ( None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens) diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 1640fd23f..d59de85df 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -29,6 +29,8 @@ from logging.handlers import BaseRotatingHandler from pathlib import Path from typing import Literal, TypeVar, Union +import numpy as np +import paddle import requests import yaml from aistudio_sdk.snapshot_download import snapshot_download as aistudio_download @@ -295,6 +297,13 @@ def extract_tar(tar_path, output_dir): raise RuntimeError(f"Extraction failed: {e!s}") +def set_random_seed(seed: int) -> None: + if seed is not None: + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + + def download_model(url, output_dir, temp_tar): """ 下载模型,并将其解压到指定目录。 diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index e0086b503..b52b35bc4 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -540,6 +540,7 @@ class GCUModelRunner(ModelRunnerBase): top_p=self.share_inputs["top_p"], top_k=self.share_inputs["top_k"], min_p=self.share_inputs["min_p"], + seed=self.share_inputs["infer_seed"], step_idx=self.share_inputs["step_idx"], pre_token_ids=self.share_inputs["pre_ids"], prompt_ids=self.share_inputs["prompt_ids"], diff --git a/fastdeploy/worker/gcu_worker.py b/fastdeploy/worker/gcu_worker.py index a16836780..54b4fa7e9 100644 --- a/fastdeploy/worker/gcu_worker.py +++ b/fastdeploy/worker/gcu_worker.py @@ -22,7 +22,7 @@ from paddle import nn from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request -from fastdeploy.utils import get_logger +from fastdeploy.utils import get_logger, set_random_seed from fastdeploy.worker.gcu_model_runner import GCUModelRunner from fastdeploy.worker.output import ModelRunnerOutput from fastdeploy.worker.worker_base import WorkerBase @@ -60,6 +60,7 @@ class GcuWorker(WorkerBase): else: raise RuntimeError(f"Not support device type: {self.device_config.device}") + set_random_seed(self.fd_config.model_config.seed) # Construct model runner self.model_runner: GCUModelRunner = GCUModelRunner( fd_config=self.fd_config, @@ -128,6 +129,7 @@ class GcuWorker(WorkerBase): self.model_runner.sot_warmup() # 2. Triger cuda grpah capture self.model_runner.capture_model() + set_random_seed(self.fd_config.model_config.seed) def check_health(self) -> bool: """ """ diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 70c37d245..5d63be3bb 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -131,6 +131,7 @@ class GPUModelRunner(ModelRunnerBase): fill_value=4, dtype="int64", ) + self.restore_chunked_prefill_request = dict() # Initialize attention Backend @@ -813,6 +814,7 @@ class GPUModelRunner(ModelRunnerBase): top_p=self.share_inputs["top_p"], top_k=self.share_inputs["top_k"], min_p=self.share_inputs["min_p"], + seed=self.share_inputs["infer_seed"], step_idx=self.share_inputs["step_idx"], pre_token_ids=self.share_inputs["pre_ids"], prompt_ids=self.share_inputs["prompt_ids"], diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index 94dc5fc19..53619f8f9 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -27,7 +27,7 @@ from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request from fastdeploy.platforms import current_platform from fastdeploy.plugins.model_runner import load_model_runner_plugins -from fastdeploy.utils import get_logger +from fastdeploy.utils import get_logger, set_random_seed from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.output import ModelRunnerOutput from fastdeploy.worker.worker_base import WorkerBase @@ -75,6 +75,7 @@ class GpuWorker(WorkerBase): else: raise RuntimeError(f"Not support device type: {self.device_config.device}") + set_random_seed(self.fd_config.model_config.seed) # Construct model runner self.model_runner: ModelRunnerBase = ModelRunner( fd_config=self.fd_config, @@ -129,6 +130,7 @@ class GpuWorker(WorkerBase): # 2. Profile run self.model_runner.profile_run() + set_random_seed(self.fd_config.model_config.seed) # 3. Statistical memory information paddle_reserved_mem_after_run = paddle.device.cuda.max_memory_reserved(local_rank) diff --git a/fastdeploy/worker/iluvatar_model_runner.py b/fastdeploy/worker/iluvatar_model_runner.py index 4a7aaaf8d..b385cd30e 100644 --- a/fastdeploy/worker/iluvatar_model_runner.py +++ b/fastdeploy/worker/iluvatar_model_runner.py @@ -509,6 +509,7 @@ class IluvatarModelRunner(ModelRunnerBase): temperature=self.share_inputs["temperature"], top_p=self.share_inputs["top_p"], top_k=self.share_inputs["top_k"], + seed=self.share_inputs["seed"], step_idx=self.share_inputs["step_idx"], pre_token_ids=self.share_inputs["pre_ids"], prompt_ids=self.share_inputs["prompt_ids"], diff --git a/fastdeploy/worker/iluvatar_worker.py b/fastdeploy/worker/iluvatar_worker.py index cd899619b..3c2a28201 100644 --- a/fastdeploy/worker/iluvatar_worker.py +++ b/fastdeploy/worker/iluvatar_worker.py @@ -23,7 +23,7 @@ from paddle import nn from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request -from fastdeploy.utils import get_logger +from fastdeploy.utils import get_logger, set_random_seed from fastdeploy.worker.iluvatar_model_runner import IluvatarModelRunner from fastdeploy.worker.output import ModelRunnerOutput from fastdeploy.worker.worker_base import WorkerBase @@ -60,6 +60,7 @@ class IluvatarWorker(WorkerBase): else: raise RuntimeError(f"Not support device type: {self.device_config.device}") + set_random_seed(self.fd_config.model_config.seed) # Construct model runner self.model_runner: IluvatarModelRunner = IluvatarModelRunner( fd_config=self.fd_config, @@ -130,6 +131,7 @@ class IluvatarWorker(WorkerBase): # 2. Triger cuda grpah capture self.model_runner.capture_model() + set_random_seed(self.fd_config.model_config.seed) def check_health(self) -> bool: """ """ diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index a153e5556..6ea5ff7e5 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -677,6 +677,7 @@ class XPUModelRunner(ModelRunnerBase): top_p=self.share_inputs["top_p"], top_k=self.share_inputs["top_k"], min_p=self.share_inputs["min_p"], + seed=self.share_inputs["infer_seed"], step_idx=self.share_inputs["step_idx"], pre_token_ids=self.share_inputs["pre_ids"], frequency_penalties=self.share_inputs["frequency_score"], diff --git a/fastdeploy/worker/xpu_worker.py b/fastdeploy/worker/xpu_worker.py index a5993abb4..c3912179f 100644 --- a/fastdeploy/worker/xpu_worker.py +++ b/fastdeploy/worker/xpu_worker.py @@ -23,7 +23,7 @@ from paddle import nn from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request -from fastdeploy.utils import get_logger +from fastdeploy.utils import get_logger, set_random_seed from fastdeploy.worker.output import ModelRunnerOutput from fastdeploy.worker.worker_base import WorkerBase from fastdeploy.worker.xpu_model_runner import XPUModelRunner @@ -60,6 +60,7 @@ class XpuWorker(WorkerBase): else: raise RuntimeError(f"Not support device type: {self.device_config.device}") + set_random_seed(self.fd_config.model_config.seed) # Construct model runner self.model_runner: XPUModelRunner = XPUModelRunner( fd_config=self.fd_config, @@ -110,6 +111,7 @@ class XpuWorker(WorkerBase): self.model_runner.prepare_profile() self.model_runner.profile_run() + set_random_seed(self.fd_config.model_config.seed) total_available_memory = int(total_memory * self.cache_config.gpu_memory_utilization) used_memory = xpu_get_used_global_memory(int(self.device_ids[self.local_rank])) diff --git a/test/ci_use/Qwen2-7B-Instruct_offline/test_Qwen2-7B-Instruct_offline.py b/test/ci_use/Qwen2-7B-Instruct_offline/test_Qwen2-7B-Instruct_offline.py index 6fcfb42e3..de18c3d2f 100644 --- a/test/ci_use/Qwen2-7B-Instruct_offline/test_Qwen2-7B-Instruct_offline.py +++ b/test/ci_use/Qwen2-7B-Instruct_offline/test_Qwen2-7B-Instruct_offline.py @@ -191,6 +191,29 @@ def test_chat_completion(llm): pytest.fail(f"Chat case {i + 1} failed") +def test_seed(llm): + """ + Test chat completion with same seed + """ + prompt = "请介绍下中国的四大发明,用一句话概述每个发明。" + sampling_params = SamplingParams(temperature=0.1, seed=1, max_tokens=100) + num_runs = 5 + + results = [] + try: + for i in range(num_runs): + outputs = llm.generate(prompt, sampling_params) + results.append(outputs[0].outputs.text) + + assert all([result == results[0] for result in results]), "Results are not identical." + print("All results are identical.") + + except Exception: + print("Failed during prompt generation.") + traceback.print_exc() + pytest.fail("Prompt generation test failed") + + if __name__ == "__main__": """ Main entry point for the test script. diff --git a/test/ci_use/Qwen3-MoE/test_sampling_consistency.py b/test/ci_use/Qwen3-MoE/test_sampling_consistency.py new file mode 100644 index 000000000..d01067f2c --- /dev/null +++ b/test/ci_use/Qwen3-MoE/test_sampling_consistency.py @@ -0,0 +1,321 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import signal +import socket +import subprocess +import sys +import time + +import pytest +import requests + + +def is_port_open(host: str, port: int, timeout=1.0): + """Check if a TCP port is open on the given host.""" + try: + with socket.create_connection((host, port), timeout): + return True + except Exception: + return False + + +def kill_process_on_port(port: int): + """Kill processes that are listening on the given port.""" + try: + output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip() + for pid in output.splitlines(): + os.kill(int(pid), signal.SIGKILL) + print(f"Killed process on port {port}, pid={pid}") + except subprocess.CalledProcessError: + pass + + +def clean_specific_ports(ports_list): + """Kill all processes occupying the specified ports.""" + for port in ports_list: + kill_process_on_port(port) + + +def create_server_process_with_sampling(sampling_class: str, api_port: int, queue_port: int, metrics_port: int): + """ + Create and start the API server process with specified sampling class and ports. + Returns the process object. + """ + base_path = os.getenv("MODEL_PATH") + if base_path: + model_path = os.path.join(base_path, "Qwen3-30B-A3B") + else: + model_path = "./Qwen3-30B-A3B" + + log_path = f"server_{sampling_class}_{api_port}.log" + cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(api_port), + "--tensor-parallel-size", + "1", + "--engine-worker-queue-port", + str(queue_port), + "--metrics-port", + str(metrics_port), + "--max-model-len", + "32768", + "--max-num-seqs", + "50", + "--quantization", + "wint4", + ] + + env = os.environ.copy() + env["FD_SAMPLING_CLASS"] = sampling_class + + print(f"Starting server with FD_SAMPLING_CLASS={sampling_class} on port {api_port}") + + with open(log_path, "w") as logfile: + process = subprocess.Popen( + cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, + env=env, + ) + + return process + + +def wait_for_server_ready_on_port(api_port: int, timeout=300): + """Wait for the API server to be ready on specified port.""" + for _ in range(timeout): + if is_port_open("127.0.0.1", api_port): + print(f"API server is up on port {api_port}") + return True + time.sleep(1) + return False + + +# ========================== +# Fixtures for pytest +# ========================== + + +@pytest.fixture +def headers(): + """Returns common HTTP request headers.""" + return {"Content-Type": "application/json"} + + +@pytest.fixture +def consistent_payload(): + """Returns a fixed payload for consistency testing with fixed seed.""" + return { + "messages": [ + { + "role": "user", + "content": "用一句话介绍 PaddlePaddle, 30字以内 /no_think", + } + ], + "temperature": 0.8, + "seed": 42, # Fixed seed + "max_tokens": 50, + } + + +@pytest.fixture +def rejection_server(): + """Fixture to manage rejection sampling server lifecycle.""" + sampling_class = "rejection" + api_port = 8288 + queue_port = 8334 + metrics_port = 8433 + ports_to_clean = [api_port, queue_port, metrics_port] + + # Setup: Clean ports and start server + clean_specific_ports(ports_to_clean) + time.sleep(2) + process = create_server_process_with_sampling(sampling_class, api_port, queue_port, metrics_port) + + # Wait for server to be ready + if not wait_for_server_ready_on_port(api_port, timeout=300): + try: + os.killpg(process.pid, signal.SIGTERM) + except Exception: + pass + pytest.fail(f"Server failed to start for {sampling_class}") + + # Yield server info to test + server_info = { + "api_url": f"http://0.0.0.0:{api_port}/v1/chat/completions", + "process": process, + "sampling_class": sampling_class, + } + + yield server_info + + # Teardown: Clean up server + try: + os.killpg(process.pid, signal.SIGTERM) + print(f"Server terminated for {sampling_class}") + except Exception as e: + print(f"Failed to terminate server: {e}") + time.sleep(3) + + +@pytest.fixture +def air_server(): + """Fixture to manage AIR sampling server lifecycle.""" + sampling_class = "air" + api_port = 8123 + queue_port = 8534 + metrics_port = 8643 + ports_to_clean = [api_port, queue_port, metrics_port] + + # Setup: Clean ports and start server + clean_specific_ports(ports_to_clean) + time.sleep(2) + process = create_server_process_with_sampling(sampling_class, api_port, queue_port, metrics_port) + + # Wait for server to be ready with detailed error reporting + if not wait_for_server_ready_on_port(api_port, timeout=300): + # Check log file for debugging + log_file = f"server_{sampling_class}_{api_port}.log" + error_msg = f"Server failed to start for {sampling_class}" + + if os.path.exists(log_file): + with open(log_file, "r") as f: + lines = f.readlines() + print(f"Server startup failed. Last 10 lines of {log_file}:") + for line in lines[-10:]: + print(f" {line.strip()}") + + try: + os.killpg(process.pid, signal.SIGTERM) + except Exception: + pass + pytest.fail(error_msg) + + # Yield server info to test + server_info = { + "api_url": f"http://0.0.0.0:{api_port}/v1/chat/completions", + "process": process, + "sampling_class": sampling_class, + } + + yield server_info + + # Teardown: Clean up server + try: + os.killpg(process.pid, signal.SIGTERM) + print(f"Server terminated for {sampling_class}") + except Exception as e: + print(f"Failed to terminate server: {e}") + time.sleep(3) + + +# ========================== +# Test cases +# ========================== + + +def test_seed_consistency_rejection_sampling(rejection_server, headers, consistent_payload): + """ + Test seed consistency for rejection sampling - multiple runs should produce identical results. + """ + server_info = rejection_server + api_url = server_info["api_url"] + sampling_class = server_info["sampling_class"] + num_runs = 5 + + print(f"\n===== Testing seed consistency for {sampling_class.upper()} sampling =====") + + # Run multiple requests with same seed + results = [] + print(f"Running {num_runs} requests with fixed seed=42:") + + for i in range(num_runs): + resp = requests.post(api_url, headers=headers, json=consistent_payload) + assert resp.status_code == 200, f"Request {i+1} failed with status {resp.status_code}" + + content = resp.json()["choices"][0]["message"]["content"] + results.append(content) + print(f" Run {i+1}: {content[:50]}...") + time.sleep(1) + + # Check if all results are identical + reference_result = results[0] + all_identical = all(result == reference_result for result in results) + + print(f"\n--- {sampling_class.upper()} Sampling Results ---") + if all_identical: + print(f" ALL {num_runs} runs produced IDENTICAL results") + print(f" Result: {reference_result}") + else: + print(" Results are NOT identical:") + for i, result in enumerate(results): + status = "yes" if result == reference_result else "no" + print(f" Run {i+1} {status}: {result}") + + # Use assertion for pytest compatibility + assert ( + all_identical + ), f"Rejection sampling should be consistent with fixed seed. Got {len(set(results))} different outputs: {list(set(results))}" + + +def test_seed_consistency_air_sampling(air_server, headers, consistent_payload): + """ + Test seed consistency for AIR sampling - multiple runs should produce identical results. + """ + server_info = air_server + api_url = server_info["api_url"] + sampling_class = server_info["sampling_class"] + num_runs = 5 + + print(f"\n===== Testing seed consistency for {sampling_class.upper()} sampling =====") + + # Run multiple requests with same seed + results = [] + print(f"Running {num_runs} requests with fixed seed=42:") + + for i in range(num_runs): + resp = requests.post(api_url, headers=headers, json=consistent_payload) + assert resp.status_code == 200, f"Request {i+1} failed with status {resp.status_code}" + + content = resp.json()["choices"][0]["message"]["content"] + results.append(content) + print(f" Run {i+1}: {content[:50]}...") + time.sleep(1) + + # Check if all results are identical + reference_result = results[0] + all_identical = all(result == reference_result for result in results) + + print(f"\n--- {sampling_class.upper()} Sampling Results ---") + if all_identical: + print(f" ALL {num_runs} runs produced IDENTICAL results") + print(f" Result: {reference_result}") + else: + print(" Results are NOT identical:") + for i, result in enumerate(results): + status = "yes" if result == reference_result else "no" + print(f" Run {i+1} {status}: {result}") + + # Use assertion for pytest compatibility + assert ( + all_identical + ), f"AIR sampling should be consistent with fixed seed. Got {len(set(results))} different outputs: {list(set(results))}"