|
17 | 17 | import shutil |
18 | 18 | import tempfile |
19 | 19 | import unittest |
20 | | -import warnings |
21 | 20 | from datetime import datetime, timedelta |
22 | 21 | from typing import Any |
23 | 22 |
|
|
37 | 36 | from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig |
38 | 37 | from feast.on_demand_feature_view import on_demand_feature_view |
39 | 38 | from feast.types import Array, Float32, Float64, Int64, PdfBytes, String, ValueType |
40 | | - |
41 | | - |
42 | | -def check_warnings( |
43 | | - expected_warnings=None, # List of warnings that MUST be present |
44 | | - forbidden_warnings=None, # List of warnings that MUST NOT be present |
45 | | - match_type="contains", # "exact", "contains", "regex" |
46 | | - capture_all=True, # Capture all warnings or just specific types |
47 | | - fail_on_unexpected=False, # Fail if unexpected warnings appear |
48 | | - min_count=None, # Minimum number of expected warnings |
49 | | - max_count=None, # Maximum number of expected warnings |
50 | | -): |
51 | | - """ |
52 | | - Decorator to automatically capture and validate warnings in test methods. |
53 | | -
|
54 | | - Args: |
55 | | - expected_warnings: List of warning messages that MUST be present |
56 | | - forbidden_warnings: List of warning messages that MUST NOT be present |
57 | | - match_type: How to match warnings ("exact", "contains", "regex") |
58 | | - capture_all: Whether to capture all warnings |
59 | | - fail_on_unexpected: Whether to fail if unexpected warnings appear |
60 | | - min_count: Minimum number of warnings expected |
61 | | - max_count: Maximum number of warnings expected |
62 | | - """ |
63 | | - |
64 | | - def decorator(test_func): |
65 | | - def wrapper(*args, **kwargs): |
66 | | - # Setup warning capture |
67 | | - with warnings.catch_warnings(record=True) as warning_list: |
68 | | - warnings.simplefilter("always") |
69 | | - |
70 | | - # Execute the test function |
71 | | - result = test_func(*args, **kwargs) |
72 | | - |
73 | | - # Convert warnings to string messages |
74 | | - captured_messages = [str(w.message) for w in warning_list] |
75 | | - |
76 | | - # Validate expected warnings are present |
77 | | - if expected_warnings: |
78 | | - for expected_warning in expected_warnings: |
79 | | - if not _warning_matches( |
80 | | - expected_warning, captured_messages, match_type |
81 | | - ): |
82 | | - raise AssertionError( |
83 | | - f"Expected warning '{expected_warning}' not found. " |
84 | | - f"Captured warnings: {captured_messages}" |
85 | | - ) |
86 | | - |
87 | | - # Validate forbidden warnings are NOT present |
88 | | - if forbidden_warnings: |
89 | | - for forbidden_warning in forbidden_warnings: |
90 | | - if _warning_matches( |
91 | | - forbidden_warning, captured_messages, match_type |
92 | | - ): |
93 | | - raise AssertionError( |
94 | | - f"Forbidden warning '{forbidden_warning}' was found. " |
95 | | - f"Captured warnings: {captured_messages}" |
96 | | - ) |
97 | | - |
98 | | - # Validate warning count constraints |
99 | | - if min_count is not None and len(warning_list) < min_count: |
100 | | - raise AssertionError( |
101 | | - f"Expected at least {min_count} warnings, got {len(warning_list)}" |
102 | | - ) |
103 | | - |
104 | | - if max_count is not None and len(warning_list) > max_count: |
105 | | - raise AssertionError( |
106 | | - f"Expected at most {max_count} warnings, got {len(warning_list)}" |
107 | | - ) |
108 | | - |
109 | | - # Validate no unexpected warnings (if enabled) |
110 | | - if fail_on_unexpected and expected_warnings: |
111 | | - all_expected = expected_warnings + (forbidden_warnings or []) |
112 | | - for message in captured_messages: |
113 | | - if not any( |
114 | | - _warning_matches(exp, [message], match_type) |
115 | | - for exp in all_expected |
116 | | - ): |
117 | | - raise AssertionError( |
118 | | - f"Unexpected warning found: '{message}'" |
119 | | - ) |
120 | | - |
121 | | - return result |
122 | | - |
123 | | - return wrapper |
124 | | - |
125 | | - return decorator |
126 | | - |
127 | | - |
128 | | -def _warning_matches(pattern, messages, match_type): |
129 | | - """Helper function to check if pattern matches any message""" |
130 | | - for message in messages: |
131 | | - if match_type == "exact": |
132 | | - if pattern == message: |
133 | | - return True |
134 | | - elif match_type == "contains": |
135 | | - if pattern in message: |
136 | | - return True |
137 | | - elif match_type == "regex": |
138 | | - import re |
139 | | - |
140 | | - if re.search(pattern, message): |
141 | | - return True |
142 | | - return False |
| 39 | +from tests.utils.test_wrappers import check_warnings |
143 | 40 |
|
144 | 41 |
|
145 | 42 | class TestOnlineWrites(unittest.TestCase): |
|
0 commit comments