-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathweb_demo.py
More file actions
1503 lines (1215 loc) · 65.2 KB
/
web_demo.py
File metadata and controls
1503 lines (1215 loc) · 65.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2026 The Alibaba Qwen team.
# SPDX-License-Identifier: Apache-2.0
"""
A gradio demo for Qwen3 TTS models (All-in-One).
Final Polish: 9 Official Roles, Smart Availability Check, Auto-Download.
"""
print("Script started...")
import argparse
import gc
import re
import threading
import warnings
from collections import defaultdict
# Suppress noisy warnings from transformers
warnings.filterwarnings("ignore", message=".*pad_token_id.*")
warnings.filterwarnings("ignore", category=UserWarning, module="gradio")
# Global lock to prevent duplicate concurrent generations
_generation_lock = threading.Lock()
_active_generations = defaultdict(bool)
import logging
import os
import tempfile
import traceback
import glob
import json
from collections import defaultdict
from dataclasses import asdict
from time import time
from typing import Any, Dict, List, Optional, Tuple
import gradio as gr
import numpy as np
import torch
from qwen_tts import Qwen3TTSModel, VoiceClonePromptItem
try:
from huggingface_hub import snapshot_download
except ImportError:
snapshot_download = None
# --- Constants ---
OFFICIAL_SPEAKERS = [
{"name": "Vivian", "lang": "Chinese", "desc": "Bright, slightly edgy young female"},
{"name": "Serena", "lang": "Chinese", "desc": "Warm, gentle young female"},
{"name": "Uncle_Fu", "lang": "Chinese", "desc": "Seasoned male, low mellow timbre"},
{"name": "Dylan", "lang": "Chinese (Beijing)", "desc": "Youthful Beijing male, clear"},
{"name": "Eric", "lang": "Chinese (Sichuan)", "desc": "Lively Chengdu male, husky"},
{"name": "Ryan", "lang": "English", "desc": "Dynamic male, strong rhythmic"},
{"name": "Aiden", "lang": "English", "desc": "Sunny American male, clear midrange"},
{"name": "Ono_Anna", "lang": "Japanese", "desc": "Playful Japanese female"},
{"name": "Sohee", "lang": "Korean", "desc": "Warm Korean female, rich emotion"},
]
CUSTOM_VOICE_REPO = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
CUSTOM_VOICE_DIR_NAME = "Qwen3-TTS-12Hz-1.7B-CustomVoice"
FULL_LANGUAGES = [
"Auto", "Chinese", "English", "Japanese", "Korean",
"German", "French", "Russian", "Spanish", "Portuguese", "Italian"
]
# --- Global State Management ---
class ModelManager:
def __init__(self):
self.tts: Optional[Qwen3TTSModel] = None
self.ckpt_path: str = ""
self.model_type: str = "Unknown"
self.device: str = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype: torch.dtype = torch.bfloat16
self.attn_impl: Optional[str] = None
self.use_flash_attn = False # Default safely
def load(self, model_path: str, device: str = "cpu", dtype: str = "auto", use_flash_attn: bool = False):
try:
if not model_path or model_path.strip() == "":
return "Error: No model selected."
print(f"Loading model from {model_path} on {device} (FlashAttn={use_flash_attn})...")
clean_path = model_path.replace("⚫ ", "").replace("⚪ ", "").strip() # clean status chars
if not os.path.exists(clean_path):
return f"Error: Path not found: {clean_path}"
self.device = device
self.use_flash_attn = use_flash_attn # Persist this!
if dtype == "auto":
self.dtype = "auto"
elif dtype == "float16":
self.dtype = torch.float16
elif dtype == "bfloat16":
self.dtype = torch.bfloat16
else:
self.dtype = torch.float32
self.attn_impl = "flash_attention_2" if use_flash_attn else "sdpa" # default to sdpa if not flash
# On MPS, usually we just let it be None or sdpa equivalent.
# Transformers handles it, but passing "flash_attention_2" when not installed throws error.
if not use_flash_attn: self.attn_impl = None
self.tts = Qwen3TTSModel.from_pretrained(
clean_path,
device_map=self.device,
dtype=self.dtype,
attn_implementation=self.attn_impl,
)
self.model_type = getattr(self.tts.model, "tts_model_type", "base")
if self.model_type not in ("custom_voice", "voice_design", "base"):
if "CustomVoice" in clean_path: self.model_type = "custom_voice"
elif "VoiceDesign" in clean_path: self.model_type = "voice_design"
else: self.model_type = "base"
print(f"Model loaded successfully. Type detected: {self.model_type}", flush=True)
return f"Loaded: {self.model_type}"
except Exception as e:
traceback.print_exc()
return f"Error: {str(e)}"
def smart_switch(self, target_type: str):
"""
Returns (True, "Loaded...") or (False, "Error details...")
"""
# 1. Check current
if self.tts is not None and self.model_type == target_type:
return True, f"Already loaded {target_type}"
debug_msg = [f"Request: {target_type}"]
# 2. Find target model path
models_dir = os.path.abspath("./models")
candidates = []
debug_msg.append(f"Scanning {models_dir}...")
if os.path.exists(models_dir):
for d in os.listdir(models_dir):
full = os.path.join(models_dir, d)
if not os.path.isdir(full): continue
debug_msg.append(f" Check: {d}")
# Match Logic
d_normalized = d.lower().replace("_", "").replace("-", "")
t_normalized = target_type.lower().replace("_", "").replace("-", "")
if t_normalized in d_normalized:
candidates.append(full)
continue
# Specific fallbacks
if target_type == "custom_voice" and "custom" in d.lower(): candidates.append(full)
if target_type == "base" and "base" in d.lower(): candidates.append(full)
if target_type == "voice_design" and "design" in d.lower(): candidates.append(full)
candidates = list(set(candidates))
def sort_key(p):
s = os.path.basename(p)
score = 0
if "1.7B" in s: score += 100
if "12Hz" in s: score += 10
return (score, s)
candidates.sort(key=sort_key, reverse=True)
if not candidates:
return False, "\n".join(debug_msg) + f"\nNo candidates for {target_type}"
target_path = candidates[0]
debug_msg.append(f"Selected: {target_path}")
# 3. Unload
if self.tts is not None:
del self.tts
self.tts = None
if torch.cuda.is_available(): torch.cuda.empty_cache()
if torch.backends.mps.is_available(): torch.mps.empty_cache()
import gc; gc.collect()
# 4. Load
# Detect device again
dev = self.device if self.device else ("cuda" if torch.cuda.is_available() else "mps")
# Use PERSISTED Flash Attn setting
use_fa = self.use_flash_attn
res = self.load(target_path, dev, "bfloat16", use_fa)
if "Error" in res:
return False, f"Load Error: {res}"
return True, f"Switched to {target_path}"
MANAGER = ModelManager()
# --- Helpers ---
def _title_case_display(s: str) -> str:
s = (s or "").strip()
return " ".join([w[:1].upper() + w[1:] if w else "" for w in s.replace("_", " ").split()])
def _build_choices_and_map(items: Optional[List[str]]) -> Tuple[List[str], Dict[str, str]]:
if not items: return [], {}
display = [_title_case_display(x) for x in items]
mapping = {d: r for d, r in zip(display, items)}
return display, mapping
def _wav_to_gradio_audio(wav: np.ndarray, sr: int) -> Tuple[int, np.ndarray]:
wav = np.asarray(wav, dtype=np.float32)
return sr, wav
def _audio_to_tuple(audio: Any) -> Optional[Tuple[np.ndarray, int]]:
if audio is None: return None
if isinstance(audio, tuple) and len(audio) == 2 and isinstance(audio[0], int):
sr, wav = audio
return _normalize_audio(wav), int(sr)
if isinstance(audio, dict) and "sampling_rate" in audio and "data" in audio:
sr = int(audio["sampling_rate"])
wav = _normalize_audio(audio["data"])
return wav, sr
return None
def _normalize_audio(wav, eps=1e-12, clip=True):
x = np.asarray(wav)
if np.issubdtype(x.dtype, np.integer):
info = np.iinfo(x.dtype)
y = x.astype(np.float32) / max(abs(info.min), info.max) if info.min < 0 else (x.astype(np.float32) - (info.max + 1) / 2.0) / ((info.max + 1) / 2.0)
elif np.issubdtype(x.dtype, np.floating):
y = x.astype(np.float32)
m = np.max(np.abs(y)) if y.size else 0.0
if m > 1.0 + 1e-6: y = y / (m + eps)
else: raise TypeError(f"Unsupported dtype: {x.dtype}")
if clip: y = np.clip(y, -1.0, 1.0)
if y.ndim > 1: y = np.mean(y, axis=-1).astype(np.float32)
return y
def _check_model_ready(required_type: str = None):
if MANAGER.tts is None: return None, "Please load a model first (请先加载模型)."
if required_type and MANAGER.model_type != required_type:
if required_type == "base" and MANAGER.model_type != "base":
return None, f"Warning: Loaded model is `{MANAGER.model_type}`, expected `base`."
if required_type != "base" and MANAGER.model_type != required_type:
return None, f"Error: Loaded `{MANAGER.model_type}`, but feature needs `{required_type}`."
return MANAGER.tts, None
def _gen_kwargs(temp, topp, topk, rep_pen):
return {"temperature": temp, "top_p": topp, "top_k": int(topk), "repetition_penalty": rep_pen, "do_sample": True}
# --- Smart Funcs ---
def smart_split_text(text, max_len=300):
"""
Splits long text into chunks by punctuation to avoid OOM.
Prioritizes splitting at: \n, then [. ? ! 。 ? !], then [, ,].
"""
if len(text) < max_len: return [text]
chunks = []
curr = ""
# Priority 1: Newlines (User's natural paragraphs)
raw_lines = text.split('\n')
import re
# Split by sentence endings (Chinese & English)
# Pattern: keep the delimiter. Includes . ? ! and Chinese variants
split_pat = r'([。?!.?!])'
for line in raw_lines:
line = line.strip()
if not line: continue
# Priority 2: Sentence Endings
parts = re.split(split_pat, line)
# Re-attach delimiters
sentences = []
for i in range(0, len(parts)-1, 2):
sent = parts[i] + parts[i+1]
if sent.strip(): sentences.append(sent)
# Handle last part if no delimiter
if len(parts) % 2 == 1 and parts[-1].strip():
sentences.append(parts[-1])
for sent in sentences:
if len(curr) + len(sent) > max_len:
if curr.strip(): chunks.append(curr)
curr = sent
else:
curr += sent
if curr.strip(): chunks.append(curr)
# Just in case some chunk is STILL huge (no punctuation), force split
final_chunks = []
for c in chunks:
while len(c) > max_len:
# Find a comma?
comma_split = c[:max_len].rfind(',')
if comma_split == -1: comma_split = c[:max_len].rfind(',')
if comma_split > max_len // 2: # Good split
final_chunks.append(c[:comma_split+1])
c = c[comma_split+1:]
else: # Hard split
final_chunks.append(c[:max_len])
c = c[max_len:]
if c: final_chunks.append(c)
return final_chunks
def generate_long_text_safe(gen_func_callback):
"""
Wrapper to handle long text segmentation and concatenation.
callback: function(text_chunk) -> (wav_chunk, sr)
"""
import gc
full_wavs = []
sr = 24000
try:
# yield progress? We can't easily yielding progress in this helper without passing yield up.
# But we can assume the caller splits text first.
# Actually, let's just make the caller loop.
pass
except: pass
return None
def scan_local_models():
"""Scans for models in ./models."""
base_dir = "."
models_dir = os.path.join(base_dir, "models")
known_models = [
"Qwen3-TTS-12Hz-1.7B-Base",
"Qwen3-TTS-12Hz-1.7B-VoiceDesign",
"Qwen3-TTS-12Hz-1.7B-CustomVoice"
]
found_choices = []
if os.path.exists(models_dir):
subdirs = [d for d in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, d)) and not d.startswith(".")]
for d in subdirs:
full_path = os.path.join(models_dir, d)
found_choices.append((f"⚫ {d}", full_path))
existing_paths = [c[1] for c in found_choices]
for km in known_models:
is_found = False
for existing in existing_paths:
if km in existing: is_found = True; break
if not is_found:
found_choices.append((f"⚪ {km} (Not Found)", ""))
# Available first
found_choices.sort(key=lambda x: (not x[1], x[0]))
return found_choices
def check_custom_voice_available():
"""Checks if ANY CustomVoice model is present in ./models."""
models_dir = "./models"
if not os.path.exists(models_dir): return False
for d in os.listdir(models_dir):
if "CustomVoice" in d and os.path.isdir(os.path.join(models_dir, d)):
return True
return False
def get_speaker_choices():
is_avail = check_custom_voice_available()
choices = []
for s in OFFICIAL_SPEAKERS:
label = s["name"]
if not is_avail:
label = f"⚪ {s['name']} (Click to Download Model)"
else:
label = f"{s['name']} | {s['lang']}"
choices.append(label)
return choices
def download_custom_voice_model():
if snapshot_download is None:
return "Error: huggingface_hub not installed."
target_dir = os.path.join("./models", CUSTOM_VOICE_DIR_NAME)
print(f"Downloading {CUSTOM_VOICE_REPO} to {target_dir}...")
try:
snapshot_download(repo_id=CUSTOM_VOICE_REPO, local_dir=target_dir)
return f"Successfully downloaded to {target_dir}. Please Click Refresh and Load!"
except Exception as e:
return f"Download Failed: {e}"
# --- Audiobook Engine ---
class AudiobookEngine:
def __init__(self):
pass
@staticmethod
def parse_script(text):
"""
Parses script text into lines.
Format: RoleName: Text content... [p=1.0]
Returns list of dict: {'role': str, 'text': str, 'original': str}
"""
lines = []
raw_lines = text.strip().split('\n')
for l in raw_lines:
l = l.strip()
if not l: continue
# Normalization
l = l.replace("……", "...").replace("——", ",")
role = "Narrator" # Default
content = l
# Check for Role: format
if ":" in l:
parts = l.split(":", 1)
potential_role = parts[0].strip()
# Simple heuristic: Role names usually short (< 20 chars)
if len(potential_role) < 20:
role = potential_role
content = parts[1].strip()
lines.append({"role": role, "text": content, "original": l})
return lines
@staticmethod
def _parse_pauses(text, defaults):
"""
Splits text by [p=x] tags and punctuation.
Returns list of segments: [{'text': str, 'pause': float}]
"""
# 1. Extract [p=x] tags
# Regex for [p=1.2]
import re
segments = []
# Split by custom pause tags
# pattern: (text_before)([p=num])(text_after)
# We process manually to preserve order
# Pre-process punctuation pauses
# We don't actually split string by punctuation usually for TTS,
# because TTS needs context.
# But for "Director's Cut" control, we might need to split sentences
# if the user wants specific delays *between* sentences.
# For MVP: We only apply pauses AFTER the line (Linebreak).
# AND if the user inserted [p=x] inside the line.
# Let's handle [p=x] first.
parts = re.split(r'(\[p=[\d\.]+\])', text)
current_text = ""
for p in parts:
if p.startswith("[p=") and p.endswith("]"):
# It's a pause tag
try:
val = float(p[3:-1])
except: val = 0.5
segments.append({"text": current_text, "pause": val})
current_text = ""
else:
current_text += p
if current_text:
segments.append({"text": current_text, "pause": 0.0}) # Last segment has 0 pause (handled by line break later)
return segments
@staticmethod
def generate_silence(duration_sec, sr=24000):
if duration_sec <= 0: return np.array([], dtype=np.float32)
return np.zeros(int(sr * duration_sec), dtype=np.float32)
@staticmethod
def normalize_audio(wav, target_rms=0.1):
"""
Normalize audio to consistent RMS level to prevent volume fluctuations.
This ensures all chunks have similar loudness when concatenated.
"""
if len(wav) == 0:
return wav
# Calculate current RMS
current_rms = np.sqrt(np.mean(wav**2))
# Avoid division by zero and skip if audio is silent
if current_rms < 1e-6:
return wav
# Scale to target RMS
scaling_factor = target_rms / current_rms
normalized = wav * scaling_factor
# Prevent clipping
max_val = np.abs(normalized).max()
if max_val > 0.95:
normalized = normalized * (0.95 / max_val)
return normalized
@staticmethod
def trim_audio(wav, threshold=0.01):
"""Removes leading/trailing noise/hallucinations."""
abs_wav = np.abs(wav)
mask = abs_wav > threshold
if not np.any(mask): return wav
start = np.argmax(mask)
# Give a 50ms buffer after threshold starts to avoid click
start = max(0, start - int(24000 * 0.05))
end = len(wav) - np.argmax(mask[::-1])
end = min(len(wav), end + int(24000 * 0.05))
return wav[start:end]
def render(self, script_text, role_map, pause_config, language="Auto", temperature=0.7, top_p=0.8, seed=42, progress=gr.Progress()):
"""
Render using Grouped Generation Strategy (Map-Reduce).
"""
import datetime
import soundfile as sf
import gc
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
out_dir = os.path.join(".", "outputs", f"audiobook_{timestamp}")
os.makedirs(out_dir, exist_ok=True)
log_path = os.path.join(out_dir, "run.log")
def log(msg):
print(msg)
with open(log_path, "a", encoding="utf-8") as f: f.write(msg + "\n")
log(f"--- Audiobook Render (Grouped) Started: {timestamp} ---")
log(f"Language Mode: {language} | Temp: {temperature} | TopP: {top_p} | Seed: {seed}")
# Stability: Set seed for consistent run-to-run results
if seed is not None:
torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed(seed)
if torch.backends.mps.is_available(): torch.mps.manual_seed(seed)
# MEMORY FIX: Force garbage collection at start
gc.collect()
if torch.cuda.is_available(): torch.cuda.empty_cache()
if torch.backends.mps.is_available(): torch.mps.empty_cache()
# 1. Parse & Analyze
parsed = self.parse_script(script_text)
tasks = [] # List of dict with metadata
for i, line in enumerate(parsed):
role = line['role']
text = line['text']
voice = role_map.get(role, role_map.get("Narrator", role_map.get("default", None)))
if not voice:
log(f"Warning: No voice for {role}, skipping line {i+1}")
continue
model_type = "base" if voice.endswith(".pt") else "custom_voice"
tasks.append({
"index": i,
"role": role,
"text": text,
"voice": voice,
"model_type": model_type,
"audio": None # Placeholder
})
log(f"Total Tasks: {len(tasks)}")
# 2. Group by Model Type
groups = {}
for t in tasks:
mt = t['model_type']
if mt not in groups: groups[mt] = []
groups[mt].append(t)
# 3. Execute Groups
# Order: CustomVoice first (usually narrators), then Base (clones)
priority = ["custom_voice", "base"]
for m_type in priority:
if m_type not in groups: continue
group_tasks = groups[m_type]
if not group_tasks: continue
log(f"\n>>> Switching to Model: {m_type} for {len(group_tasks)} lines...")
# Smart Switch
success, switch_msg = MANAGER.smart_switch(m_type)
log(f"Switch Result: {switch_msg}")
if not success:
log(f"CRITICAL: Failed to switch. Skipping this group.")
continue
# Execute Items
for t_idx, task in enumerate(group_tasks):
progress((t_idx / len(group_tasks)), desc=f"Generating Group {m_type} ({t_idx+1}/{len(group_tasks)})...")
raw_txt = task['text']
vc = task['voice']
# Handling segments ([p=x] tags)
segments = self._parse_pauses(raw_txt, pause_config)
line_wavs = []
sr = 24000
try:
for seg in segments:
raw_txt_seg = seg['text'].strip()
if raw_txt_seg:
# Buffering: Add leading space to stabilize model start state
# This often fixes "filler sounds" or starting pronunciation issues
txt = " " + raw_txt_seg
wav = None
if m_type == "base":
fpath = os.path.join("./voices", vc)
payload = torch.load(fpath, map_location="cpu")
items = [VoiceClonePromptItem(**x) for x in payload["items"]]
# MEMORY FIX: No Grad context to prevent graph growth
with torch.no_grad():
w_list, _sr = MANAGER.tts.generate_voice_clone(
text=txt,
language=language,
voice_clone_prompt=items,
temperature=temperature,
top_p=top_p
)
wav = w_list[0]; sr = _sr
del payload, items # Explicit delete
elif m_type == "custom_voice":
spk = vc.split("|")[0].replace("⚫", "").strip()
# MEMORY FIX: No Grad context
with torch.no_grad():
w_list, _sr = MANAGER.tts.generate_custom_voice(
text=txt,
speaker=spk,
language=language,
temperature=temperature,
top_p=top_p
)
wav = w_list[0]; sr = _sr
if wav is not None:
# FIX: Trim hallucinations/fillers at the start
wav = self.trim_audio(wav)
# FIX: Add tiny silence (0.05s) to avoid syllable clipping
pad = self.generate_silence(0.05, sr)
line_wavs.append(pad)
line_wavs.append(wav)
# Add requested pause
if seg['pause'] > 0:
line_wavs.append(self.generate_silence(seg['pause'], sr))
if line_wavs:
final_wav = np.concatenate(line_wavs)
# Save temp
fname = f"seg_{task['index']:03d}.wav"
fpath = os.path.join(out_dir, fname)
sf.write(fpath, final_wav, sr)
task['audio'] = (final_wav, sr)
log(f" -> Generated line {task['index']} ({task['role']})")
else:
log(f" -> Skipped line {task['index']}")
except Exception as e:
log(f" -> Error Line {task['index']}: {e}")
import traceback
traceback.print_exc()
# MEMORY FIX: Cleanup after each group loop
gc.collect()
if torch.backends.mps.is_available(): torch.mps.empty_cache()
# End of Group: Optional cleanup?
# We let smart_switch handle next cleanup
# 4. Reassemble
log("\n>>> Reassembling Audio...")
final_segments = []
global_sr = 24000
# Sort tasks by index to restore order
tasks.sort(key=lambda x: x['index'])
for i, task in enumerate(tasks):
# 1. Content
if task['audio']:
wav, sr = task['audio']
final_segments.append(wav)
global_sr = sr # Assume consistent SR
else:
log(f"Warning: Missing audio for line {i}")
# 2. Tag Pauses inside text? (Already processed in parse? No, we skipped it in simplified render)
# For MVP Grouped, let's keep it simple: Only Line Pauses.
# (If we want [p] tags inside, we'd need to split task into sub-tasks.
# Let's assume [p] tags are handled by user splitting lines for now or add basic support later)
# 3. Line Pause
last_char = task['text'].strip()[-1] if task['text'].strip() else ""
delay = pause_config.get("newline", 0.5)
if last_char in ['。', '.']: delay = max(delay, pause_config.get("period", 1.0))
elif last_char in ['?', '?']: delay = max(delay, pause_config.get("question", 0.6))
elif last_char in [',', ',']: delay = max(delay, pause_config.get("comma", 0.3))
if delay > 0:
final_segments.append(self.generate_silence(delay, global_sr))
if not final_segments:
err_msg = "Generation Failed. All lines were skipped. Check run.log for 'CRITICAL' errors."
log(err_msg)
return None, err_msg
# Check if we have AT LEAST ONE valid audio segment (not just silence)
has_content = any(len(s) > 24000 for s in final_segments) # silence is usually short? No, pause can be long.
# Better: check if we generated any tasks successfully.
failed_lines = sum(1 for t in tasks if t['audio'] is None)
if failed_lines == len(tasks):
err_msg = "Generation Failed. No speech generated. Check run.log."
log(err_msg)
return None, err_msg
merged = np.concatenate(final_segments)
final_file = os.path.join(out_dir, "final_audiobook.wav")
sf.write(final_file, merged, global_sr)
log(f"Done! Saved to {final_file}")
return (global_sr, merged), f"Success! Audio saved to outputs."
AUDIOBOOK_ENGINE = AudiobookEngine()
def get_all_voice_choices():
official = get_speaker_choices()
my_voices = get_my_voices_list()
# Format official? They are "Name | Lang" or "⚪ Name"
# Format my voices? "Filename.pt"
# Return single flat list
return official + my_voices
def run_audiobook_logic(script, role_state, p_line, p_period, p_comma, p_question, p_hyphen, language, temp, top_p, seed):
if not script: return None, "Please enter a script."
if not role_state: return None, "Please configure roles first."
pause_config = {
"newline": p_line,
"period": p_period,
"comma": p_comma,
"question": p_question,
"hyphen": p_hyphen
}
return AUDIOBOOK_ENGINE.render(script, role_state, pause_config, language=language, temperature=temp, top_p=top_p, seed=int(seed))
def add_role_to_state(role, voice, current_state):
if current_state is None: current_state = {}
# helper for display
def get_display(d): return [[k, v] for k, v in d.items()]
if not role or not voice:
return current_state, get_display(current_state), f"Error: Missing Name or Voice. Current keys: {len(current_state)}"
current_state[role] = voice
return current_state, get_display(current_state), f"Added: {role} -> {voice}"
PRESETS_FILE = "presets.json"
def get_presets_list():
if not os.path.exists(PRESETS_FILE): return []
try:
with open(PRESETS_FILE, "r", encoding='utf-8') as f:
data = json.load(f)
return list(data.keys())
except:
return []
def save_preset_logic(name, current_state):
try:
if not name or not name.strip(): return "Error: Name cannot be empty.", gr.update()
if not current_state: return "Error: Cast is empty.", gr.update()
data = {}
if os.path.exists(PRESETS_FILE):
try:
with open(PRESETS_FILE, "r", encoding='utf-8') as f: data = json.load(f)
except Exception as e:
print(f"Error reading presets: {e}")
pass
data[name.strip()] = current_state
with open(PRESETS_FILE, "w", encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
return f"Saved '{name}'", gr.update(choices=list(data.keys()))
except Exception as e:
traceback.print_exc()
return f"Error: {str(e)}", gr.update()
def load_preset_logic(name):
if not name: return {}, [], "Please select a preset."
if not os.path.exists(PRESETS_FILE): return {}, [], "No presets file found."
try:
with open(PRESETS_FILE, "r", encoding='utf-8') as f:
data = json.load(f)
state = data.get(name, {})
display = [[k, v] for k, v in state.items()]
return state, display, f"Loaded '{name}'"
except Exception as e:
return {}, [], f"Error: {e}"
def on_select_role(evt: gr.SelectData, data):
"""
When a user clicks a row in the dataframe, populate the inputs for editing.
"""
try:
row_idx = evt.index[0]
# data is a list of lists: [[role, voice], [role, voice], ...]
if row_idx < len(data):
row = data[row_idx]
return row[0], row[1], f"Editing: {row[0]}"
except:
pass
return gr.update(), gr.update(), ""
# --- Logic Functions ---
def run_voice_clone(ref_aud, ref_txt, use_xvec, text, lang_disp, t, p, k, r):
# Prevent duplicate concurrent execution
gen_id = f"clone_{id(text)}_{len(text)}"
if _active_generations.get(gen_id, False):
return None, "Generation already in progress, please wait..."
_active_generations[gen_id] = True
try:
return _run_voice_clone_internal(ref_aud, ref_txt, use_xvec, text, lang_disp, t, p, k, r)
finally:
_active_generations[gen_id] = False
def _run_voice_clone_internal(ref_aud, ref_txt, use_xvec, text, lang_disp, t, p, k, r):
model, err = _check_model_ready("base")
if err: return None, err
try:
if not text: return None, "Text required (请输入文本)."
at = _audio_to_tuple(ref_aud)
if at is None: return None, "Audio required (请上传音频)."
supported_langs = getattr(model.model, "get_supported_languages", lambda: FULL_LANGUAGES)()
_, lang_map = _build_choices_and_map(supported_langs)
language = lang_map.get(lang_disp, "Auto")
# AUTO-SLICE LOGIC
print(f"[PROGRESS] Generating audio from {len(text)} chars...")
chunks = smart_split_text(text, max_len=300)
print(f"[PROGRESS] Text split into {len(chunks)} chunks.")
full_wav_list = []
final_sr = 24000
for idx, chunk in enumerate(chunks):
print(f"[PROGRESS] Processing chunk {idx+1}/{len(chunks)}...")
# SAFETEY CHECK: Skip empty or punctuation-only chunks
if not any(c.isalnum() for c in chunk):
continue
# MEMORY SAFE CALL
gc.collect()
if torch.backends.mps.is_available(): torch.mps.empty_cache()
with torch.no_grad():
wavs, _sr = model.generate_voice_clone(
text=chunk, language=language, ref_audio=at, ref_text=ref_txt, x_vector_only_mode=use_xvec,
**_gen_kwargs(t, p, k, r)
)
w = wavs[0]
# Trim & Normalize & Pad
w = AudiobookEngine.trim_audio(w)
w = AudiobookEngine.normalize_audio(w) # Ensure consistent volume
w = np.concatenate([w, AudiobookEngine.generate_silence(0.3, _sr)]) # 0.3s pause between chunks
full_wav_list.append(w)
final_sr = _sr
merged = np.concatenate(full_wav_list)
return _wav_to_gradio_audio(merged, final_sr), f"Success (Sliced {len(chunks)} parts)."
except Exception as e:
traceback.print_exc()
return None, f"Error: {e}"
def save_voice_named(ref_aud, ref_txt, use_xvec, name):
# 1. Validate inputs first (avoid unnecessary model load)
at = _audio_to_tuple(ref_aud)
if at is None: return "Error: Audio required."
if not name or not name.strip(): return "Error: Name required."
name = "".join(x for x in name if x.isalnum() or x in " _-")
# 2. Ensure Base model is loaded for prompt creation
success, msg = MANAGER.smart_switch("base")
if not success: return f"Error switching to base model: {msg}"
model = MANAGER.tts
if model is None: return "Error: Model failed to load."
try:
items = model.create_voice_clone_prompt(
ref_audio=at, ref_text=ref_txt, x_vector_only_mode=use_xvec,
)
payload = {"items": [asdict(it) for it in items]}
out_path = os.path.join("./voices", f"{name}.pt")
torch.save(payload, out_path)
return f"Saved to {out_path}"
except Exception as e: return f"Error: {e}"
def save_designed_voice_logic(audio_info, text, name):
"""
Saves the generated audio from Voice Design as a new Cloned Voice.
Uses the generated audio as the reference audio, and the input text as the reference text.
"""
if audio_info is None: return "Error: No audio generated to save (请先生成语音)."
if not text: return "Error: Original text required for reference."
if not name: return "Error: Name required."
# We reuse saving logic from Clone tab
# save_voice_named handles the audio conversion (and switches to base)
try:
res = save_voice_named(audio_info, text, False, name)
finally:
# Always auto-restore Voice Design model so user can continue designing
# regardless of save success/failure
MANAGER.smart_switch("voice_design")
if "Error" not in res:
return res + "\n(Success! Design Model Reloaded & Ready)"
return res
def get_my_voices_list():
if not os.path.exists("./voices"): return []
files = [f for f in os.listdir("./voices") if f.endswith(".pt")]
files.sort()
return files
def run_my_voice_logic(text, lang_disp, voice_file, instruct, t, p, k, r):
# Prevent duplicate concurrent execution
gen_id = f"myvoice_{voice_file}_{len(text)}"
if _active_generations.get(gen_id, False):
return None, "Generation already in progress, please wait..."
_active_generations[gen_id] = True
try:
return _run_my_voice_logic_internal(text, lang_disp, voice_file, instruct, t, p, k, r)
finally:
_active_generations[gen_id] = False
def _run_my_voice_logic_internal(text, lang_disp, voice_file, instruct, t, p, k, r):
model, err = _check_model_ready("base")
if err: return None, err
try:
if not voice_file: return None, "Please select a voice."
voice_path = os.path.join("./voices", voice_file)
if not os.path.exists(voice_path): return None, "Voice file not found."
# Load prompt
payload = torch.load(voice_path, map_location="cpu")
# Reconstruct items
# We need VoiceClonePromptItem which is imported
items = [VoiceClonePromptItem(**x) for x in payload["items"]]
supported_langs = getattr(model.model, "get_supported_languages", lambda: FULL_LANGUAGES)()
_, lang_map = _build_choices_and_map(supported_langs)
language = lang_map.get(lang_disp, "Auto")
# AUTO-SLICE LOGIC
print(f"[PROGRESS] Generating audio from {len(text)} chars...")
chunks = smart_split_text(text, max_len=300)
full_wav_list = []
print(f"[PROGRESS] Text split into {len(chunks)} chunks.")