[Log] Add trace log and add loggingInstrumentor tool (#4692)

* add trace logger and trace print

* trigger ci

* fix unittest

* translate notes and add copyright

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
qwes5s5
2025-11-17 11:08:57 +08:00
committed by GitHub
parent 5444af6ff6
commit 36216e62f0
21 changed files with 941 additions and 43 deletions

View File

@@ -49,6 +49,8 @@ from fastdeploy.plugins.token_processor import load_token_processor_plugins
from fastdeploy.router.utils import check_service_health
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.trace.constants import LoggingEventName
from fastdeploy.trace.trace_logger import print as trace_print
from fastdeploy.utils import (
EngineError,
check_download_links,
@@ -364,7 +366,7 @@ class EngineService:
for item in tasks:
item.schedule_start_time = time.time()
trace_print(LoggingEventName.RESOURCE_ALLOCATE_START, item.request_id, getattr(item, "user", ""))
available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch:
self.llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
@@ -398,6 +400,9 @@ class EngineService:
self.llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
for task in tasks:
task.inference_start_time = time.time()
trace_print(LoggingEventName.RESOURCE_ALLOCATE_END, task.request_id, getattr(task, "user", ""))
trace_print(LoggingEventName.REQUEST_SCHEDULE_END, task.request_id, getattr(task, "user", ""))
trace_print(LoggingEventName.INFERENCE_START, task.request_id, getattr(task, "user", ""))
if not is_prefill:
if not self.cfg.model_config.enable_mm:
self.update_requests_chunk_size(tasks)
@@ -636,7 +641,8 @@ class EngineService:
max_num_batched_tokens=self.cfg.scheduler_config.max_num_batched_tokens,
batch=num_prefill_batch,
)
for task in tasks:
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
if len(tasks) == 0:
time.sleep(0.001)
continue
@@ -689,6 +695,8 @@ class EngineService:
max_num_batched_tokens=max_num_batched_tokens,
batch=num_prefill_batch,
)
for task in tasks:
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
if self.cfg.scheduler_config.splitwise_role == "decode":
# Decode will instert the request sent by prefill to engine,
@@ -761,6 +769,10 @@ class EngineService:
time.sleep(0.001)
# Fetch requests and add them to the scheduling queue
if tasks:
for task in tasks:
trace_print(
LoggingEventName.RESOURCE_ALLOCATE_START, task.request_id, getattr(task, "user", "")
)
if self.cfg.scheduler_config.splitwise_role == "prefill":
self.resource_manager.add_request_in_p(tasks)
else:
@@ -816,6 +828,10 @@ class EngineService:
]
)
self.resource_manager.get_real_bsz()
for task in tasks:
trace_print(LoggingEventName.RESOURCE_ALLOCATE_END, task.request_id, getattr(task, "user", ""))
trace_print(LoggingEventName.REQUEST_SCHEDULE_END, task.request_id, getattr(task, "user", ""))
trace_print(LoggingEventName.INFERENCE_START, task.request_id, getattr(task, "user", ""))
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
else:
time.sleep(0.005)
@@ -877,6 +893,10 @@ class EngineService:
request.llm_engine_recv_req_timestamp = time.time()
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
main_process_metrics.requests_number.inc()
self.llm_logger.debug(f"Receive request: {request}")
trace_print(LoggingEventName.PREPROCESSING_END, data["request_id"], data.get("user", ""))
trace_print(LoggingEventName.REQUEST_SCHEDULE_START, data["request_id"], data.get("user", ""))
trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", ""))
self.llm_logger.debug(f"Receive request from api server: {request}")
except Exception as e:
self.llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}")

View File

@@ -38,6 +38,8 @@ from fastdeploy.inter_communicator import (
)
from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.platforms import current_platform
from fastdeploy.trace.constants import LoggingEventName
from fastdeploy.trace.trace_logger import print as trace_print
from fastdeploy.utils import (
EngineError,
ParameterError,
@@ -185,6 +187,7 @@ class EngineClient:
"""
task["preprocess_start_time"] = time.time()
trace_print(LoggingEventName.PREPROCESSING_START, task["request_id"], task.get("user", ""))
try:
chat_template_kwargs = task.get("chat_template_kwargs") or {}
chat_template_kwargs.update({"chat_template": task.get("chat_template")})

View File

@@ -40,6 +40,8 @@ from fastdeploy.entrypoints.openai.protocol import (
)
from fastdeploy.entrypoints.openai.response_processors import ChatResponseProcessor
from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.trace.constants import LoggingEventName
from fastdeploy.trace.trace_logger import print as trace_print
from fastdeploy.utils import (
ErrorCode,
ErrorType,
@@ -448,6 +450,7 @@ class OpenAIServingChat:
finally:
await self.engine_client.connection_manager.cleanup_request(request_id)
self.engine_client.semaphore.release()
trace_print(LoggingEventName.POSTPROCESSING_END, request_id, getattr(request, "user", ""))
api_server_logger.info(f"release {request_id} {self.engine_client.semaphore.status()}")
yield "data: [DONE]\n\n"
@@ -599,6 +602,7 @@ class OpenAIServingChat:
choices=choices,
usage=usage,
)
trace_print(LoggingEventName.POSTPROCESSING_END, request_id, getattr(request, "user", ""))
api_server_logger.info(f"Chat response: {res.model_dump_json()}")
return res

View File

@@ -36,6 +36,8 @@ from fastdeploy.entrypoints.openai.protocol import (
PromptTokenUsageInfo,
UsageInfo,
)
from fastdeploy.trace.constants import LoggingEventName
from fastdeploy.trace.trace_logger import print as trace_print
from fastdeploy.utils import (
ErrorCode,
ErrorType,
@@ -316,6 +318,7 @@ class OpenAIServingCompletion:
except Exception as e:
api_server_logger.error(f"Error in completion_full_generator: {e}", exc_info=True)
finally:
trace_print(LoggingEventName.POSTPROCESSING_END, request_id, getattr(request, "user", ""))
self.engine_client.semaphore.release()
if dealer is not None:
await self.engine_client.connection_manager.cleanup_request(request_id)
@@ -551,6 +554,7 @@ class OpenAIServingCompletion:
api_server_logger.error(f"Error in completion_stream_generator: {e}, {str(traceback.format_exc())}")
yield f"data: {ErrorResponse(error=ErrorInfo(message=str(e), code='400', type=ErrorType.INTERNAL_ERROR)).model_dump_json(exclude_unset=True)}\n\n"
finally:
trace_print(LoggingEventName.POSTPROCESSING_END, request_id, getattr(request, "user", ""))
del request
if dealer is not None:
await self.engine_client.connection_manager.cleanup_request(request_id)

View File

@@ -14,38 +14,51 @@
"""
"""
自定义日志格式化器模块
该模块定义了 ColoredFormatter 类,用于在控制台输出带颜色的日志信息,
便于开发者在终端中快速识别不同级别的日志。
Custom log formatter module
This module defines the ColoredFormatter class for outputting colored log information to the console,
helping developers quickly identify different levels of logs in the terminal.
"""
import logging
import re
import time
class ColoredFormatter(logging.Formatter):
"""
自定义日志格式器,用于控制台输出带颜色的日志。
支持的颜色:
- WARNING: 黄色
- ERROR: 红色
- CRITICAL: 红色
- 其他等级: 默认终端颜色
Custom log formatter for console output with colored logs.
Supported colors:
- WARNING: Yellow
- ERROR: Red
- CRITICAL: Red
- Other levels: Default terminal color
"""
COLOR_CODES = {
logging.WARNING: 33, # 黄色
logging.ERROR: 31, # 红色
logging.CRITICAL: 31, # 红色
logging.WARNING: 33, # Yellow
logging.ERROR: 31, # Red
logging.CRITICAL: 31, # Red
}
def format(self, record):
"""
格式化日志记录,并根据日志等级添加 ANSI 颜色前缀和后缀。
Format log record and add ANSI color prefix and suffix based on log level.
Newly supports attributes expansion and otelSpanID/otelTraceID fields.
Args:
record (LogRecord): 日志记录对象。
record (LogRecord): Log record object.
Returns:
str: 带有颜色的日志消息字符串。
str: Colored log message string.
"""
try:
# Add OpenTelemetry-related fields.
if hasattr(record, "otelSpanID") and record.otelSpanID is not None:
record.msg = f"[otel_span_id={record.otelSpanID}] {record.msg}"
if hasattr(record, "otelTraceID") and record.otelTraceID is not None:
record.msg = f"[otel_trace_id={record.otelTraceID}] {record.msg}"
except:
pass
color_code = self.COLOR_CODES.get(record.levelno, 0)
prefix = f"\033[{color_code}m"
suffix = "\033[0m"
@@ -53,3 +66,63 @@ class ColoredFormatter(logging.Formatter):
if color_code:
message = f"{prefix}{message}{suffix}"
return message
class CustomFormatter(logging.Formatter):
"""
Custom log formatter for console output.
Supports field expansion and adds thread, timestamp and other information.
"""
def _format_attributes(self, record):
"""
Expand attributes in record to [attr=value] format
"""
if hasattr(record, "attributes"):
if isinstance(record.attributes, dict):
return " ".join(f"[{k}={v}]" for k, v in record.attributes.items())
return ""
def _camel_to_snake(self, name: str) -> str:
"""Convert camel case to snake case"""
s1 = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name)
return s1.lower()
def format(self, record):
"""
Format log record, with new support for attributes expansion and otelSpanID/otelTraceID fields.
Supports field expansion and adds thread, timestamp and other information.
Args:
record (LogRecord): Log record object.
Returns:
str: Log message string.
"""
try:
log_fields = {
"thread": record.thread,
"thread_name": record.threadName,
"timestamp": int(time.time() * 1000),
}
if hasattr(record, "attributes") and isinstance(record.attributes, dict):
for k, v in record.attributes.items():
log_fields[self._camel_to_snake(k)] = v
# filter out null values.
log_fields = {k: v for k, v in log_fields.items() if not (isinstance(v, str) and v == "")}
log_str = " ".join(f"[{k}={v}]" for k, v in log_fields.items())
if log_str:
record.msg = f"{log_str} {record.msg}"
# Add OpenTelemetry-related fields.
if hasattr(record, "otelSpanID") and record.otelSpanID is not None:
record.msg = f"[otel_span_id={record.otelSpanID}] {record.msg}"
if hasattr(record, "otelTraceID") and record.otelTraceID is not None:
record.msg = f"[otel_trace_id={record.otelTraceID}] {record.msg}"
except:
pass
return super().format(record)

View File

@@ -24,7 +24,7 @@ import threading
from pathlib import Path
from fastdeploy import envs
from fastdeploy.logger.formatters import ColoredFormatter
from fastdeploy.logger.formatters import ColoredFormatter, CustomFormatter
from fastdeploy.logger.handlers import DailyRotatingFileHandler, LazyFileHandler
from fastdeploy.logger.setup_logging import setup_logging
@@ -95,6 +95,71 @@ class FastDeployLogger:
# 其他情况添加fastdeploy前缀
return logging.getLogger(f"fastdeploy.{name}")
def get_trace_logger(self, name, file_name, without_formater=False, print_to_console=False):
"""
Log retrieval method compatible with the original interface
"""
log_dir = envs.FD_LOG_DIR
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
is_debug = int(envs.FD_DEBUG)
# logger = logging.getLogger(name)
# Use namespace for isolation to avoid logger overwrite and confusion issues for compatibility with original interface
legacy_name = f"legacy.{name}"
logger = logging.getLogger(legacy_name)
# Set log level
if is_debug:
logger.setLevel(level=logging.DEBUG)
else:
logger.setLevel(level=logging.INFO)
# Set formatter
formatter = CustomFormatter(
"[%(asctime)s] [%(levelname)-8s] (%(filename)s:%(funcName)s:%(lineno)d) %(message)s"
)
# Clear existing handlers (maintain original logic)
for handler in logger.handlers[:]:
logger.removeHandler(handler)
# Create main log file handler
LOG_FILE = f"{log_dir}/{file_name}"
backup_count = int(envs.FD_LOG_BACKUP_COUNT)
# handler = LazyFileHandler(filename=LOG_FILE, backupCount=backup_count, level=hanlder_level)
handler = DailyRotatingFileHandler(LOG_FILE, backupCount=backup_count)
# Create ERROR log file handler (new feature)
if not file_name.endswith(".log"):
file_name = f"{file_name}.log" if "." not in file_name else file_name.split(".")[0] + ".log"
ERROR_LOG_FILE = os.path.join(log_dir, file_name.replace(".log", "_error.log"))
error_handler = LazyFileHandler(
filename=ERROR_LOG_FILE, backupCount=backup_count, level=logging.ERROR, formatter=None
)
if not without_formater:
handler.setFormatter(formatter)
error_handler.setFormatter(formatter)
# Add file handlers
logger.addHandler(handler)
logger.addHandler(error_handler)
# Console handler
if print_to_console:
console_handler = logging.StreamHandler()
if not without_formater:
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
console_handler.propagate = False
# Set propagate (maintain original logic)
# logger.propagate = False
return logger
def _get_legacy_logger(self, name, file_name, without_formater=False, print_to_console=False):
"""
兼容原有接口的日志获取方式

View File

@@ -5,6 +5,7 @@ from fastapi import FastAPI
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.instrumentation.logging import LoggingInstrumentor
from opentelemetry.propagate import extract, inject
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import SpanProcessor, TracerProvider
@@ -116,6 +117,10 @@ def instrument(app: FastAPI):
if traces_enable:
llm_logger.info("Applying instrumentors...")
FastAPIInstrumentor.instrument_app(app)
try:
LoggingInstrumentor().instrument(set_logging_format=True)
except Exception:
pass
except:
llm_logger.info("instrument failed")
pass

View File

@@ -40,6 +40,8 @@ from fastdeploy.engine.request import (
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcServer
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.platforms import current_platform
from fastdeploy.trace.constants import LoggingEventName
from fastdeploy.trace.trace_logger import print as trace_print
from fastdeploy.utils import llm_logger, spec_logger
from fastdeploy.worker.output import LogprobsLists
@@ -774,6 +776,8 @@ class TokenProcessor:
def _record_first_token_metrics(self, task, current_time):
"""Record metrics for first token"""
task.first_token_time = current_time
trace_print(LoggingEventName.FIRST_TOKEN_GENERATED, task.request_id, getattr(task, "user", ""))
trace_print(LoggingEventName.DECODE_START, task.request_id, getattr(task, "user", ""))
main_process_metrics.first_token_latency.set(current_time - task.inference_start_time)
main_process_metrics.time_to_first_token.observe(current_time - task.inference_start_time)
main_process_metrics.request_queue_time.observe(task.schedule_start_time - task.preprocess_end_time)
@@ -783,7 +787,8 @@ class TokenProcessor:
if hasattr(task, "first_token_time"):
decode_time = current_time - task.first_token_time
main_process_metrics.request_decode_time.observe(decode_time)
trace_print(LoggingEventName.INFERENCE_END, task.request_id, getattr(task, "user", ""))
trace_print(LoggingEventName.POSTPROCESSING_START, task.request_id, getattr(task, "user", ""))
main_process_metrics.num_requests_running.dec(1)
main_process_metrics.request_success_total.inc()
main_process_metrics.infer_latency.set(current_time - task.inference_start_time)

View File

@@ -0,0 +1,13 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,66 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
class LoggingEventName(Enum):
"""
Represents various event points in the system.
"""
PREPROCESSING_START = "PREPROCESSING_START"
PREPROCESSING_END = "PREPROCESSING_END"
REQUEST_SCHEDULE_START = "REQUEST_SCHEDULE_START"
REQUEST_QUEUE_START = "REQUEST_QUEUE_START"
REQUEST_QUEUE_END = "REQUEST_QUEUE_END"
RESOURCE_ALLOCATE_START = "RESOURCE_ALLOCATE_START"
RESOURCE_ALLOCATE_END = "RESOURCE_ALLOCATE_END"
REQUEST_SCHEDULE_END = "REQUEST_SCHEDULE_END"
INFERENCE_START = "INFERENCE_START"
FIRST_TOKEN_GENERATED = "FIRST_TOKEN_GENERATED"
DECODE_START = "DECODE_START"
INFERENCE_END = "INFERENCE_END"
POSTPROCESSING_START = "POSTPROCESSING_START"
POSTPROCESSING_END = "POSTPROCESSING_END"
class StageName(Enum):
"""
Represents the main stages in the request processing flow.
"""
PREPROCESSING = "PREPROCESSING"
SCHEDULE = "SCHEDULE"
PREFILL = "PREFILL"
DECODE = "DECODE"
POSTPROCESSING = "POSTPROCESSING"
EVENT_TO_STAGE_MAP = {
LoggingEventName.PREPROCESSING_START: StageName.PREPROCESSING,
LoggingEventName.PREPROCESSING_END: StageName.PREPROCESSING,
LoggingEventName.REQUEST_SCHEDULE_START: StageName.SCHEDULE,
LoggingEventName.REQUEST_QUEUE_START: StageName.SCHEDULE,
LoggingEventName.REQUEST_QUEUE_END: StageName.SCHEDULE,
LoggingEventName.RESOURCE_ALLOCATE_START: StageName.SCHEDULE,
LoggingEventName.RESOURCE_ALLOCATE_END: StageName.SCHEDULE,
LoggingEventName.REQUEST_SCHEDULE_END: StageName.SCHEDULE,
LoggingEventName.INFERENCE_START: StageName.PREFILL,
LoggingEventName.FIRST_TOKEN_GENERATED: StageName.PREFILL,
LoggingEventName.DECODE_START: StageName.DECODE,
LoggingEventName.INFERENCE_END: StageName.DECODE,
LoggingEventName.POSTPROCESSING_START: StageName.POSTPROCESSING,
LoggingEventName.POSTPROCESSING_END: StageName.POSTPROCESSING,
}

View File

@@ -0,0 +1,38 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fastdeploy.trace.constants import EVENT_TO_STAGE_MAP
from fastdeploy.utils import trace_logger
def print(event, request_id, user):
"""
Records task tracking log information, including task name, start time, end time, etc.
Args:
task (Task): Task object to be recorded.
"""
try:
trace_logger.info(
"",
extra={
"attributes": {
"request_id": f"{request_id}",
"user_id": f"{user}",
"event": event.value,
"stage": EVENT_TO_STAGE_MAP.get(event).value,
}
},
)
except:
pass

View File

@@ -982,6 +982,7 @@ api_server_logger = get_logger("api_server", "api_server.log")
console_logger = get_logger("console", "console.log", print_to_console=True)
spec_logger = get_logger("speculate", "speculate.log")
zmq_client_logger = get_logger("zmq_client", "zmq_client.log")
trace_logger = FastDeployLogger().get_trace_logger("trace_logger", "trace_logger.log")
router_logger = get_logger("router", "router.log")

View File

@@ -40,6 +40,7 @@ opentelemetry-instrumentation-mysql
opentelemetry-distro 
opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi
opentelemetry-instrumentation-logging
partial_json_parser
msgspec
einops

View File

@@ -37,4 +37,5 @@ opentelemetry-instrumentation-mysql
opentelemetry-distro 
opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi
opentelemetry-instrumentation-logging
partial_json_parser

View File

@@ -37,6 +37,7 @@ opentelemetry-instrumentation-mysql
opentelemetry-distro
opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi
opentelemetry-instrumentation-logging
partial_json_parser
msgspec
safetensors==0.7.0rc0

View File

@@ -40,4 +40,5 @@ opentelemetry-instrumentation-mysql
opentelemetry-distro
opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi
opentelemetry-instrumentation-logging
partial_json_parser

View File

@@ -16,27 +16,27 @@
import logging
import unittest
from fastdeploy.logger.formatters import ColoredFormatter
from fastdeploy.logger.formatters import ColoredFormatter, CustomFormatter
class TestColoredFormatter(unittest.TestCase):
"""测试 ColoredFormatter """
"""Test ColoredFormatter class"""
def setUp(self):
"""测试前准备"""
"""Test preparation"""
self.formatter = ColoredFormatter("%(levelname)s - %(message)s")
def test_color_codes_definition(self):
"""测试颜色代码定义"""
"""Test color code definition"""
expected_colors = {
logging.WARNING: 33, # 黄色
logging.ERROR: 31, # 红色
logging.CRITICAL: 31, # 红色
logging.WARNING: 33, # yellow
logging.ERROR: 31, # red
logging.CRITICAL: 31, # red
}
self.assertEqual(self.formatter.COLOR_CODES, expected_colors)
def test_format_warning_message(self):
"""测试 WARNING 级别日志格式化(黄色)"""
"""Test WARNING level log formatting (yellow)"""
record = logging.LogRecord(
name="test", level=logging.WARNING, pathname="", lineno=0, msg="This is a warning", args=(), exc_info=None
)
@@ -46,7 +46,7 @@ class TestColoredFormatter(unittest.TestCase):
self.assertEqual(formatted_message, expected)
def test_format_error_message(self):
"""测试 ERROR 级别日志格式化(红色)"""
"""Test ERROR level log formatting (red)"""
record = logging.LogRecord(
name="test", level=logging.ERROR, pathname="", lineno=0, msg="This is an error", args=(), exc_info=None
)
@@ -56,7 +56,7 @@ class TestColoredFormatter(unittest.TestCase):
self.assertEqual(formatted_message, expected)
def test_format_critical_message(self):
"""测试 CRITICAL 级别日志格式化(红色)"""
"""Test CRITICAL level log formatting (red)"""
record = logging.LogRecord(
name="test", level=logging.CRITICAL, pathname="", lineno=0, msg="This is critical", args=(), exc_info=None
)
@@ -66,7 +66,7 @@ class TestColoredFormatter(unittest.TestCase):
self.assertEqual(formatted_message, expected)
def test_format_info_message(self):
"""测试 INFO 级别日志格式化(无颜色)"""
"""Test INFO level log formatting (no color)"""
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="This is info", args=(), exc_info=None
)
@@ -76,7 +76,7 @@ class TestColoredFormatter(unittest.TestCase):
self.assertEqual(formatted_message, expected)
def test_format_debug_message(self):
"""测试 DEBUG 级别日志格式化(无颜色)"""
"""Test DEBUG level log formatting (no color)"""
record = logging.LogRecord(
name="test", level=logging.DEBUG, pathname="", lineno=0, msg="This is debug", args=(), exc_info=None
)
@@ -86,9 +86,9 @@ class TestColoredFormatter(unittest.TestCase):
self.assertEqual(formatted_message, expected)
def test_format_custom_level(self):
"""测试自定义级别日志格式化(无颜色)"""
# 创建自定义级别
custom_level = 25 # 介于 INFO(20) WARNING(30) 之间
"""Test custom level log formatting (no color)"""
# Create custom level
custom_level = 25 # Between INFO(20) and WARNING(30)
record = logging.LogRecord(
name="test", level=custom_level, pathname="", lineno=0, msg="This is custom level", args=(), exc_info=None
)
@@ -98,6 +98,321 @@ class TestColoredFormatter(unittest.TestCase):
expected = "CUSTOM - This is custom level"
self.assertEqual(formatted_message, expected)
def test_format_with_otel_span_id(self):
"""Test log formatting with otelSpanID"""
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="This has span", args=(), exc_info=None
)
record.otelSpanID = "span123"
formatted_message = self.formatter.format(record)
expected = "INFO - [otel_span_id=span123] This has span"
self.assertEqual(formatted_message, expected)
def test_format_with_otel_trace_id(self):
"""Test log formatting with otelTraceID"""
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="This has trace", args=(), exc_info=None
)
record.otelTraceID = "trace456"
formatted_message = self.formatter.format(record)
expected = "INFO - [otel_trace_id=trace456] This has trace"
self.assertEqual(formatted_message, expected)
class TestCustomFormatter(unittest.TestCase):
"""Test CustomFormatter class"""
def setUp(self):
"""Test preparation"""
self.formatter = CustomFormatter("%(levelname)s - %(message)s")
def test_format_with_attributes(self):
"""Test log formatting with attributes"""
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="This has attrs", args=(), exc_info=None
)
record.attributes = {"key1": "value1", "key2": "value2"}
formatted_message = self.formatter.format(record)
self.assertIn("[key1=value1]", formatted_message)
self.assertIn("[key2=value2]", formatted_message)
self.assertIn("This has attrs", formatted_message)
def test_format_with_camel_case_attributes(self):
"""Test conversion of camelCase attributes"""
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="This has camelCase", args=(), exc_info=None
)
record.attributes = {"camelCaseKey": "value"}
formatted_message = self.formatter.format(record)
self.assertIn("[camel_case_key=value]", formatted_message)
def test_format_with_empty_attributes(self):
"""Test handling of empty attributes"""
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="Empty attrs", args=(), exc_info=None
)
record.attributes = {}
formatted_message = self.formatter.format(record)
# Check if thread info and timestamp are included
self.assertIn("[thread=", formatted_message)
self.assertIn("[thread_name=", formatted_message)
self.assertIn("[timestamp=", formatted_message)
self.assertTrue(formatted_message.endswith("Empty attrs"))
def test_format_with_thread_info(self):
"""Test addition of thread information"""
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="Thread test", args=(), exc_info=None
)
record.thread = 123
record.threadName = "TestThread"
formatted_message = self.formatter.format(record)
self.assertIn("[thread=123]", formatted_message)
self.assertIn("[thread_name=TestThread]", formatted_message)
self.assertIn("[timestamp=", formatted_message) # Check timestamp
def test_format_attributes_method(self):
"""Test _format_attributes method"""
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="Test attributes", args=(), exc_info=None
)
record.attributes = {"key1": "value1", "key2": "value2"}
# Directly call _format_attributes method
formatted_attrs = self.formatter._format_attributes(record)
self.assertEqual(formatted_attrs, "[key1=value1] [key2=value2]")
def test_format_attributes_method_empty(self):
"""Test _format_attributes method handling empty attributes"""
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="Test empty", args=(), exc_info=None
)
record.attributes = {}
formatted_attrs = self.formatter._format_attributes(record)
self.assertEqual(formatted_attrs, "")
def test_format_attributes_method_none(self):
"""Test _format_attributes method handling no attributes"""
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="Test none", args=(), exc_info=None
)
formatted_attrs = self.formatter._format_attributes(record)
self.assertEqual(formatted_attrs, "")
def test_format_attributes_method_invalid_type(self):
"""Test _format_attributes method handling non-dict attributes"""
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="Test invalid", args=(), exc_info=None
)
record.attributes = "invalid"
formatted_attrs = self.formatter._format_attributes(record)
self.assertEqual(formatted_attrs, "")
def test_camel_to_snake_method(self):
"""Test _camel_to_snake method"""
# Test camelCase to snake_case conversion
self.assertEqual(self.formatter._camel_to_snake("camelCase"), "camel_case")
self.assertEqual(self.formatter._camel_to_snake("CamelCase"), "camel_case")
self.assertEqual(self.formatter._camel_to_snake("camelCaseKey"), "camel_case_key")
self.assertEqual(self.formatter._camel_to_snake("already_snake"), "already_snake")
def test_format_with_empty_string_attributes(self):
"""Test handling of empty string attributes"""
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="Empty string attrs", args=(), exc_info=None
)
record.attributes = {"key1": "", "key2": "value2"}
formatted_message = self.formatter.format(record)
# Empty string key1 should be filtered out
self.assertNotIn("[key1=]", formatted_message)
self.assertIn("[key2=value2]", formatted_message)
def test_format_with_both_otel_and_attributes(self):
"""Test case with both otel fields and attributes"""
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="Both otel and attrs", args=(), exc_info=None
)
record.attributes = {"key1": "value1"}
record.otelSpanID = "span123"
record.otelTraceID = "trace456"
formatted_message = self.formatter.format(record)
self.assertIn("[key1=value1]", formatted_message)
self.assertIn("[otel_span_id=span123]", formatted_message)
self.assertIn("[otel_trace_id=trace456]", formatted_message)
def test_format_exception_handling(self):
"""Test exception handling mechanism"""
# Create a record that will cause an exception
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="Exception test", args=(), exc_info=None
)
# Add an attribute that will cause an exception
record.thread = "invalid_thread" # This will cause an exception because thread should be an integer
# Even with exceptions, the format method should return normally
formatted_message = self.formatter.format(record)
self.assertIsInstance(formatted_message, str)
self.assertIn("Exception test", formatted_message)
def test_format_with_none_otel_fields(self):
"""Test handling of None value otel fields"""
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="None otel", args=(), exc_info=None
)
record.otelSpanID = None
record.otelTraceID = None
formatted_message = self.formatter.format(record)
# None value otel fields should not be added
self.assertNotIn("otel_span_id", formatted_message)
self.assertNotIn("otel_trace_id", formatted_message)
class TestColoredFormatterExceptionHandling(unittest.TestCase):
"""Test ColoredFormatter exception handling"""
def setUp(self):
"""Test preparation"""
self.formatter = ColoredFormatter("%(levelname)s - %(message)s")
def test_format_exception_handling(self):
"""Test ColoredFormatter exception handling mechanism"""
# Create a record that will cause an exception
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="Exception test", args=(), exc_info=None
)
# Add an attribute that will cause an exception
record.otelSpanID = object() # Non-string type, may cause an exception
# Even with exceptions, the format method should return normally
formatted_message = self.formatter.format(record)
self.assertIsInstance(formatted_message, str)
self.assertIn("Exception test", formatted_message)
def test_format_with_none_otel_fields(self):
"""Test handling of None value otel fields"""
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="None otel", args=(), exc_info=None
)
record.otelSpanID = None
record.otelTraceID = None
formatted_message = self.formatter.format(record)
# None value otel fields should not be added
self.assertNotIn("otel_span_id", formatted_message)
self.assertNotIn("otel_trace_id", formatted_message)
def test_format_with_invalid_otel_fields(self):
"""Test handling of invalid otel fields"""
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="Invalid otel", args=(), exc_info=None
)
# Set invalid attributes to ensure exceptions are caught
record.otelSpanID = 123 # Integer type, not string
record.otelTraceID = 456 # Integer type, not string
formatted_message = self.formatter.format(record)
self.assertIsInstance(formatted_message, str)
self.assertIn("Invalid otel", formatted_message)
def test_colored_formatter_exception_handling_with_forced_error(self):
"""Test ColoredFormatter exception handling - forced exception"""
# Create test record and add special attributes that will cause exceptions
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="Forced error test", args=(), exc_info=None
)
# Add attribute that will cause AttributeError
class BadOtelSpanID:
def __str__(self):
raise AttributeError("Forced attribute error")
record.otelSpanID = BadOtelSpanID()
# Call format method, should catch exception and continue execution
formatted_message = self.formatter.format(record)
self.assertIsInstance(formatted_message, str)
self.assertIn("Forced error test", formatted_message)
def test_custom_colored_formatter_exception_handling_with_forced_error(self):
"""Test CustomFormatter exception handling - forced exception"""
custom_formatter = CustomFormatter("%(levelname)s - %(message)s")
# Create test record and add special attributes that will cause exceptions
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="Forced error test", args=(), exc_info=None
)
# Add attributes that will cause TypeError
class BadAttributes:
def items(self):
raise TypeError("Forced type error")
record.attributes = BadAttributes()
# Call format method, should catch exception and continue execution
formatted_message = custom_formatter.format(record)
self.assertIsInstance(formatted_message, str)
self.assertIn("Forced error test", formatted_message)
def test_colored_formatter_otel_processing_exception(self):
"""Test otel processing exception in ColoredFormatter"""
# Create test record and add special attributes that will cause exceptions
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="Otel exception test", args=(), exc_info=None
)
# Add otelSpanID that will cause Exception
class BadOtelSpanID:
def __str__(self):
raise Exception("Forced otel processing error")
record.otelSpanID = BadOtelSpanID()
# Call format method, should catch exception and continue execution
formatted_message = self.formatter.format(record)
self.assertIsInstance(formatted_message, str)
self.assertIn("Otel exception test", formatted_message)
def test_custom_colored_formatter_thread_processing_exception(self):
"""Test thread processing exception in CustomFormatter"""
custom_formatter = CustomFormatter("%(levelname)s - %(message)s")
# Create test record and add special attributes that will cause exceptions
record = logging.LogRecord(
name="test", level=logging.INFO, pathname="", lineno=0, msg="Thread exception test", args=(), exc_info=None
)
# Add thread attribute that will cause Exception
class BadThread:
def __int__(self):
raise Exception("Forced thread processing error")
record.thread = BadThread()
# Add attribute that will cause AttributeError
class BadOtelSpanID:
def __str__(self):
raise AttributeError("Forced attribute error")
record.otelSpanID = BadOtelSpanID()
# Call format method, should catch exception and continue execution
formatted_message = custom_formatter.format(record)
self.assertIsInstance(formatted_message, str)
self.assertIn("Thread exception test", formatted_message)
if __name__ == "__main__":
unittest.main(verbosity=2)

View File

@@ -14,16 +14,18 @@
import logging
import os
import shutil
import tempfile
import unittest
from unittest.mock import patch
from fastdeploy.logger.handlers import LazyFileHandler
from fastdeploy.logger.logger import FastDeployLogger
class LoggerTests(unittest.TestCase):
"""修改后的测试类,通过实例测试内部方法"""
"""Modified test class, testing internal methods through instances"""
def setUp(self):
self.tmp_dir = tempfile.mkdtemp(prefix="fd_unittest_")
@@ -35,7 +37,7 @@ class LoggerTests(unittest.TestCase):
for p in self.env_patchers:
p.start()
# 创建测试用实例
# Create test instance
self.logger = FastDeployLogger()
def tearDown(self):
@@ -44,7 +46,7 @@ class LoggerTests(unittest.TestCase):
shutil.rmtree(self.tmp_dir, ignore_errors=True)
def test_unified_logger(self):
"""通过实例测试_get_unified_logger"""
"""Test _get_unified_logger through instance"""
test_cases = [(None, "fastdeploy"), ("module", "fastdeploy.module"), ("fastdeploy.utils", "fastdeploy.utils")]
for name, expected in test_cases:
@@ -53,29 +55,118 @@ class LoggerTests(unittest.TestCase):
self.assertEqual(result.name, expected)
def test_main_module_handling(self):
"""测试__main__特殊处理"""
"""Test __main__ special handling"""
with patch("__main__.__file__", "/path/to/test_script.py"):
result = self.logger._get_unified_logger("__main__")
self.assertEqual(result.name, "fastdeploy.main.test_script")
def test_legacy_logger_creation(self):
"""通过实例测试_get_legacy_logger"""
"""Test _get_legacy_logger through instance"""
legacy_logger = self.logger._get_legacy_logger(
"test", "test.log", without_formater=False, print_to_console=True
)
# 验证基础属性
# Verify basic properties
self.assertTrue(legacy_logger.name.startswith("legacy."))
self.assertEqual(legacy_logger.level, logging.INFO)
# 验证handler
self.assertEqual(len(legacy_logger.handlers), 3) # 文件+错误+控制台
# Verify handlers
self.assertEqual(len(legacy_logger.handlers), 3) # file + error + console
def test_logger_propagate(self):
"""测试日志传播设置"""
"""Test log propagation settings"""
legacy_logger = self.logger._get_legacy_logger("test", "test.log")
self.assertTrue(legacy_logger.propagate)
def test_get_trace_logger_basic(self):
"""Test basic functionality of get_trace_logger"""
logger = self.logger.get_trace_logger("test_trace", "trace_test.log")
# Verify basic properties
self.assertTrue(logger.name.startswith("legacy."))
self.assertEqual(logger.level, logging.INFO)
# Verify handler count
self.assertEqual(len(logger.handlers), 2) # main log and error log
def test_get_trace_logger_with_console(self):
"""Test trace logger with console output"""
logger = self.logger.get_trace_logger("test_trace_console", "trace_console_test.log", print_to_console=True)
# Verify handler count
self.assertEqual(len(logger.handlers), 3) # main log + error log + console
def test_get_trace_logger_without_formatter(self):
"""Test trace logger without formatting"""
logger = self.logger.get_trace_logger("test_trace_no_fmt", "trace_no_fmt_test.log", without_formater=True)
# Verify handlers have no formatter
for handler in logger.handlers:
self.assertIsNone(handler.formatter)
def test_get_trace_logger_debug_mode(self):
"""Test trace logger in debug mode"""
with patch("fastdeploy.envs.FD_DEBUG", "1"):
logger = self.logger.get_trace_logger("test_trace_debug", "trace_debug_test.log")
self.assertEqual(logger.level, logging.DEBUG)
def test_get_trace_logger_directory_creation(self):
"""Test line 105: log directory creation functionality"""
import os
from unittest.mock import patch
# Test creation of non-existent directory
with tempfile.TemporaryDirectory() as temp_dir:
test_log_dir = os.path.join(temp_dir, "test_logs")
with patch("fastdeploy.envs.FD_LOG_DIR", test_log_dir):
# Ensure directory does not exist
self.assertFalse(os.path.exists(test_log_dir))
# Call get_trace_logger, should create directory
self.logger.get_trace_logger("test_dir_creation", "test.log")
# Verify directory is created
self.assertTrue(os.path.exists(test_log_dir))
self.assertTrue(os.path.isdir(test_log_dir))
def test_get_trace_logger_handler_cleanup(self):
"""Test line 126: handler cleanup functionality"""
# First create a logger and add some handlers
test_logger = logging.getLogger("legacy.test_cleanup")
initial_handler_count = len(test_logger.handlers)
# Add some test handlers
test_handler1 = logging.StreamHandler()
test_handler2 = logging.StreamHandler()
test_logger.addHandler(test_handler1)
test_logger.addHandler(test_handler2)
# Verify handlers are added
self.assertEqual(len(test_logger.handlers), initial_handler_count + 2)
# Call get_trace_logger, should clean up existing handlers
logger = self.logger.get_trace_logger("test_cleanup", "cleanup_test.log")
# Verify new logger's handler count (should be 2: main log and error log)
self.assertEqual(len(logger.handlers), 2)
def test_log_file_name_handling_error(self):
"""Test log file name handling logic"""
test_cases = [
("test", "test_error.log"),
]
for input_name, expected_name in test_cases:
with self.subTest(input_name=input_name):
# Create logger and get actual processed file name
logger = self.logger.get_trace_logger("test_file_name", input_name)
# Get file name from handler
for handler in logger.handlers:
if isinstance(handler, LazyFileHandler):
actual_name = os.path.basename(handler.filename)
self.assertTrue(actual_name.endswith(expected_name))
class LoggerExtraTests(unittest.TestCase):
def setUp(self):

View File

@@ -0,0 +1,84 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from unittest.mock import MagicMock
from fastdeploy.engine.request import Request
from fastdeploy.output.token_processor import TokenProcessor
class TestTokenProcessorMetrics:
def setup_method(self):
self.mock_cfg = MagicMock()
self.mock_cached_tokens = MagicMock()
self.mock_engine_queue = MagicMock()
self.mock_split_connector = MagicMock()
self.processor = TokenProcessor(
cfg=self.mock_cfg,
cached_generated_tokens=self.mock_cached_tokens,
engine_worker_queue=self.mock_engine_queue,
split_connector=self.mock_split_connector,
)
# Create a complete Request object with all required parameters
self.task = Request(
request_id="test123",
prompt="test prompt",
prompt_token_ids=[1, 2, 3],
prompt_token_ids_len=3,
messages=["test message"],
history=[],
tools=[],
system="test system",
eos_token_ids=[0],
arrival_time=time.time(),
)
self.task.inference_start_time = time.time()
self.task.schedule_start_time = self.task.inference_start_time - 0.1
self.task.preprocess_end_time = self.task.schedule_start_time - 0.05
self.task.preprocess_start_time = self.task.preprocess_end_time - 0.05
self.task.arrival_time = self.task.preprocess_start_time - 0.1
def test_record_first_token_metrics(self, caplog):
current_time = time.time()
with caplog.at_level(logging.INFO):
self.processor._record_first_token_metrics(self.task, current_time)
assert len(caplog.records) == 2
assert "[request_id=test123]" in caplog.text
assert "[event=FIRST_TOKEN_GENERATED]" in caplog.text
assert "[event=DECODE_START]" in caplog.text
# Verify metrics are set
assert hasattr(self.task, "first_token_time")
assert self.task.first_token_time == current_time
def test_record_completion_metrics(self, caplog):
current_time = time.time()
self.task.first_token_time = current_time - 0.5
with caplog.at_level(logging.INFO):
self.processor._record_completion_metrics(self.task, current_time)
assert len(caplog.records) == 2
assert "[request_id=test123]" in caplog.text
assert "[event=INFERENCE_END]" in caplog.text
assert "[event=POSTPROCESSING_START]" in caplog.text
# Verify metrics are updated
assert self.processor.tokens_counter["test123"] == 0 # Just checking counter exists

View File

@@ -0,0 +1,60 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fastdeploy.trace.constants import EVENT_TO_STAGE_MAP, LoggingEventName, StageName
class TestLoggingEventName:
def test_enum_values(self):
assert LoggingEventName.PREPROCESSING_START.value == "PREPROCESSING_START"
assert LoggingEventName.PREPROCESSING_END.value == "PREPROCESSING_END"
assert LoggingEventName.REQUEST_SCHEDULE_START.value == "REQUEST_SCHEDULE_START"
assert LoggingEventName.REQUEST_QUEUE_START.value == "REQUEST_QUEUE_START"
assert LoggingEventName.REQUEST_QUEUE_END.value == "REQUEST_QUEUE_END"
assert LoggingEventName.RESOURCE_ALLOCATE_START.value == "RESOURCE_ALLOCATE_START"
assert LoggingEventName.RESOURCE_ALLOCATE_END.value == "RESOURCE_ALLOCATE_END"
assert LoggingEventName.REQUEST_SCHEDULE_END.value == "REQUEST_SCHEDULE_END"
assert LoggingEventName.INFERENCE_START.value == "INFERENCE_START"
assert LoggingEventName.FIRST_TOKEN_GENERATED.value == "FIRST_TOKEN_GENERATED"
assert LoggingEventName.DECODE_START.value == "DECODE_START"
assert LoggingEventName.INFERENCE_END.value == "INFERENCE_END"
assert LoggingEventName.POSTPROCESSING_START.value == "POSTPROCESSING_START"
assert LoggingEventName.POSTPROCESSING_END.value == "POSTPROCESSING_END"
class TestStageName:
def test_enum_values(self):
assert StageName.PREPROCESSING.value == "PREPROCESSING"
assert StageName.SCHEDULE.value == "SCHEDULE"
assert StageName.PREFILL.value == "PREFILL"
assert StageName.DECODE.value == "DECODE"
assert StageName.POSTPROCESSING.value == "POSTPROCESSING"
class TestEventToStageMap:
def test_mapping(self):
assert EVENT_TO_STAGE_MAP[LoggingEventName.PREPROCESSING_START] == StageName.PREPROCESSING
assert EVENT_TO_STAGE_MAP[LoggingEventName.PREPROCESSING_END] == StageName.PREPROCESSING
assert EVENT_TO_STAGE_MAP[LoggingEventName.REQUEST_SCHEDULE_START] == StageName.SCHEDULE
assert EVENT_TO_STAGE_MAP[LoggingEventName.REQUEST_QUEUE_START] == StageName.SCHEDULE
assert EVENT_TO_STAGE_MAP[LoggingEventName.REQUEST_QUEUE_END] == StageName.SCHEDULE
assert EVENT_TO_STAGE_MAP[LoggingEventName.RESOURCE_ALLOCATE_START] == StageName.SCHEDULE
assert EVENT_TO_STAGE_MAP[LoggingEventName.RESOURCE_ALLOCATE_END] == StageName.SCHEDULE
assert EVENT_TO_STAGE_MAP[LoggingEventName.REQUEST_SCHEDULE_END] == StageName.SCHEDULE
assert EVENT_TO_STAGE_MAP[LoggingEventName.INFERENCE_START] == StageName.PREFILL
assert EVENT_TO_STAGE_MAP[LoggingEventName.FIRST_TOKEN_GENERATED] == StageName.PREFILL
assert EVENT_TO_STAGE_MAP[LoggingEventName.DECODE_START] == StageName.DECODE
assert EVENT_TO_STAGE_MAP[LoggingEventName.INFERENCE_END] == StageName.DECODE
assert EVENT_TO_STAGE_MAP[LoggingEventName.POSTPROCESSING_START] == StageName.POSTPROCESSING
assert EVENT_TO_STAGE_MAP[LoggingEventName.POSTPROCESSING_END] == StageName.POSTPROCESSING

View File

@@ -0,0 +1,47 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from unittest.mock import patch
from fastdeploy.trace.constants import LoggingEventName, StageName
from fastdeploy.trace.trace_logger import print as trace_print
class TestTraceLogging:
def test_trace_print(self, caplog):
request_id = "test123"
user = "test_user"
event = LoggingEventName.PREPROCESSING_START
with caplog.at_level(logging.INFO):
trace_print(event, request_id, user)
assert len(caplog.records) == 1
record = caplog.records[0]
assert f"[request_id={request_id}]" in record.message
assert f"[user_id={user}]" in record.message
assert f"[event={event.value}]" in record.message
assert f"[stage={StageName.PREPROCESSING.value}]" in record.message
def test_trace_print_with_logger_error(self, caplog):
request_id = "test123"
user = "test_user"
event = LoggingEventName.PREPROCESSING_START
with patch("logging.Logger.info", side_effect=Exception("Logger error")):
with caplog.at_level(logging.INFO):
trace_print(event, request_id, user)
assert len(caplog.records) == 0