Skip to content

Commit 9bc52fd

Browse files
fix: enhance connection state checks (#2137)
* fix: enhance connection state checks * feat: support prefetch --------- Co-authored-by: hanhandi <1540984562@qq.com>
1 parent a95b1d4 commit 9bc52fd

4 files changed

Lines changed: 218 additions & 11 deletions

File tree

ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/extension.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,9 @@ def is_connected(self) -> bool:
305305
# ):
306306
# return True # Still consider connected during finalize grace period
307307

308-
return self.connected and self.client is not None
308+
return (
309+
self.connected and self.client is not None and self.client.connected
310+
)
309311

310312
@override
311313
async def send_audio(
@@ -351,6 +353,8 @@ async def send_audio(
351353
return True
352354

353355
except Exception as e:
356+
if self.stopped:
357+
return False
354358
self.ten_env.log(LogLevel.ERROR, f"Error sending audio: {e}")
355359
await self._handle_error(e)
356360
return False
@@ -466,16 +470,32 @@ async def _handle_reconnect(self) -> None:
466470
finally:
467471
self._reconnecting = False
468472

473+
def _result_level_asr_info_fields(
474+
self, result: ASRResponse
475+
) -> dict[str, Any]:
476+
"""Fields from vendor `result` object that belong in metadata.asr_info.
477+
478+
Only includes keys when present on the vendor payload (e.g. prefetch).
479+
"""
480+
rd = result.result
481+
if not rd or not isinstance(rd, dict):
482+
return {}
483+
if "prefetch" not in rd:
484+
return {}
485+
return {"prefetch": rd["prefetch"]}
486+
469487
def _build_metadata_with_asr_info(
470488
self,
471489
base_metadata: dict[str, Any] | None = None,
472490
additional_fields: dict[str, Any] | None = None,
491+
result_level_fields: dict[str, Any] | None = None,
473492
) -> dict[str, Any]:
474493
"""Build metadata according to protocol: session_id at root, others in asr_info.
475494
476495
Args:
477496
base_metadata: Base metadata dict (defaults to self.metadata if None)
478497
additional_fields: Additional fields to add to asr_info
498+
result_level_fields: Vendor result-level fields (e.g. prefetch) for asr_info
479499
480500
Returns:
481501
Metadata dict with structure: {"session_id": "...", "asr_info": {...}}
@@ -501,6 +521,10 @@ def _build_metadata_with_asr_info(
501521
if additional_fields:
502522
asr_info.update(additional_fields)
503523

524+
# Vendor result-level fields (e.g. prefetch under result.{prefetch})
525+
if result_level_fields:
526+
asr_info.update(result_level_fields)
527+
504528
# Build final metadata structure
505529
metadata: dict[str, Any] = {}
506530
if session_id is not None:
@@ -510,7 +534,9 @@ def _build_metadata_with_asr_info(
510534
return metadata
511535

512536
def _extract_final_result_metadata(
513-
self, utterance: Utterance
537+
self,
538+
utterance: Utterance,
539+
result_level_fields: dict[str, Any] | None = None,
514540
) -> dict[str, Any]:
515541
"""Extract metadata from utterance additions.
516542
@@ -524,11 +550,14 @@ def _extract_final_result_metadata(
524550
additional_fields = utterance.additions
525551

526552
return self._build_metadata_with_asr_info(
527-
additional_fields=additional_fields
553+
additional_fields=additional_fields,
554+
result_level_fields=result_level_fields,
528555
)
529556

530557
def _extract_non_final_result_metadata(
531-
self, utterance: Utterance
558+
self,
559+
utterance: Utterance,
560+
result_level_fields: dict[str, Any] | None = None,
532561
) -> dict[str, Any]:
533562
"""Extract metadata from utterance additions for non-final results.
534563
@@ -550,7 +579,8 @@ def _extract_non_final_result_metadata(
550579
additional_fields["source"] = additions["source"]
551580

552581
return self._build_metadata_with_asr_info(
553-
additional_fields=additional_fields
582+
additional_fields=additional_fields,
583+
result_level_fields=result_level_fields,
554584
)
555585

556586
def _calculate_utterance_start_ms(
@@ -719,6 +749,8 @@ async def _on_asr_result(self, result: ASRResponse) -> None:
719749
category=LOG_CATEGORY_VENDOR,
720750
)
721751

752+
result_level_asr_info = self._result_level_asr_info_fields(result)
753+
722754
# Process utterances: send definite=true individually,
723755
# and concatenate adjacent definite=false utterances together
724756
if not result.utterances:
@@ -728,7 +760,9 @@ async def _on_asr_result(self, result: ASRResponse) -> None:
728760
result.start_ms
729761
)
730762
# Build metadata according to protocol: session_id at root, others in asr_info
731-
metadata = self._build_metadata_with_asr_info()
763+
metadata = self._build_metadata_with_asr_info(
764+
result_level_fields=result_level_asr_info
765+
)
732766
await self._send_asr_result_from_text(
733767
text=result.text,
734768
is_final=False,
@@ -767,11 +801,13 @@ async def _on_asr_result(self, result: ASRResponse) -> None:
767801
if is_final:
768802
has_final_result = True
769803
metadata = self._extract_final_result_metadata(
770-
utterance
804+
utterance,
805+
result_level_fields=result_level_asr_info,
771806
)
772807
else:
773808
metadata = self._extract_non_final_result_metadata(
774-
utterance
809+
utterance,
810+
result_level_fields=result_level_asr_info,
775811
)
776812

777813
await self._send_asr_result_from_text(
@@ -816,10 +852,14 @@ async def _on_asr_result(self, result: ASRResponse) -> None:
816852
"start_time": first.start_time,
817853
"duration_ms": last.end_time - first.start_time,
818854
"metadata": (
819-
self._extract_final_result_metadata(last)
855+
self._extract_final_result_metadata(
856+
last,
857+
result_level_fields=result_level_asr_info,
858+
)
820859
if is_final
821860
else self._extract_non_final_result_metadata(
822-
last
861+
last,
862+
result_level_fields=result_level_asr_info,
823863
)
824864
),
825865
"utterance": last, # Keep reference for timestamp tracking

ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/manifest.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"type": "extension",
33
"name": "bytedance_llm_based_asr",
4-
"version": "0.3.18",
4+
"version": "0.3.20",
55
"dependencies": [
66
{
77
"type": "system",
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import os
2+
import sys
3+
import types
4+
from unittest.mock import AsyncMock, MagicMock
5+
6+
import pytest
7+
8+
9+
extension_dir = os.path.join(os.path.dirname(__file__), "..")
10+
sys.path.insert(0, extension_dir)
11+
12+
package = types.ModuleType("bytedance_llm_based_asr")
13+
package.__path__ = [extension_dir]
14+
sys.modules["bytedance_llm_based_asr"] = package
15+
16+
from bytedance_llm_based_asr import config as config_module
17+
18+
sys.modules["bytedance_llm_based_asr.config"] = config_module
19+
20+
from bytedance_llm_based_asr import extension as extension_module
21+
22+
sys.modules["bytedance_llm_based_asr.extension"] = extension_module
23+
24+
from bytedance_llm_based_asr.extension import BytedanceASRLLMExtension
25+
26+
27+
@pytest.fixture
28+
def mock_ten_env():
29+
env = AsyncMock()
30+
env.log = MagicMock()
31+
env.log_debug = MagicMock()
32+
env.log_info = MagicMock()
33+
env.log_warn = MagicMock()
34+
env.log_error = MagicMock()
35+
return env
36+
37+
38+
@pytest.fixture
39+
def extension(mock_ten_env):
40+
ext = BytedanceASRLLMExtension("test_extension")
41+
ext.ten_env = mock_ten_env
42+
ext.connected = True
43+
ext.client = MagicMock()
44+
ext.client.connected = False
45+
ext.client.send_audio = AsyncMock()
46+
return ext
47+
48+
49+
@pytest.mark.asyncio
50+
async def test_handle_audio_frame_buffers_when_client_state_is_disconnected(
51+
extension, mock_ten_env
52+
):
53+
frame = MagicMock()
54+
frame.get_buf.return_value = b"\x00\x01"
55+
56+
await extension._handle_audio_frame(mock_ten_env, frame)
57+
58+
assert extension.is_connected() is False
59+
assert extension.buffered_frames.qsize() == 1
60+
extension.client.send_audio.assert_not_awaited()
61+
mock_ten_env.log_debug.assert_called_once_with(
62+
"send_frame: service not connected."
63+
)
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import os
2+
import sys
3+
import types
4+
from unittest.mock import AsyncMock, MagicMock
5+
6+
import pytest
7+
8+
9+
extension_dir = os.path.join(os.path.dirname(__file__), "..")
10+
sys.path.insert(0, extension_dir)
11+
12+
package = types.ModuleType("bytedance_llm_based_asr")
13+
package.__path__ = [extension_dir]
14+
sys.modules["bytedance_llm_based_asr"] = package
15+
16+
from bytedance_llm_based_asr import config as config_module
17+
18+
sys.modules["bytedance_llm_based_asr.config"] = config_module
19+
20+
from bytedance_llm_based_asr import extension as extension_module
21+
22+
sys.modules["bytedance_llm_based_asr.extension"] = extension_module
23+
24+
from bytedance_llm_based_asr.extension import BytedanceASRLLMExtension
25+
26+
27+
@pytest.fixture
28+
def mock_ten_env():
29+
env = AsyncMock()
30+
env.log = MagicMock()
31+
env.log_debug = MagicMock()
32+
env.log_info = MagicMock()
33+
env.log_warn = MagicMock()
34+
env.log_error = MagicMock()
35+
return env
36+
37+
38+
@pytest.fixture
39+
def mock_frame():
40+
frame = MagicMock()
41+
frame.lock_buf.return_value = b"\x00\x01"
42+
return frame
43+
44+
45+
@pytest.fixture
46+
def extension(mock_ten_env):
47+
ext = BytedanceASRLLMExtension("test_extension")
48+
ext.ten_env = mock_ten_env
49+
ext.connected = True
50+
ext.client = MagicMock()
51+
ext.client.send_audio = AsyncMock()
52+
ext.send_asr_error = AsyncMock()
53+
return ext
54+
55+
56+
@pytest.mark.asyncio
57+
async def test_send_audio_suppresses_stop_phase_disconnect_error(
58+
extension, mock_ten_env, mock_frame
59+
):
60+
extension.stopped = True
61+
extension.client.send_audio.side_effect = RuntimeError(
62+
"Not connected to ASR service"
63+
)
64+
65+
result = await extension.send_audio(mock_frame, "session-1")
66+
67+
assert result is False
68+
extension.send_asr_error.assert_not_awaited()
69+
mock_ten_env.log.assert_not_called()
70+
mock_ten_env.log_debug.assert_not_called()
71+
mock_frame.unlock_buf.assert_called_once_with(b"\x00\x01")
72+
73+
74+
@pytest.mark.asyncio
75+
async def test_send_audio_reports_runtime_disconnect_error(
76+
extension, mock_ten_env, mock_frame
77+
):
78+
extension.stopped = False
79+
extension.client.send_audio.side_effect = RuntimeError(
80+
"Not connected to ASR service"
81+
)
82+
83+
result = await extension.send_audio(mock_frame, "session-1")
84+
85+
assert result is False
86+
extension.send_asr_error.assert_awaited_once()
87+
mock_ten_env.log.assert_called_once()
88+
mock_frame.unlock_buf.assert_called_once_with(b"\x00\x01")
89+
90+
91+
@pytest.mark.asyncio
92+
async def test_send_audio_does_not_suppress_other_stop_phase_errors(
93+
extension, mock_ten_env, mock_frame
94+
):
95+
extension.stopped = True
96+
extension.client.send_audio.side_effect = ValueError("unexpected failure")
97+
98+
result = await extension.send_audio(mock_frame, "session-1")
99+
100+
assert result is False
101+
extension.send_asr_error.assert_not_awaited()
102+
mock_ten_env.log.assert_not_called()
103+
mock_ten_env.log_debug.assert_not_called()
104+
mock_frame.unlock_buf.assert_called_once_with(b"\x00\x01")

0 commit comments

Comments
 (0)