Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions funasr/auto/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,18 @@
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils import export_utils
from funasr.utils import misc


def is_npu_available():
"""检查NPU是否可用。"""
try:
import torch_npu

return torch_npu.npu.is_available()
except ImportError:
return False


def _resolve_ncpu(config, fallback=4):
"""Return a positive integer representing CPU threads from config."""
value = config.get("ncpu", fallback)
Expand All @@ -46,6 +50,7 @@ def _resolve_ncpu(config, fallback=4):
value = fallback
return max(value, 1)


try:
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
from funasr.models.campplus.cluster_backend import ClusterBackend
Expand Down Expand Up @@ -202,11 +207,13 @@ def build_model(**kwargs):
set_all_random_seed(kwargs.get("seed", 0))

device = kwargs.get("device", "cuda")
if ((device =="cuda" and not torch.cuda.is_available())
if (
(device == "cuda" and not torch.cuda.is_available())
or (device == "xpu" and not torch.xpu.is_available())
or (device == "mps" and not torch.backends.mps.is_available())
or (device == "npu" and not is_npu_available())
or kwargs.get("ngpu", 1) == 0):
or kwargs.get("ngpu", 1) == 0
):
device = "cpu"
kwargs["batch_size"] = 1
kwargs["device"] = device
Expand Down Expand Up @@ -573,8 +580,12 @@ def inference_with_vad(self, input, input_len=None, **cfg):
result[k] = []
for t in restored_data[j][k]:
if isinstance(t, dict):
t["start_time"] = (float(t["start_time"]) * 1000 + int(vadsegments[j][0])) / 1000
t["end_time"] = (float(t["end_time"]) * 1000 + int(vadsegments[j][0])) / 1000
t["start_time"] = (
float(t["start_time"]) * 1000 + int(vadsegments[j][0])
) / 1000
t["end_time"] = (
float(t["end_time"]) * 1000 + int(vadsegments[j][0])
) / 1000
else:
t[0] = int(t[0]) + int(vadsegments[j][0])
t[1] = int(t[1]) + int(vadsegments[j][0])
Expand All @@ -600,6 +611,7 @@ def inference_with_vad(self, input, input_len=None, **cfg):
return_raw_text = kwargs.get("return_raw_text", False)
# step.3 compute punc model
raw_text = None
punc_res = None
if self.punc_model is not None:
deep_update(self.punc_kwargs, cfg)
punc_res = self.inference(
Expand Down Expand Up @@ -645,7 +657,12 @@ def inference_with_vad(self, input, input_len=None, **cfg):
and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
can predict timestamp, and speaker diarization relies on timestamps."
)
if kwargs.get("en_post_proc", False):
if punc_res is None:
logging.error(
"Missing punc_model, which is required for punc_segment speaker diarization."
)
sentence_list = []
elif kwargs.get("en_post_proc", False):
sentence_list = timestamp_sentence_en(
punc_res[0]["punc_array"],
result["timestamp"],
Expand All @@ -664,6 +681,11 @@ def inference_with_vad(self, input, input_len=None, **cfg):
elif kwargs.get("sentence_timestamp", False):
if not len(result["text"].strip()):
sentence_list = []
elif punc_res is None:
logging.warning(
"punc_model is required for sentence_timestamp, skipping sentence segmentation."
)
sentence_list = []
else:
if kwargs.get("en_post_proc", False):
sentence_list = timestamp_sentence_en(
Expand Down
119 changes: 119 additions & 0 deletions tests/test_punc_model_none.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Tests for issue #2839: punc_model=None or empty string should not cause UnboundLocalError."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring states that these tests cover punc_model=None or an empty string. However, the tests only cover the None case. An empty string for punc_model would likely cause a different failure mode during model initialization, not the UnboundLocalError this PR aims to fix. To avoid confusion, please update the docstring to only mention the punc_model=None case.

Suggested change
"""Tests for issue #2839: punc_model=None or empty string should not cause UnboundLocalError."""
"""Tests for issue #2839: punc_model=None should not cause UnboundLocalError."""


import unittest
from unittest.mock import MagicMock, patch
import numpy as np


class TestPuncModelNone(unittest.TestCase):
"""Test that inference_with_vad works when punc_model is None."""

def _make_auto_model(self, punc_model=None, spk_model=None, spk_mode=None):
"""Create a minimal AutoModel instance with mocked dependencies."""
from funasr.auto.auto_model import AutoModel

am = AutoModel.__new__(AutoModel)
am.model = MagicMock()
am.vad_model = MagicMock()
am.punc_model = punc_model
am.punc_kwargs = {}
am.spk_model = spk_model
am.cb_model = None
am.spk_mode = spk_mode
am.vad_kwargs = {}
am.kwargs = {
"batch_size_s": 300,
"batch_size_threshold_s": 60,
"device": "cpu",
"disable_pbar": True,
"frontend": MagicMock(fs=16000),
"fs": 16000,
}
am._reset_runtime_configs = MagicMock()
return am

def _setup_mocks(self, am, mock_slice, mock_load, mock_prep):
"""Configure standard mocks for a single-segment VAD + ASR flow."""
# VAD returns one segment [0, 16000ms]
vad_result = [{"key": "test_utt", "value": [[0, 16000]]}]
# ASR returns text with timestamps
asr_result = [{"text": "hello world", "timestamp": [[0, 500], [500, 1000]]}]

call_count = [0]
results_seq = [vad_result, asr_result]

def mock_inference(data, input_len=None, model=None, kwargs=None, **cfg):
idx = call_count[0]
call_count[0] += 1
if idx < len(results_seq):
return results_seq[idx]
return [{"text": ""}]

am.inference = MagicMock(side_effect=mock_inference)
mock_prep.return_value = (["test_utt"], [np.zeros(16000, dtype=np.float32)])
mock_load.return_value = np.zeros(16000, dtype=np.float32)
mock_slice.return_value = ([np.zeros(16000, dtype=np.float32)], [16000])

@patch("funasr.auto.auto_model.slice_padding_audio_samples")
@patch("funasr.auto.auto_model.load_audio_text_image_video")
@patch("funasr.auto.auto_model.prepare_data_iterator")
def test_punc_model_none_basic(self, mock_prep, mock_load, mock_slice):
"""Basic inference with punc_model=None should not raise UnboundLocalError."""
am = self._make_auto_model(punc_model=None)
self._setup_mocks(am, mock_slice, mock_load, mock_prep)

results = am.inference_with_vad("dummy_input")

self.assertEqual(len(results), 1)
self.assertEqual(results[0]["text"], "hello world")
self.assertEqual(results[0]["key"], "test_utt")

@patch("funasr.auto.auto_model.slice_padding_audio_samples")
@patch("funasr.auto.auto_model.load_audio_text_image_video")
@patch("funasr.auto.auto_model.prepare_data_iterator")
def test_sentence_timestamp_with_punc_model_none(self, mock_prep, mock_load, mock_slice):
"""sentence_timestamp=True with punc_model=None should not crash."""
am = self._make_auto_model(punc_model=None)
self._setup_mocks(am, mock_slice, mock_load, mock_prep)

# This path previously caused UnboundLocalError on punc_res
results = am.inference_with_vad("dummy_input", sentence_timestamp=True)

self.assertEqual(len(results), 1)
# sentence_info should be empty list since punc_res is unavailable
self.assertEqual(results[0].get("sentence_info"), [])

@patch("funasr.auto.auto_model.slice_padding_audio_samples")
@patch("funasr.auto.auto_model.load_audio_text_image_video")
@patch("funasr.auto.auto_model.prepare_data_iterator")
def test_punc_model_with_value_still_works(self, mock_prep, mock_load, mock_slice):
"""When punc_model is provided, punc_res should still be used normally."""
punc_mock = MagicMock()
am = self._make_auto_model(punc_model=punc_mock)

vad_result = [{"key": "test_utt", "value": [[0, 16000]]}]
asr_result = [{"text": "hello world", "timestamp": [[0, 500], [500, 1000]]}]
punc_result = [{"text": "Hello, world.", "punc_array": [1, 2]}]

call_count = [0]
results_seq = [vad_result, asr_result, punc_result]

def mock_inference(data, input_len=None, model=None, kwargs=None, **cfg):
idx = call_count[0]
call_count[0] += 1
return results_seq[idx]

am.inference = MagicMock(side_effect=mock_inference)
mock_prep.return_value = (["test_utt"], [np.zeros(16000, dtype=np.float32)])
mock_load.return_value = np.zeros(16000, dtype=np.float32)
mock_slice.return_value = ([np.zeros(16000, dtype=np.float32)], [16000])

results = am.inference_with_vad("dummy_input")

self.assertEqual(len(results), 1)
# Text should be updated with punctuated version
self.assertEqual(results[0]["text"], "Hello, world.")


if __name__ == "__main__":
unittest.main()