1717import torch
1818import numpy as np
1919import threading
20- import time
2120from torch .nn import functional as F
2221from contextlib import nullcontext
2322import uuid
@@ -57,6 +56,7 @@ def __init__(self,
5756 # dict used to store session related variable
5857 self .tts_speech_token_dict = {}
5958 self .llm_end_dict = {}
59+ self .token_condition_dict = {}
6060 self .mel_overlap_dict = {}
6161 self .flow_cache_dict = {}
6262 self .hift_cache_dict = {}
@@ -125,12 +125,18 @@ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uui
125125 continue
126126 else :
127127 cur_silent_token_num = 0
128- self .tts_speech_token_dict [uuid ].append (i )
129- self .llm_end_dict [uuid ] = True
128+ with self .lock :
129+ self .tts_speech_token_dict [uuid ].append (i )
130+ self .token_condition_dict [uuid ].notify ()
131+ with self .lock :
132+ self .llm_end_dict [uuid ] = True
133+ self .token_condition_dict [uuid ].notify ()
130134
131135 def vc_job (self , source_speech_token , uuid ):
132- self .tts_speech_token_dict [uuid ] = source_speech_token .flatten ().tolist ()
133- self .llm_end_dict [uuid ] = True
136+ with self .lock :
137+ self .tts_speech_token_dict [uuid ] = source_speech_token .flatten ().tolist ()
138+ self .llm_end_dict [uuid ] = True
139+ self .token_condition_dict [uuid ].notify ()
134140
135141 def token2wav (self , token , prompt_token , prompt_feat , embedding , uuid , finalize = False , speed = 1.0 ):
136142 with torch .cuda .amp .autocast (self .fp16 ):
@@ -181,6 +187,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
181187 this_uuid = str (uuid .uuid1 ())
182188 with self .lock :
183189 self .tts_speech_token_dict [this_uuid ], self .llm_end_dict [this_uuid ] = [], False
190+ self .token_condition_dict [this_uuid ] = threading .Condition (self .lock )
184191 self .hift_cache_dict [this_uuid ] = None
185192 self .mel_overlap_dict [this_uuid ] = torch .zeros (1 , 80 , 0 )
186193 self .flow_cache_dict [this_uuid ] = torch .zeros (1 , 80 , 0 , 2 )
@@ -192,10 +199,18 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
192199 if stream is True :
193200 token_hop_len = self .token_min_hop_len
194201 while True :
195- time .sleep (0.1 )
196- if len (self .tts_speech_token_dict [this_uuid ]) >= token_hop_len + self .token_overlap_len :
197- this_tts_speech_token = torch .tensor (self .tts_speech_token_dict [this_uuid ][:token_hop_len + self .token_overlap_len ]) \
198- .unsqueeze (dim = 0 )
202+ with self .lock :
203+ while len (self .tts_speech_token_dict [this_uuid ]) < token_hop_len + self .token_overlap_len and \
204+ self .llm_end_dict [this_uuid ] is False :
205+ self .token_condition_dict [this_uuid ].wait ()
206+ if len (self .tts_speech_token_dict [this_uuid ]) >= token_hop_len + self .token_overlap_len :
207+ this_tts_speech_token_slice = self .tts_speech_token_dict [this_uuid ][:token_hop_len + self .token_overlap_len ]
208+ elif self .llm_end_dict [this_uuid ] is True :
209+ break
210+ else :
211+ continue
212+ this_tts_speech_token = torch .tensor (this_tts_speech_token_slice ).unsqueeze (dim = 0 )
213+ if this_tts_speech_token .shape [1 ] != 0 :
199214 this_tts_speech = self .token2wav (token = this_tts_speech_token ,
200215 prompt_token = flow_prompt_speech_token ,
201216 prompt_feat = prompt_speech_feat ,
@@ -207,8 +222,6 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
207222 self .tts_speech_token_dict [this_uuid ] = self .tts_speech_token_dict [this_uuid ][token_hop_len :]
208223 # increase token_hop_len for better speech quality
209224 token_hop_len = min (self .token_max_hop_len , int (token_hop_len * self .stream_scale_factor ))
210- if self .llm_end_dict [this_uuid ] is True and len (self .tts_speech_token_dict [this_uuid ]) < token_hop_len + self .token_overlap_len :
211- break
212225 p .join ()
213226 # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
214227 this_tts_speech_token = torch .tensor (self .tts_speech_token_dict [this_uuid ]).unsqueeze (dim = 0 )
@@ -234,6 +247,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
234247 with self .lock :
235248 self .tts_speech_token_dict .pop (this_uuid )
236249 self .llm_end_dict .pop (this_uuid )
250+ self .token_condition_dict .pop (this_uuid )
237251 self .mel_overlap_dict .pop (this_uuid )
238252 self .hift_cache_dict .pop (this_uuid )
239253 self .flow_cache_dict .pop (this_uuid )
@@ -271,6 +285,7 @@ def __init__(self,
271285 # dict used to store session related variable
272286 self .tts_speech_token_dict = {}
273287 self .llm_end_dict = {}
288+ self .token_condition_dict = {}
274289 self .hift_cache_dict = {}
275290 self .silent_tokens = []
276291
@@ -287,6 +302,10 @@ def load_vllm(self, model_dir):
287302 gpu_memory_utilization = 0.2 )
288303 self .llm .vllm = LLMEngine .from_engine_args (engine_args )
289304 self .llm .lock = threading .Lock ()
305+ self .llm .vllm_step_condition = threading .Condition (self .llm .lock )
306+ self .llm .vllm_step_thread = None
307+ self .llm .vllm_background_error = None
308+ self .llm ._ensure_vllm_runtime ()
290309 del self .llm .llm .model .model .layers
291310
292311 def token2wav (self , token , prompt_token , prompt_feat , embedding , token_offset , uuid , stream = False , finalize = False , speed = 1.0 ):
@@ -334,6 +353,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
334353 this_uuid = str (uuid .uuid1 ())
335354 with self .lock :
336355 self .tts_speech_token_dict [this_uuid ], self .llm_end_dict [this_uuid ] = [], False
356+ self .token_condition_dict [this_uuid ] = threading .Condition (self .lock )
337357 self .hift_cache_dict [this_uuid ] = None
338358 if source_speech_token .shape [1 ] == 0 :
339359 p = threading .Thread (target = self .llm_job , args = (text , prompt_text , llm_prompt_speech_token , llm_embedding , this_uuid ))
@@ -344,10 +364,19 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
344364 token_offset = 0
345365 prompt_token_pad = int (np .ceil (flow_prompt_speech_token .shape [1 ] / self .token_hop_len ) * self .token_hop_len - flow_prompt_speech_token .shape [1 ])
346366 while True :
347- time .sleep (0.1 )
348367 this_token_hop_len = self .token_hop_len + prompt_token_pad if token_offset == 0 else self .token_hop_len
349- if len (self .tts_speech_token_dict [this_uuid ]) - token_offset >= this_token_hop_len + self .flow .pre_lookahead_len :
350- this_tts_speech_token = torch .tensor (self .tts_speech_token_dict [this_uuid ][:token_offset + this_token_hop_len + self .flow .pre_lookahead_len ]).unsqueeze (dim = 0 )
368+ required_token_len = token_offset + this_token_hop_len + self .flow .pre_lookahead_len
369+ with self .lock :
370+ while len (self .tts_speech_token_dict [this_uuid ]) < required_token_len and self .llm_end_dict [this_uuid ] is False :
371+ self .token_condition_dict [this_uuid ].wait ()
372+ if len (self .tts_speech_token_dict [this_uuid ]) >= required_token_len :
373+ this_tts_speech_token_slice = self .tts_speech_token_dict [this_uuid ][:required_token_len ]
374+ elif self .llm_end_dict [this_uuid ] is True :
375+ break
376+ else :
377+ continue
378+ this_tts_speech_token = torch .tensor (this_tts_speech_token_slice ).unsqueeze (dim = 0 )
379+ if this_tts_speech_token .shape [1 ] != 0 :
351380 this_tts_speech = self .token2wav (token = this_tts_speech_token ,
352381 prompt_token = flow_prompt_speech_token ,
353382 prompt_feat = prompt_speech_feat ,
@@ -359,8 +388,6 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
359388 token_offset += this_token_hop_len
360389 self .token_hop_len = min (self .token_max_hop_len , self .token_hop_len * self .stream_scale_factor )
361390 yield {'tts_speech' : this_tts_speech .cpu ()}
362- if self .llm_end_dict [this_uuid ] is True and len (self .tts_speech_token_dict [this_uuid ]) - token_offset < this_token_hop_len + self .flow .pre_lookahead_len :
363- break
364391 p .join ()
365392 # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
366393 this_tts_speech_token = torch .tensor (self .tts_speech_token_dict [this_uuid ]).unsqueeze (dim = 0 )
@@ -370,6 +397,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
370397 embedding = flow_embedding ,
371398 token_offset = token_offset ,
372399 uuid = this_uuid ,
400+ stream = stream ,
373401 finalize = True )
374402 yield {'tts_speech' : this_tts_speech .cpu ()}
375403 else :
@@ -388,9 +416,10 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
388416 with self .lock :
389417 self .tts_speech_token_dict .pop (this_uuid )
390418 self .llm_end_dict .pop (this_uuid )
419+ self .token_condition_dict .pop (this_uuid )
391420 self .hift_cache_dict .pop (this_uuid )
392421 if torch .cuda .is_available ():
393- torch .cuda .empty_cache ()
422+ # torch.cuda.empty_cache()
394423 torch .cuda .current_stream ().synchronize ()
395424
396425
@@ -418,6 +447,7 @@ def __init__(self,
418447 # dict used to store session related variable
419448 self .tts_speech_token_dict = {}
420449 self .llm_end_dict = {}
450+ self .token_condition_dict = {}
421451 self .hift_cache_dict = {}
422452 # FSQ silent and breath token
423453 self .silent_tokens = [1 , 2 , 28 , 29 , 55 , 248 , 494 , 2241 , 2242 , 2322 , 2323 ]
0 commit comments