diff --git a/nixos/lib/test-driver/test_driver/logger.py b/nixos/lib/test-driver/test_driver/logger.py index 14c5b1033dd5..f50e3b78a828 100644 --- a/nixos/lib/test-driver/test_driver/logger.py +++ b/nixos/lib/test-driver/test_driver/logger.py @@ -4,9 +4,9 @@ import sys import time import unicodedata from abc import ABC, abstractmethod -from contextlib import contextmanager +from contextlib import ExitStack, contextmanager from queue import Empty, Queue -from typing import Any, Dict, Iterator +from typing import Any, Dict, Iterator, List from xml.sax.saxutils import XMLGenerator from xml.sax.xmlreader import AttributesImpl @@ -49,24 +49,117 @@ class AbstractLogger(ABC): pass -class Logger(AbstractLogger): +class CompositeLogger(AbstractLogger): + def __init__(self, logger_list: List[AbstractLogger]) -> None: + self.logger_list = logger_list + + def add_logger(self, logger: AbstractLogger) -> None: + self.logger_list.append(logger) + + def log(self, message: str, attributes: Dict[str, str] = {}) -> None: + for logger in self.logger_list: + logger.log(message, attributes) + + @contextmanager + def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + with ExitStack() as stack: + for logger in self.logger_list: + stack.enter_context(logger.subtest(name, attributes)) + yield + + @contextmanager + def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + with ExitStack() as stack: + for logger in self.logger_list: + stack.enter_context(logger.nested(message, attributes)) + yield + + def info(self, *args, **kwargs) -> None: # type: ignore + for logger in self.logger_list: + logger.info(*args, **kwargs) + + def warning(self, *args, **kwargs) -> None: # type: ignore + for logger in self.logger_list: + logger.warning(*args, **kwargs) + + def error(self, *args, **kwargs) -> None: # type: ignore + for logger in self.logger_list: + logger.error(*args, **kwargs) + sys.exit(1) + + def print_serial_logs(self, enable: bool) -> None: + for logger in self.logger_list: + logger.print_serial_logs(enable) + + def log_serial(self, message: str, machine: str) -> None: + for logger in self.logger_list: + logger.log_serial(message, machine) + + +class TerminalLogger(AbstractLogger): + def __init__(self) -> None: + self._print_serial_logs = True + + def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str: + if "machine" in attributes: + return f"{attributes['machine']}: {message}" + return message + + @staticmethod + def _eprint(*args: object, **kwargs: Any) -> None: + print(*args, file=sys.stderr, **kwargs) + + def log(self, message: str, attributes: Dict[str, str] = {}) -> None: + self._eprint(self.maybe_prefix(message, attributes)) + + @contextmanager + def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + with self.nested("subtest: " + name, attributes): + yield + + @contextmanager + def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + self._eprint( + self.maybe_prefix( + Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes + ) + ) + + tic = time.time() + yield + toc = time.time() + self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)") + + def info(self, *args, **kwargs) -> None: # type: ignore + self.log(*args, **kwargs) + + def warning(self, *args, **kwargs) -> None: # type: ignore + self.log(*args, **kwargs) + + def error(self, *args, **kwargs) -> None: # type: ignore + self.log(*args, **kwargs) + + def print_serial_logs(self, enable: bool) -> None: + self._print_serial_logs = enable + + def log_serial(self, message: str, machine: str) -> None: + if not self._print_serial_logs: + return + + self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL) + + +class XMLLogger(AbstractLogger): def __init__(self) -> None: self.logfile = os.environ.get("LOGFILE", "/dev/null") self.logfile_handle = codecs.open(self.logfile, "wb") self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8") self.queue: "Queue[Dict[str, str]]" = Queue() - self.xml.startDocument() - self.xml.startElement("logfile", attrs=AttributesImpl({})) - self._print_serial_logs = True - def print_serial_logs(self, enable: bool) -> None: - self._print_serial_logs = enable - - @staticmethod - def _eprint(*args: object, **kwargs: Any) -> None: - print(*args, file=sys.stderr, **kwargs) + self.xml.startDocument() + self.xml.startElement("logfile", attrs=AttributesImpl({})) def close(self) -> None: self.xml.endElement("logfile") @@ -94,17 +187,19 @@ class Logger(AbstractLogger): def error(self, *args, **kwargs) -> None: # type: ignore self.log(*args, **kwargs) - sys.exit(1) def log(self, message: str, attributes: Dict[str, str] = {}) -> None: - self._eprint(self.maybe_prefix(message, attributes)) self.drain_log_queue() self.log_line(message, attributes) + def print_serial_logs(self, enable: bool) -> None: + self._print_serial_logs = enable + def log_serial(self, message: str, machine: str) -> None: + if not self._print_serial_logs: + return + self.enqueue({"msg": message, "machine": machine, "type": "serial"}) - if self._print_serial_logs: - self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL) def enqueue(self, item: Dict[str, str]) -> None: self.queue.put(item) @@ -126,12 +221,6 @@ class Logger(AbstractLogger): @contextmanager def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: - self._eprint( - self.maybe_prefix( - Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes - ) - ) - self.xml.startElement("nest", attrs=AttributesImpl({})) self.xml.startElement("head", attrs=AttributesImpl(attributes)) self.xml.characters(message) @@ -147,4 +236,6 @@ class Logger(AbstractLogger): self.xml.endElement("nest") -rootlog: AbstractLogger = Logger() +terminal_logger = TerminalLogger() +xml_logger = XMLLogger() +rootlog: AbstractLogger = CompositeLogger([terminal_logger, xml_logger]) diff --git a/nixos/lib/test-script-prepend.py b/nixos/lib/test-script-prepend.py index 976992ea0015..9d2efdf97303 100644 --- a/nixos/lib/test-script-prepend.py +++ b/nixos/lib/test-script-prepend.py @@ -4,7 +4,7 @@ from test_driver.driver import Driver from test_driver.vlan import VLan from test_driver.machine import Machine -from test_driver.logger import Logger +from test_driver.logger import AbstractLogger from typing import Callable, Iterator, ContextManager, Optional, List, Dict, Any, Union from typing_extensions import Protocol from pathlib import Path @@ -44,7 +44,7 @@ test_script: Callable[[], None] machines: List[Machine] vlans: List[VLan] driver: Driver -log: Logger +log: AbstractLogger create_machine: CreateMachineProtocol run_tests: Callable[[], None] join_all: Callable[[], None]