-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathtest_punc_model_none.py
More file actions
119 lines (95 loc) · 4.97 KB
/
test_punc_model_none.py
File metadata and controls
119 lines (95 loc) · 4.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""Tests for issue #2839: punc_model=None should not cause UnboundLocalError."""
import unittest
from unittest.mock import MagicMock, patch
import numpy as np
class TestPuncModelNone(unittest.TestCase):
"""Test that inference_with_vad works when punc_model is None."""
def _make_auto_model(self, punc_model=None, spk_model=None, spk_mode=None):
"""Create a minimal AutoModel instance with mocked dependencies."""
from funasr.auto.auto_model import AutoModel
am = AutoModel.__new__(AutoModel)
am.model = MagicMock()
am.vad_model = MagicMock()
am.punc_model = punc_model
am.punc_kwargs = {}
am.spk_model = spk_model
am.cb_model = None
am.spk_mode = spk_mode
am.vad_kwargs = {}
am.kwargs = {
"batch_size_s": 300,
"batch_size_threshold_s": 60,
"device": "cpu",
"disable_pbar": True,
"frontend": MagicMock(fs=16000),
"fs": 16000,
}
am._reset_runtime_configs = MagicMock()
return am
def _setup_mocks(self, am, mock_slice, mock_load, mock_prep):
"""Configure standard mocks for a single-segment VAD + ASR flow."""
# VAD returns one segment [0, 16000ms]
vad_result = [{"key": "test_utt", "value": [[0, 16000]]}]
# ASR returns text with timestamps
asr_result = [{"text": "hello world", "timestamp": [[0, 500], [500, 1000]]}]
call_count = [0]
results_seq = [vad_result, asr_result]
def mock_inference(data, input_len=None, model=None, kwargs=None, **cfg):
idx = call_count[0]
call_count[0] += 1
if idx < len(results_seq):
return results_seq[idx]
return [{"text": ""}]
am.inference = MagicMock(side_effect=mock_inference)
mock_prep.return_value = (["test_utt"], [np.zeros(16000, dtype=np.float32)])
mock_load.return_value = np.zeros(16000, dtype=np.float32)
mock_slice.return_value = ([np.zeros(16000, dtype=np.float32)], [16000])
@patch("funasr.auto.auto_model.slice_padding_audio_samples")
@patch("funasr.auto.auto_model.load_audio_text_image_video")
@patch("funasr.auto.auto_model.prepare_data_iterator")
def test_punc_model_none_basic(self, mock_prep, mock_load, mock_slice):
"""Basic inference with punc_model=None should not raise UnboundLocalError."""
am = self._make_auto_model(punc_model=None)
self._setup_mocks(am, mock_slice, mock_load, mock_prep)
results = am.inference_with_vad("dummy_input")
self.assertEqual(len(results), 1)
self.assertEqual(results[0]["text"], "hello world")
self.assertEqual(results[0]["key"], "test_utt")
@patch("funasr.auto.auto_model.slice_padding_audio_samples")
@patch("funasr.auto.auto_model.load_audio_text_image_video")
@patch("funasr.auto.auto_model.prepare_data_iterator")
def test_sentence_timestamp_with_punc_model_none(self, mock_prep, mock_load, mock_slice):
"""sentence_timestamp=True with punc_model=None should not crash."""
am = self._make_auto_model(punc_model=None)
self._setup_mocks(am, mock_slice, mock_load, mock_prep)
# This path previously caused UnboundLocalError on punc_res
results = am.inference_with_vad("dummy_input", sentence_timestamp=True)
self.assertEqual(len(results), 1)
# sentence_info should be empty list since punc_res is unavailable
self.assertEqual(results[0].get("sentence_info"), [])
@patch("funasr.auto.auto_model.slice_padding_audio_samples")
@patch("funasr.auto.auto_model.load_audio_text_image_video")
@patch("funasr.auto.auto_model.prepare_data_iterator")
def test_punc_model_with_value_still_works(self, mock_prep, mock_load, mock_slice):
"""When punc_model is provided, punc_res should still be used normally."""
punc_mock = MagicMock()
am = self._make_auto_model(punc_model=punc_mock)
vad_result = [{"key": "test_utt", "value": [[0, 16000]]}]
asr_result = [{"text": "hello world", "timestamp": [[0, 500], [500, 1000]]}]
punc_result = [{"text": "Hello, world.", "punc_array": [1, 2]}]
call_count = [0]
results_seq = [vad_result, asr_result, punc_result]
def mock_inference(data, input_len=None, model=None, kwargs=None, **cfg):
idx = call_count[0]
call_count[0] += 1
return results_seq[idx]
am.inference = MagicMock(side_effect=mock_inference)
mock_prep.return_value = (["test_utt"], [np.zeros(16000, dtype=np.float32)])
mock_load.return_value = np.zeros(16000, dtype=np.float32)
mock_slice.return_value = ([np.zeros(16000, dtype=np.float32)], [16000])
results = am.inference_with_vad("dummy_input")
self.assertEqual(len(results), 1)
# Text should be updated with punctuated version
self.assertEqual(results[0]["text"], "Hello, world.")
if __name__ == "__main__":
unittest.main()