Skip to content

Commit b4388fc

Browse files
committed
fix(inference): Improve inference consistency and runtime stability for streaming and vLLM backends
1 parent ace7c47 commit b4388fc

4 files changed

Lines changed: 217 additions & 57 deletions

File tree

cosyvoice/cli/model.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import torch
1818
import numpy as np
1919
import threading
20-
import time
2120
from torch.nn import functional as F
2221
from contextlib import nullcontext
2322
import uuid
@@ -57,6 +56,7 @@ def __init__(self,
5756
# dict used to store session related variable
5857
self.tts_speech_token_dict = {}
5958
self.llm_end_dict = {}
59+
self.token_condition_dict = {}
6060
self.mel_overlap_dict = {}
6161
self.flow_cache_dict = {}
6262
self.hift_cache_dict = {}
@@ -125,12 +125,18 @@ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uui
125125
continue
126126
else:
127127
cur_silent_token_num = 0
128-
self.tts_speech_token_dict[uuid].append(i)
129-
self.llm_end_dict[uuid] = True
128+
with self.lock:
129+
self.tts_speech_token_dict[uuid].append(i)
130+
self.token_condition_dict[uuid].notify()
131+
with self.lock:
132+
self.llm_end_dict[uuid] = True
133+
self.token_condition_dict[uuid].notify()
130134

131135
def vc_job(self, source_speech_token, uuid):
132-
self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist()
133-
self.llm_end_dict[uuid] = True
136+
with self.lock:
137+
self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist()
138+
self.llm_end_dict[uuid] = True
139+
self.token_condition_dict[uuid].notify()
134140

135141
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
136142
with torch.cuda.amp.autocast(self.fp16):
@@ -181,6 +187,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
181187
this_uuid = str(uuid.uuid1())
182188
with self.lock:
183189
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
190+
self.token_condition_dict[this_uuid] = threading.Condition(self.lock)
184191
self.hift_cache_dict[this_uuid] = None
185192
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
186193
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
@@ -192,10 +199,18 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
192199
if stream is True:
193200
token_hop_len = self.token_min_hop_len
194201
while True:
195-
time.sleep(0.1)
196-
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
197-
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
198-
.unsqueeze(dim=0)
202+
with self.lock:
203+
while len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len and \
204+
self.llm_end_dict[this_uuid] is False:
205+
self.token_condition_dict[this_uuid].wait()
206+
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
207+
this_tts_speech_token_slice = self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]
208+
elif self.llm_end_dict[this_uuid] is True:
209+
break
210+
else:
211+
continue
212+
this_tts_speech_token = torch.tensor(this_tts_speech_token_slice).unsqueeze(dim=0)
213+
if this_tts_speech_token.shape[1] != 0:
199214
this_tts_speech = self.token2wav(token=this_tts_speech_token,
200215
prompt_token=flow_prompt_speech_token,
201216
prompt_feat=prompt_speech_feat,
@@ -207,8 +222,6 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
207222
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
208223
# increase token_hop_len for better speech quality
209224
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
210-
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
211-
break
212225
p.join()
213226
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
214227
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
@@ -234,6 +247,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
234247
with self.lock:
235248
self.tts_speech_token_dict.pop(this_uuid)
236249
self.llm_end_dict.pop(this_uuid)
250+
self.token_condition_dict.pop(this_uuid)
237251
self.mel_overlap_dict.pop(this_uuid)
238252
self.hift_cache_dict.pop(this_uuid)
239253
self.flow_cache_dict.pop(this_uuid)
@@ -271,6 +285,7 @@ def __init__(self,
271285
# dict used to store session related variable
272286
self.tts_speech_token_dict = {}
273287
self.llm_end_dict = {}
288+
self.token_condition_dict = {}
274289
self.hift_cache_dict = {}
275290
self.silent_tokens = []
276291

@@ -287,6 +302,10 @@ def load_vllm(self, model_dir):
287302
gpu_memory_utilization=0.2)
288303
self.llm.vllm = LLMEngine.from_engine_args(engine_args)
289304
self.llm.lock = threading.Lock()
305+
self.llm.vllm_step_condition = threading.Condition(self.llm.lock)
306+
self.llm.vllm_step_thread = None
307+
self.llm.vllm_background_error = None
308+
self.llm._ensure_vllm_runtime()
290309
del self.llm.llm.model.model.layers
291310

292311
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
@@ -334,6 +353,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
334353
this_uuid = str(uuid.uuid1())
335354
with self.lock:
336355
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
356+
self.token_condition_dict[this_uuid] = threading.Condition(self.lock)
337357
self.hift_cache_dict[this_uuid] = None
338358
if source_speech_token.shape[1] == 0:
339359
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
@@ -344,10 +364,19 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
344364
token_offset = 0
345365
prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
346366
while True:
347-
time.sleep(0.1)
348367
this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
349-
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
350-
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
368+
required_token_len = token_offset + this_token_hop_len + self.flow.pre_lookahead_len
369+
with self.lock:
370+
while len(self.tts_speech_token_dict[this_uuid]) < required_token_len and self.llm_end_dict[this_uuid] is False:
371+
self.token_condition_dict[this_uuid].wait()
372+
if len(self.tts_speech_token_dict[this_uuid]) >= required_token_len:
373+
this_tts_speech_token_slice = self.tts_speech_token_dict[this_uuid][:required_token_len]
374+
elif self.llm_end_dict[this_uuid] is True:
375+
break
376+
else:
377+
continue
378+
this_tts_speech_token = torch.tensor(this_tts_speech_token_slice).unsqueeze(dim=0)
379+
if this_tts_speech_token.shape[1] != 0:
351380
this_tts_speech = self.token2wav(token=this_tts_speech_token,
352381
prompt_token=flow_prompt_speech_token,
353382
prompt_feat=prompt_speech_feat,
@@ -359,8 +388,6 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
359388
token_offset += this_token_hop_len
360389
self.token_hop_len = min(self.token_max_hop_len, self.token_hop_len * self.stream_scale_factor)
361390
yield {'tts_speech': this_tts_speech.cpu()}
362-
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
363-
break
364391
p.join()
365392
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
366393
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
@@ -370,6 +397,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
370397
embedding=flow_embedding,
371398
token_offset=token_offset,
372399
uuid=this_uuid,
400+
stream=stream,
373401
finalize=True)
374402
yield {'tts_speech': this_tts_speech.cpu()}
375403
else:
@@ -388,9 +416,10 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
388416
with self.lock:
389417
self.tts_speech_token_dict.pop(this_uuid)
390418
self.llm_end_dict.pop(this_uuid)
419+
self.token_condition_dict.pop(this_uuid)
391420
self.hift_cache_dict.pop(this_uuid)
392421
if torch.cuda.is_available():
393-
torch.cuda.empty_cache()
422+
# torch.cuda.empty_cache()
394423
torch.cuda.current_stream().synchronize()
395424

396425

@@ -418,6 +447,7 @@ def __init__(self,
418447
# dict used to store session related variable
419448
self.tts_speech_token_dict = {}
420449
self.llm_end_dict = {}
450+
self.token_condition_dict = {}
421451
self.hift_cache_dict = {}
422452
# FSQ silent and breath token
423453
self.silent_tokens = [1, 2, 28, 29, 55, 248, 494, 2241, 2242, 2322, 2323]

cosyvoice/flow/flow_matching.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
from contextlib import nullcontext
16+
1517
import torch
1618
import torch.nn.functional as F
1719
from matcha.models.components.flow_matching import BASECFM
@@ -128,9 +130,11 @@ def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
128130
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
129131
else:
130132
[estimator, stream], trt_engine = self.estimator.acquire_estimator()
131-
# NOTE need to synchronize when switching stream
132-
torch.cuda.current_stream().synchronize()
133-
with stream:
133+
stream_context = stream if stream is not None else nullcontext()
134+
if stream is not None:
135+
# NOTE only synchronize when switching to a dedicated TRT stream.
136+
torch.cuda.current_stream().synchronize()
137+
with stream_context:
134138
estimator.set_input_shape('x', (2, 80, x.size(2)))
135139
estimator.set_input_shape('mask', (2, 1, x.size(2)))
136140
estimator.set_input_shape('mu', (2, 80, x.size(2)))
@@ -148,7 +152,8 @@ def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
148152
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
149153
# run trt engine
150154
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
151-
torch.cuda.current_stream().synchronize()
155+
if stream is not None:
156+
torch.cuda.current_stream().synchronize()
152157
self.estimator.release_estimator(estimator, stream)
153158
return x
154159

0 commit comments

Comments
 (0)