Skip to content

Commit 06d21ce

Browse files
author
yokotoka
committed
fix: implement streaming mode for large audio files to prevent timeouts
1 parent 0631f3f commit 06d21ce

6 files changed

Lines changed: 156 additions & 13 deletions

File tree

examples/audio_file.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
palabra = PalabraAI()
66
reader = FileReader("./speech/es.mp3")
77
writer = FileWriter("./es2en_out.wav")
8-
cfg = Config(SourceLang(ES, reader), [TargetLang(EN, writer)])
8+
cfg = Config(SourceLang(ES, reader), [TargetLang(EN, writer)], debug=True)
99
palabra.run(cfg)

src/palabra_ai/task/io/base.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,24 +200,32 @@ async def _exit(self):
200200
return await super()._exit()
201201

202202
async def set_task(self):
203+
debug(f"set_task() STARTED for {self.name} id={id(self)}")
203204
debug("Setting task configuration...")
204205
await aio.sleep(SLEEP_INTERVAL_LONG)
206+
debug(f"set_task() creating receiver for {self.name} id={id(self)}")
205207
async with self.out_msg_foq.receiver(self, self.stopper) as msgs_out:
208+
debug(f"set_task() receiver created for {self.name}")
206209
await self.push_in_msg(SetTaskMessage.from_config(self.cfg))
207210
start_time = time.perf_counter()
208211
await aio.sleep(SLEEP_INTERVAL_LONG)
209212
while start_time + BOOT_TIMEOUT > time.perf_counter():
210213
await self.push_in_msg(GetTaskMessage())
211214
msg = await anext(msgs_out)
212215
if isinstance(msg, CurrentTaskMessage):
213-
debug(f"Received current task: {msg.data}")
216+
debug(f"set_task() SUCCESS: Received current task: {msg.data}")
214217
return
215218
# Handle error messages from server
216219
from palabra_ai.message import ErrorMessage
217220

218221
if isinstance(msg, ErrorMessage):
219222
debug(f"Received error from server: {msg.data}")
220-
msg.raise_() # This will raise the appropriate exception
223+
# Don't immediately fail on NOT_FOUND - it may be temporary
224+
if msg.data.get("data", {}).get("code") == "NOT_FOUND":
225+
debug("Got NOT_FOUND error, will retry...")
226+
else:
227+
# For other errors, raise immediately
228+
msg.raise_()
221229
debug(f"Received unexpected message: {msg}")
222230
await aio.sleep(SLEEP_INTERVAL_LONG)
223231
debug("Timeout waiting for task configuration")

src/palabra_ai/task/io/webrtc.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919
from palabra_ai.task.io.base import Io
2020
from palabra_ai.util.aio import shutdown
21-
from palabra_ai.util.logger import debug
21+
from palabra_ai.util.logger import debug, error
2222

2323
PALABRA_PEER_PREFIX = "palabra_translator_"
2424
PALABRA_TRACK_PREFIX = "translation_"
@@ -141,15 +141,23 @@ async def send_frame(self, frame: AudioFrame) -> None:
141141
return await self.in_audio_source.capture_frame(frame.to_rtc())
142142

143143
async def boot(self):
144+
debug(f"WebrtcIo.boot() STARTED for {self.name} id={id(self)}")
144145
await self.room.connect(
145146
self.credentials.webrtc_url, self.credentials.jwt_token, self.room_options
146147
)
147148
self.room.on("data_received", self.on_data_received)
148149
lang = self.cfg.targets[0].lang # TODO: many langs
149150
self.peer = await self.peer_appears()
151+
debug(f"WebrtcIo.boot() creating in_msg_sender task for {self.name}")
150152
self.sub_tg.create_task(self.in_msg_sender(), name="Io:in_msg_sender")
151153

152-
await self.set_task()
154+
debug(f"WebrtcIo.boot() calling set_task() for {self.name}")
155+
try:
156+
await self.set_task()
157+
debug(f"WebrtcIo.boot() set_task() completed for {self.name}")
158+
except Exception as e:
159+
error(f"WebrtcIo.boot() set_task() FAILED: {e}", exc_info=True)
160+
raise
153161

154162
self.in_track_name = self.in_track_name or f"{uuid.uuid4()}_{lang.code}"
155163
# noinspection PyTypeChecker

src/palabra_ai/util/fanout_queue.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@ def _get_id(self, subscriber: Any) -> str:
2929
if isinstance(subscriber, str):
3030
return subscriber
3131
elif isinstance(subscriber, object):
32-
# Use the object's id if it's not a string
32+
# Include queue instance ID to avoid collisions between different queues
33+
queue_id = id(self)
34+
subscriber_id = id(subscriber)
3335
if name := getattr(subscriber, "name", None):
34-
return f"{name}_{id(subscriber)}"
35-
return f"{type(subscriber)}_{id(subscriber)}"
36+
return f"{name}_{subscriber_id}_{queue_id}"
37+
return f"{type(subscriber)}_{subscriber_id}_{queue_id}"
3638
else:
3739
raise TypeError(
3840
f"Subscriber must be a string or an object, got: {type(subscriber)}"
@@ -132,10 +134,10 @@ async def message_generator(
132134
break
133135
# Otherwise continue waiting
134136

135-
debug(f"Starting subscriber {subscriber_id}")
137+
debug(f"Starting subscriber {subscriber_id} for queue {type(self).__name__}")
136138

137139
# Subscribe
138-
_ = self.subscribe(subscriber_id, maxsize=0)
140+
_ = self.subscribe(subscriber, maxsize=0)
139141
subscription = self.subscribers[subscriber_id]
140142
generator = message_generator(subscription)
141143

tests/test_task_io_base.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from palabra_ai.audio import AudioFrame
88
from palabra_ai.enum import Channel, Direction
99
from palabra_ai.message import (
10-
Message, EndTaskMessage, SetTaskMessage, GetTaskMessage, CurrentTaskMessage
10+
Message, EndTaskMessage, SetTaskMessage, GetTaskMessage, CurrentTaskMessage, ErrorMessage
1111
)
1212
from palabra_ai.constant import BYTES_PER_SAMPLE, SLEEP_INTERVAL_LONG
1313
from palabra_ai.util.fanout_queue import FanoutQueue
@@ -245,6 +245,96 @@ async def mock_receiver():
245245
assert any("Setting task configuration" in str(call) for call in mock_debug.call_args_list)
246246
assert any("Received current task" in str(call) for call in mock_debug.call_args_list)
247247

248+
@pytest.mark.asyncio
249+
async def test_set_task_not_found_error_retry(self, mock_config, mock_credentials, mock_reader, mock_writer):
250+
"""Test set_task handles NOT_FOUND errors and retries"""
251+
io = ConcreteIo(
252+
cfg=mock_config,
253+
credentials=mock_credentials,
254+
reader=mock_reader,
255+
writer=mock_writer
256+
)
257+
258+
io.push_in_msg = AsyncMock()
259+
260+
# Mock subscription to return NOT_FOUND error then success
261+
async def mock_receiver():
262+
# First return NOT_FOUND error
263+
error_msg = ErrorMessage(
264+
message_type="error",
265+
timestamp=0.0,
266+
raw={"data": {"code": "NOT_FOUND", "desc": "No active task found"}},
267+
data={"data": {"code": "NOT_FOUND", "desc": "No active task found"}}
268+
)
269+
yield error_msg
270+
271+
# Then return success
272+
yield CurrentTaskMessage(timestamp=0.0, data={"task": "test"})
273+
274+
with patch.object(io.out_msg_foq, 'receiver') as mock_receiver_ctx:
275+
mock_receiver_ctx.return_value.__aenter__.return_value = mock_receiver()
276+
277+
with patch('palabra_ai.task.io.base.debug') as mock_debug:
278+
await io.set_task()
279+
280+
# Verify NOT_FOUND was logged but didn't cause immediate failure
281+
debug_calls = [str(call) for call in mock_debug.call_args_list]
282+
assert any("Got NOT_FOUND error, will retry" in call for call in debug_calls)
283+
assert any("set_task() SUCCESS" in call for call in debug_calls)
284+
285+
@pytest.mark.asyncio
286+
async def test_set_task_other_error_immediate_failure(self, mock_config, mock_credentials, mock_reader, mock_writer):
287+
"""Test set_task raises immediately for non-NOT_FOUND errors"""
288+
io = ConcreteIo(
289+
cfg=mock_config,
290+
credentials=mock_credentials,
291+
reader=mock_reader,
292+
writer=mock_writer
293+
)
294+
295+
io.push_in_msg = AsyncMock()
296+
297+
# Mock subscription to return other error
298+
async def mock_receiver():
299+
error_msg = MagicMock(spec=ErrorMessage)
300+
error_msg.data = {"data": {"code": "SERVER_ERROR", "desc": "Internal server error"}}
301+
error_msg.raise_ = MagicMock(side_effect=RuntimeError("Server error"))
302+
yield error_msg
303+
304+
with patch.object(io.out_msg_foq, 'receiver') as mock_receiver_ctx:
305+
mock_receiver_ctx.return_value.__aenter__.return_value = mock_receiver()
306+
307+
with pytest.raises(RuntimeError, match="Server error"):
308+
await io.set_task()
309+
310+
@pytest.mark.asyncio
311+
async def test_set_task_debug_logging(self, mock_config, mock_credentials, mock_reader, mock_writer):
312+
"""Test set_task produces expected debug messages"""
313+
io = ConcreteIo(
314+
cfg=mock_config,
315+
credentials=mock_credentials,
316+
reader=mock_reader,
317+
writer=mock_writer
318+
)
319+
320+
io.push_in_msg = AsyncMock()
321+
322+
# Mock subscription to return success immediately
323+
async def mock_receiver():
324+
yield CurrentTaskMessage(timestamp=0.0, data={"task": "test"})
325+
326+
with patch.object(io.out_msg_foq, 'receiver') as mock_receiver_ctx:
327+
mock_receiver_ctx.return_value.__aenter__.return_value = mock_receiver()
328+
329+
with patch('palabra_ai.task.io.base.debug') as mock_debug:
330+
await io.set_task()
331+
332+
# Check for new debug messages
333+
debug_calls = [str(call) for call in mock_debug.call_args_list]
334+
assert any("set_task() STARTED" in call for call in debug_calls)
335+
assert any("set_task() creating receiver" in call for call in debug_calls)
336+
assert any("set_task() receiver created" in call for call in debug_calls)
337+
248338
@pytest.mark.asyncio
249339
async def test_set_task_timeout(self, mock_config, mock_credentials, mock_reader, mock_writer):
250340
"""Test set_task method with timeout"""

tests/test_util_fanout_queue.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,29 @@ def test_get_id_object_with_name(self):
2626
obj = MagicMock()
2727
obj.name = "test_object"
2828
subscriber_id = self.queue._get_id(obj)
29+
# Should include name, object id, and queue id
2930
assert subscriber_id.startswith("test_object_")
31+
assert str(id(obj)) in subscriber_id
32+
assert str(id(self.queue)) in subscriber_id
3033

3134
def test_get_id_object_without_name(self):
3235
"""Test _get_id with object without name attribute"""
3336
obj = MagicMock()
3437
del obj.name # Remove name attribute
3538
subscriber_id = self.queue._get_id(obj)
39+
# Should include type name, object id, and queue id
3640
assert "MagicMock" in subscriber_id
41+
assert str(id(obj)) in subscriber_id
42+
assert str(id(self.queue)) in subscriber_id
3743

3844
def test_get_id_integer(self):
3945
"""Test _get_id with integer (which is an object in Python)"""
4046
# In Python, integers are objects, so this should work
4147
subscriber_id = self.queue._get_id(123)
42-
assert "<class 'int'>_" in subscriber_id # Should contain type name and id
48+
# Should contain type name, object id, and queue id
49+
assert "<class 'int'>_" in subscriber_id
50+
assert str(id(123)) in subscriber_id
51+
assert str(id(self.queue)) in subscriber_id
4352

4453
def test_is_subscribed(self):
4554
"""Test is_subscribed method"""
@@ -256,4 +265,30 @@ async def test_receiver_cleanup_on_exception(self):
256265
raise RuntimeError("Test error")
257266

258267
# Check subscriber was cleaned up
259-
assert "test_subscriber" not in self.queue.subscribers
268+
assert "test_subscriber" not in self.queue.subscribers
269+
270+
def test_queue_collision_prevention(self):
271+
"""Test that multiple queues with same subscriber don't collide"""
272+
queue1 = FanoutQueue()
273+
queue2 = FanoutQueue()
274+
275+
# Same object subscribed to both queues
276+
obj = MagicMock()
277+
obj.name = "test_obj"
278+
279+
# Subscribe to both queues
280+
sub1 = queue1.subscribe(obj)
281+
sub2 = queue2.subscribe(obj)
282+
283+
# Should create different subscriber IDs due to queue ID inclusion
284+
assert sub1.id_ != sub2.id_
285+
assert str(id(queue1)) in sub1.id_
286+
assert str(id(queue2)) in sub2.id_
287+
288+
# Publish to each queue
289+
queue1.publish("msg1")
290+
queue2.publish("msg2")
291+
292+
# Each should receive only its own message
293+
assert sub1.q.get_nowait() == "msg1"
294+
assert sub2.q.get_nowait() == "msg2"

0 commit comments

Comments
 (0)