|
7 | 7 | from palabra_ai.audio import AudioFrame |
8 | 8 | from palabra_ai.enum import Channel, Direction |
9 | 9 | from palabra_ai.message import ( |
10 | | - Message, EndTaskMessage, SetTaskMessage, GetTaskMessage, CurrentTaskMessage |
| 10 | + Message, EndTaskMessage, SetTaskMessage, GetTaskMessage, CurrentTaskMessage, ErrorMessage |
11 | 11 | ) |
12 | 12 | from palabra_ai.constant import BYTES_PER_SAMPLE, SLEEP_INTERVAL_LONG |
13 | 13 | from palabra_ai.util.fanout_queue import FanoutQueue |
@@ -245,6 +245,96 @@ async def mock_receiver(): |
245 | 245 | assert any("Setting task configuration" in str(call) for call in mock_debug.call_args_list) |
246 | 246 | assert any("Received current task" in str(call) for call in mock_debug.call_args_list) |
247 | 247 |
|
| 248 | + @pytest.mark.asyncio |
| 249 | + async def test_set_task_not_found_error_retry(self, mock_config, mock_credentials, mock_reader, mock_writer): |
| 250 | + """Test set_task handles NOT_FOUND errors and retries""" |
| 251 | + io = ConcreteIo( |
| 252 | + cfg=mock_config, |
| 253 | + credentials=mock_credentials, |
| 254 | + reader=mock_reader, |
| 255 | + writer=mock_writer |
| 256 | + ) |
| 257 | + |
| 258 | + io.push_in_msg = AsyncMock() |
| 259 | + |
| 260 | + # Mock subscription to return NOT_FOUND error then success |
| 261 | + async def mock_receiver(): |
| 262 | + # First return NOT_FOUND error |
| 263 | + error_msg = ErrorMessage( |
| 264 | + message_type="error", |
| 265 | + timestamp=0.0, |
| 266 | + raw={"data": {"code": "NOT_FOUND", "desc": "No active task found"}}, |
| 267 | + data={"data": {"code": "NOT_FOUND", "desc": "No active task found"}} |
| 268 | + ) |
| 269 | + yield error_msg |
| 270 | + |
| 271 | + # Then return success |
| 272 | + yield CurrentTaskMessage(timestamp=0.0, data={"task": "test"}) |
| 273 | + |
| 274 | + with patch.object(io.out_msg_foq, 'receiver') as mock_receiver_ctx: |
| 275 | + mock_receiver_ctx.return_value.__aenter__.return_value = mock_receiver() |
| 276 | + |
| 277 | + with patch('palabra_ai.task.io.base.debug') as mock_debug: |
| 278 | + await io.set_task() |
| 279 | + |
| 280 | + # Verify NOT_FOUND was logged but didn't cause immediate failure |
| 281 | + debug_calls = [str(call) for call in mock_debug.call_args_list] |
| 282 | + assert any("Got NOT_FOUND error, will retry" in call for call in debug_calls) |
| 283 | + assert any("set_task() SUCCESS" in call for call in debug_calls) |
| 284 | + |
| 285 | + @pytest.mark.asyncio |
| 286 | + async def test_set_task_other_error_immediate_failure(self, mock_config, mock_credentials, mock_reader, mock_writer): |
| 287 | + """Test set_task raises immediately for non-NOT_FOUND errors""" |
| 288 | + io = ConcreteIo( |
| 289 | + cfg=mock_config, |
| 290 | + credentials=mock_credentials, |
| 291 | + reader=mock_reader, |
| 292 | + writer=mock_writer |
| 293 | + ) |
| 294 | + |
| 295 | + io.push_in_msg = AsyncMock() |
| 296 | + |
| 297 | + # Mock subscription to return other error |
| 298 | + async def mock_receiver(): |
| 299 | + error_msg = MagicMock(spec=ErrorMessage) |
| 300 | + error_msg.data = {"data": {"code": "SERVER_ERROR", "desc": "Internal server error"}} |
| 301 | + error_msg.raise_ = MagicMock(side_effect=RuntimeError("Server error")) |
| 302 | + yield error_msg |
| 303 | + |
| 304 | + with patch.object(io.out_msg_foq, 'receiver') as mock_receiver_ctx: |
| 305 | + mock_receiver_ctx.return_value.__aenter__.return_value = mock_receiver() |
| 306 | + |
| 307 | + with pytest.raises(RuntimeError, match="Server error"): |
| 308 | + await io.set_task() |
| 309 | + |
| 310 | + @pytest.mark.asyncio |
| 311 | + async def test_set_task_debug_logging(self, mock_config, mock_credentials, mock_reader, mock_writer): |
| 312 | + """Test set_task produces expected debug messages""" |
| 313 | + io = ConcreteIo( |
| 314 | + cfg=mock_config, |
| 315 | + credentials=mock_credentials, |
| 316 | + reader=mock_reader, |
| 317 | + writer=mock_writer |
| 318 | + ) |
| 319 | + |
| 320 | + io.push_in_msg = AsyncMock() |
| 321 | + |
| 322 | + # Mock subscription to return success immediately |
| 323 | + async def mock_receiver(): |
| 324 | + yield CurrentTaskMessage(timestamp=0.0, data={"task": "test"}) |
| 325 | + |
| 326 | + with patch.object(io.out_msg_foq, 'receiver') as mock_receiver_ctx: |
| 327 | + mock_receiver_ctx.return_value.__aenter__.return_value = mock_receiver() |
| 328 | + |
| 329 | + with patch('palabra_ai.task.io.base.debug') as mock_debug: |
| 330 | + await io.set_task() |
| 331 | + |
| 332 | + # Check for new debug messages |
| 333 | + debug_calls = [str(call) for call in mock_debug.call_args_list] |
| 334 | + assert any("set_task() STARTED" in call for call in debug_calls) |
| 335 | + assert any("set_task() creating receiver" in call for call in debug_calls) |
| 336 | + assert any("set_task() receiver created" in call for call in debug_calls) |
| 337 | + |
248 | 338 | @pytest.mark.asyncio |
249 | 339 | async def test_set_task_timeout(self, mock_config, mock_credentials, mock_reader, mock_writer): |
250 | 340 | """Test set_task method with timeout""" |
|
0 commit comments