async_monkey_patches_test.py (5881B)
1 import sys 2 from contextlib import contextmanager 3 from contextvars import ContextVar 4 from typing import Any 5 from typing import Callable 6 from typing import Generator 7 from typing import Iterator 8 9 import pytest 10 from loguru import logger 11 12 13 class IncorrectErrorsLoggedDuringTesting(Exception): 14 pass 15 16 17 _expecting_errors: ContextVar[bool] = ContextVar("expecting_errors", default=False) 18 19 20 @contextmanager 21 def check_logged_errors(check_func: Callable[[list[str]], None]) -> Iterator[None]: 22 """Context manager that intercepts ERROR logs using loguru's sink system. 23 Then it runs the check function on the accumulated errors. 24 25 Sets the _expecting_errors context variable so that explode_on_error knows to 26 ignore errors during this block. 27 """ 28 accumulated_errors: list[str] = [] 29 30 token = _expecting_errors.set(True) 31 32 def error_catching_sink(message: Any) -> None: 33 record = message.record 34 if record["level"].name == "ERROR": 35 accumulated_errors.append(record["message"]) 36 sys.stderr.write(f"CAUGHT ERROR LOG: {record['message'].splitlines()[0][:100]}\n") 37 else: 38 sys.stderr.write(str(message)) 39 40 handler_id = logger.add(error_catching_sink, format="{message}", level="DEBUG") 41 try: 42 logger.remove(0) 43 except ValueError: 44 pass 45 46 try: 47 yield 48 finally: 49 _expecting_errors.reset(token) 50 logger.remove(handler_id) 51 logger.add(sys.stderr, level="DEBUG") 52 check_func(accumulated_errors) 53 54 55 def at_least_check_maker(expected_errors_set: set[str]) -> Callable[[list[str]], None]: 56 assert isinstance(expected_errors_set, set), "expected_errors must be a set" 57 expected_errors = list(expected_errors_set) 58 59 def check_func(accumulated_errors: list[str]) -> None: 60 if len(accumulated_errors) < len(expected_errors): 61 raise IncorrectErrorsLoggedDuringTesting( 62 f"{len(accumulated_errors)=} != {len(expected_errors)=}, {accumulated_errors=}" 63 ) 64 for expected_error in expected_errors: 65 for accumulated_error in accumulated_errors: 66 if expected_error in accumulated_error: 67 break 68 else: 69 raise IncorrectErrorsLoggedDuringTesting(f"{expected_error=} is not in {accumulated_errors=}") 70 71 return check_func 72 73 74 @contextmanager 75 def expect_at_least_logged_errors(expected_errors: set[str]) -> Iterator[None]: 76 """Context manager that intercepts ERROR logs using loguru's sink system. 77 Checks that all expected errors are in the accumulated errors, in no particular order. 78 """ 79 check_func = at_least_check_maker(expected_errors) 80 with check_logged_errors(check_func): 81 yield 82 83 84 def exact_check_maker(expected_errors: list[str]) -> Callable[[list[str]], None]: 85 assert isinstance(expected_errors, list), "expected_errors must be a list" 86 87 def check_func(accumulated_errors: list[str]) -> None: 88 if len(accumulated_errors) != len(expected_errors): 89 raise IncorrectErrorsLoggedDuringTesting( 90 f"{len(accumulated_errors)=} != {len(expected_errors)=}, {accumulated_errors=}" 91 ) 92 for i, expected_error in enumerate(expected_errors): 93 if expected_error not in accumulated_errors[i]: 94 raise IncorrectErrorsLoggedDuringTesting( 95 f"At position {i=}, {expected_error=} is not in {accumulated_errors[i]=}" 96 ) 97 98 return check_func 99 100 101 @contextmanager 102 def expect_exact_logged_errors(expected_errors: list[str]) -> Iterator[None]: 103 """Context manager that intercepts ERROR logs using loguru's sink system. 104 Checks that all expected errors are in the accumulated errors, in the same order.""" 105 check_func = exact_check_maker(expected_errors) 106 with check_logged_errors(check_func): 107 yield 108 109 110 @pytest.fixture 111 def explode_on_error() -> Generator[None, None, None]: 112 """Fixture to explode on error - fails the test if any ERROR logs are recorded.""" 113 accumulated_errors: list[str] = [] 114 115 def error_catching_sink(message: Any) -> None: 116 record = message.record 117 if record["level"].name == "ERROR": 118 if not _expecting_errors.get(): 119 accumulated_errors.append(record["message"]) 120 sys.stderr.write(str(message)) 121 122 handler_id = logger.add( 123 error_catching_sink, 124 format="{level} | {name}:{function}:{line} - {message}", 125 level="DEBUG", 126 ) 127 try: 128 logger.remove(0) 129 except ValueError: 130 pass 131 132 try: 133 yield 134 except BaseException: 135 raise 136 else: 137 if len(accumulated_errors) > 0: 138 raise IncorrectErrorsLoggedDuringTesting(f"Errors logged during testing: {accumulated_errors}") 139 finally: 140 logger.remove(handler_id) 141 logger.add(sys.stderr, level="DEBUG") 142 143 144 def test_log_error(explode_on_error: Any) -> None: 145 with expect_exact_logged_errors(["Something bad happened"]): 146 logger.error("Something bad happened") 147 148 with pytest.raises(IncorrectErrorsLoggedDuringTesting): 149 with expect_exact_logged_errors(["Something bad happened"]): 150 pass 151 152 with pytest.raises(IncorrectErrorsLoggedDuringTesting): 153 with expect_exact_logged_errors(["Something bad happened"]): 154 logger.error("Something bad happened") 155 logger.error("Something else bad happened") 156 157 158 def test_log_error_at_least(explode_on_error: Any) -> None: 159 with expect_at_least_logged_errors({"Something bad happened"}): 160 logger.error("Something bad happened") 161 logger.error("Something else bad happened") 162 163 with pytest.raises(IncorrectErrorsLoggedDuringTesting): 164 with expect_at_least_logged_errors({"Something bad happened", "Something else bad happened"}): 165 logger.error("Something bad happened")