vet

Mirror of Vet, an AI code review tool
git clone git://git.laack.co/vet.git
Log | Files | Refs | README | LICENSE

async_monkey_patches_test.py (5881B)


      1 import sys
      2 from contextlib import contextmanager
      3 from contextvars import ContextVar
      4 from typing import Any
      5 from typing import Callable
      6 from typing import Generator
      7 from typing import Iterator
      8 
      9 import pytest
     10 from loguru import logger
     11 
     12 
     13 class IncorrectErrorsLoggedDuringTesting(Exception):
     14     pass
     15 
     16 
     17 _expecting_errors: ContextVar[bool] = ContextVar("expecting_errors", default=False)
     18 
     19 
     20 @contextmanager
     21 def check_logged_errors(check_func: Callable[[list[str]], None]) -> Iterator[None]:
     22     """Context manager that intercepts ERROR logs using loguru's sink system.
     23     Then it runs the check function on the accumulated errors.
     24 
     25     Sets the _expecting_errors context variable so that explode_on_error knows to
     26     ignore errors during this block.
     27     """
     28     accumulated_errors: list[str] = []
     29 
     30     token = _expecting_errors.set(True)
     31 
     32     def error_catching_sink(message: Any) -> None:
     33         record = message.record
     34         if record["level"].name == "ERROR":
     35             accumulated_errors.append(record["message"])
     36             sys.stderr.write(f"CAUGHT ERROR LOG: {record['message'].splitlines()[0][:100]}\n")
     37         else:
     38             sys.stderr.write(str(message))
     39 
     40     handler_id = logger.add(error_catching_sink, format="{message}", level="DEBUG")
     41     try:
     42         logger.remove(0)
     43     except ValueError:
     44         pass
     45 
     46     try:
     47         yield
     48     finally:
     49         _expecting_errors.reset(token)
     50         logger.remove(handler_id)
     51         logger.add(sys.stderr, level="DEBUG")
     52         check_func(accumulated_errors)
     53 
     54 
     55 def at_least_check_maker(expected_errors_set: set[str]) -> Callable[[list[str]], None]:
     56     assert isinstance(expected_errors_set, set), "expected_errors must be a set"
     57     expected_errors = list(expected_errors_set)
     58 
     59     def check_func(accumulated_errors: list[str]) -> None:
     60         if len(accumulated_errors) < len(expected_errors):
     61             raise IncorrectErrorsLoggedDuringTesting(
     62                 f"{len(accumulated_errors)=} != {len(expected_errors)=}, {accumulated_errors=}"
     63             )
     64         for expected_error in expected_errors:
     65             for accumulated_error in accumulated_errors:
     66                 if expected_error in accumulated_error:
     67                     break
     68             else:
     69                 raise IncorrectErrorsLoggedDuringTesting(f"{expected_error=} is not in {accumulated_errors=}")
     70 
     71     return check_func
     72 
     73 
     74 @contextmanager
     75 def expect_at_least_logged_errors(expected_errors: set[str]) -> Iterator[None]:
     76     """Context manager that intercepts ERROR logs using loguru's sink system.
     77     Checks that all expected errors are in the accumulated errors, in no particular order.
     78     """
     79     check_func = at_least_check_maker(expected_errors)
     80     with check_logged_errors(check_func):
     81         yield
     82 
     83 
     84 def exact_check_maker(expected_errors: list[str]) -> Callable[[list[str]], None]:
     85     assert isinstance(expected_errors, list), "expected_errors must be a list"
     86 
     87     def check_func(accumulated_errors: list[str]) -> None:
     88         if len(accumulated_errors) != len(expected_errors):
     89             raise IncorrectErrorsLoggedDuringTesting(
     90                 f"{len(accumulated_errors)=} != {len(expected_errors)=}, {accumulated_errors=}"
     91             )
     92         for i, expected_error in enumerate(expected_errors):
     93             if expected_error not in accumulated_errors[i]:
     94                 raise IncorrectErrorsLoggedDuringTesting(
     95                     f"At position {i=}, {expected_error=} is not in {accumulated_errors[i]=}"
     96                 )
     97 
     98     return check_func
     99 
    100 
    101 @contextmanager
    102 def expect_exact_logged_errors(expected_errors: list[str]) -> Iterator[None]:
    103     """Context manager that intercepts ERROR logs using loguru's sink system.
    104     Checks that all expected errors are in the accumulated errors, in the same order."""
    105     check_func = exact_check_maker(expected_errors)
    106     with check_logged_errors(check_func):
    107         yield
    108 
    109 
    110 @pytest.fixture
    111 def explode_on_error() -> Generator[None, None, None]:
    112     """Fixture to explode on error - fails the test if any ERROR logs are recorded."""
    113     accumulated_errors: list[str] = []
    114 
    115     def error_catching_sink(message: Any) -> None:
    116         record = message.record
    117         if record["level"].name == "ERROR":
    118             if not _expecting_errors.get():
    119                 accumulated_errors.append(record["message"])
    120         sys.stderr.write(str(message))
    121 
    122     handler_id = logger.add(
    123         error_catching_sink,
    124         format="{level} | {name}:{function}:{line} - {message}",
    125         level="DEBUG",
    126     )
    127     try:
    128         logger.remove(0)
    129     except ValueError:
    130         pass
    131 
    132     try:
    133         yield
    134     except BaseException:
    135         raise
    136     else:
    137         if len(accumulated_errors) > 0:
    138             raise IncorrectErrorsLoggedDuringTesting(f"Errors logged during testing: {accumulated_errors}")
    139     finally:
    140         logger.remove(handler_id)
    141         logger.add(sys.stderr, level="DEBUG")
    142 
    143 
    144 def test_log_error(explode_on_error: Any) -> None:
    145     with expect_exact_logged_errors(["Something bad happened"]):
    146         logger.error("Something bad happened")
    147 
    148     with pytest.raises(IncorrectErrorsLoggedDuringTesting):
    149         with expect_exact_logged_errors(["Something bad happened"]):
    150             pass
    151 
    152     with pytest.raises(IncorrectErrorsLoggedDuringTesting):
    153         with expect_exact_logged_errors(["Something bad happened"]):
    154             logger.error("Something bad happened")
    155             logger.error("Something else bad happened")
    156 
    157 
    158 def test_log_error_at_least(explode_on_error: Any) -> None:
    159     with expect_at_least_logged_errors({"Something bad happened"}):
    160         logger.error("Something bad happened")
    161         logger.error("Something else bad happened")
    162 
    163     with pytest.raises(IncorrectErrorsLoggedDuringTesting):
    164         with expect_at_least_logged_errors({"Something bad happened", "Something else bad happened"}):
    165             logger.error("Something bad happened")