Skip to content

Commit fb21fd2

Browse files
jasagiriclaude
authored andcommitted
feat: add Apple Silicon (MPS) support for macOS ARM64
Introduce a device abstraction layer (cosyvoice/utils/device.py) that unifies CUDA, MPS, and CPU device management. Replace all hardcoded CUDA-specific code paths in the inference pipeline with device-agnostic alternatives, enabling CosyVoice to run natively on Apple Silicon Macs. Key changes: - Device abstraction: get_device(), get_stream_context(), get_autocast_context(), empty_cache() - model.py: Replace CUDA device init, streams, AMP, and cache clearing across CosyVoiceModel, CosyVoice2Model, CosyVoice3Model - cosyvoice.py: MPS-aware feature gates (TRT/vLLM require CUDA, JIT/fp16 require any GPU) - frontend.py: CoreMLExecutionProvider support for ONNX Runtime - common.py: Guard torch.cuda.manual_seed_all for non-CUDA environments - requirements.txt: Remove CUDA-only index URLs, loosen PyTorch version - setup_macos.sh: One-command setup script for Apple Silicon Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ace7c47 commit fb21fd2

9 files changed

Lines changed: 211 additions & 33 deletions

File tree

README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,32 @@
108108
sudo yum install sox sox-devel
109109
```
110110

111+
### macOS Apple Silicon (M1/M2/M3/M4)
112+
113+
For Apple Silicon Macs, use the dedicated setup script:
114+
115+
``` sh
116+
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
117+
cd CosyVoice
118+
bash setup_macos.sh
119+
```
120+
121+
Or manually:
122+
123+
``` sh
124+
conda create -n cosyvoice -y python=3.10
125+
conda activate cosyvoice
126+
conda install -c conda-forge pynini==2.1.5 -y
127+
pip install torch torchaudio
128+
pip install -r requirements.txt
129+
```
130+
131+
**Apple Silicon notes:**
132+
- Inference runs on MPS (Metal Performance Shaders) — faster than CPU
133+
- TensorRT and vLLM are not available (CUDA-only)
134+
- Training with DeepSpeed/DDP is not supported
135+
- For CUDA environments (Linux), use `pip install -r requirements-cuda.txt` instead
136+
111137
### Model download
112138

113139
We strongly recommend that you download our pretrained `Fun-CosyVoice3-0.5B` `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.

cosyvoice/cli/cosyvoice.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
2323
from cosyvoice.utils.file_utils import logging
2424
from cosyvoice.utils.class_utils import get_model_type
25+
from cosyvoice.utils.device import is_cuda, is_gpu_available
2526

2627

2728
class CosyVoice:
@@ -44,9 +45,12 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_co
4445
'{}/spk2info.pt'.format(model_dir),
4546
configs['allowed_special'])
4647
self.sample_rate = configs['sample_rate']
47-
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
48-
load_jit, load_trt, fp16 = False, False, False
49-
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
48+
if not is_cuda() and load_trt:
49+
load_trt = False
50+
logging.warning('TensorRT requires CUDA, disabling load_trt')
51+
if not is_gpu_available() and (load_jit or fp16):
52+
load_jit, fp16 = False, False
53+
logging.warning('no GPU device, disabling load_jit/fp16')
5054
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
5155
self.model.load('{}/llm.pt'.format(model_dir),
5256
'{}/flow.pt'.format(model_dir),
@@ -156,9 +160,16 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, f
156160
'{}/spk2info.pt'.format(model_dir),
157161
configs['allowed_special'])
158162
self.sample_rate = configs['sample_rate']
159-
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or load_vllm is True or fp16 is True):
160-
load_jit, load_trt, load_vllm, fp16 = False, False, False, False
161-
logging.warning('no cuda device, set load_jit/load_trt/load_vllm/fp16 to False')
163+
if not is_cuda():
164+
if load_trt:
165+
load_trt = False
166+
logging.warning('TensorRT requires CUDA, disabling load_trt')
167+
if load_vllm:
168+
load_vllm = False
169+
logging.warning('vLLM requires CUDA, disabling load_vllm')
170+
if not is_gpu_available() and (load_jit or fp16):
171+
load_jit, fp16 = False, False
172+
logging.warning('no GPU device, disabling load_jit/fp16')
162173
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
163174
self.model.load('{}/llm.pt'.format(model_dir),
164175
'{}/flow.pt'.format(model_dir),
@@ -206,9 +217,12 @@ def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_c
206217
'{}/spk2info.pt'.format(model_dir),
207218
configs['allowed_special'])
208219
self.sample_rate = configs['sample_rate']
209-
if torch.cuda.is_available() is False and (load_trt is True or fp16 is True):
210-
load_trt, fp16 = False, False
211-
logging.warning('no cuda device, set load_trt/fp16 to False')
220+
if not is_cuda() and load_trt:
221+
load_trt = False
222+
logging.warning('TensorRT requires CUDA, disabling load_trt')
223+
if not is_gpu_available() and fp16:
224+
fp16 = False
225+
logging.warning('no GPU device, disabling fp16')
212226
self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
213227
self.model.load('{}/llm.pt'.format(model_dir),
214228
'{}/flow.pt'.format(model_dir),

cosyvoice/cli/frontend.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import inflect
2626
from cosyvoice.utils.file_utils import logging, load_wav
2727
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
28+
from cosyvoice.utils.device import get_device
2829

2930

3031
class CosyVoiceFrontEnd:
@@ -38,14 +39,19 @@ def __init__(self,
3839
allowed_special: str = 'all'):
3940
self.tokenizer = get_tokenizer()
4041
self.feat_extractor = feat_extractor
41-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42+
self.device = get_device()
4243
option = onnxruntime.SessionOptions()
4344
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
4445
option.intra_op_num_threads = 1
4546
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
47+
if torch.cuda.is_available():
48+
tokenizer_providers = ["CUDAExecutionProvider"]
49+
elif "CoreMLExecutionProvider" in onnxruntime.get_available_providers():
50+
tokenizer_providers = ["CoreMLExecutionProvider"]
51+
else:
52+
tokenizer_providers = ["CPUExecutionProvider"]
4653
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
47-
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
48-
"CPUExecutionProvider"])
54+
providers=tokenizer_providers)
4955
if os.path.exists(spk2info):
5056
self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=True)
5157
else:

cosyvoice/cli/model.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from cosyvoice.utils.common import fade_in_out
2525
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
2626
from cosyvoice.utils.common import TrtContextWrapper
27+
from cosyvoice.utils.device import get_device, get_stream_context, get_autocast_context, empty_cache
2728

2829

2930
class CosyVoiceModel:
@@ -33,7 +34,7 @@ def __init__(self,
3334
flow: torch.nn.Module,
3435
hift: torch.nn.Module,
3536
fp16: bool = False):
36-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
37+
self.device = get_device()
3738
self.llm = llm
3839
self.flow = flow
3940
self.hift = hift
@@ -52,7 +53,7 @@ def __init__(self,
5253
# rtf and decoding related
5354
self.stream_scale_factor = 1
5455
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
55-
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
56+
self.llm_context = get_stream_context(self.device)
5657
self.lock = threading.Lock()
5758
# dict used to store session related variable
5859
self.tts_speech_token_dict = {}
@@ -100,7 +101,7 @@ def get_trt_kwargs(self):
100101

101102
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
102103
cur_silent_token_num, max_silent_token_num = 0, 5
103-
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
104+
with self.llm_context, get_autocast_context(self.fp16 is True and hasattr(self.llm, 'vllm') is False, self.device):
104105
if isinstance(text, Generator):
105106
assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
106107
token_generator = self.llm.inference_bistream(text=text,
@@ -133,7 +134,7 @@ def vc_job(self, source_speech_token, uuid):
133134
self.llm_end_dict[uuid] = True
134135

135136
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
136-
with torch.cuda.amp.autocast(self.fp16):
137+
with get_autocast_context(self.fp16, self.device):
137138
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
138139
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
139140
prompt_token=prompt_token.to(self.device),
@@ -237,9 +238,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
237238
self.mel_overlap_dict.pop(this_uuid)
238239
self.hift_cache_dict.pop(this_uuid)
239240
self.flow_cache_dict.pop(this_uuid)
240-
if torch.cuda.is_available():
241-
torch.cuda.empty_cache()
242-
torch.cuda.current_stream().synchronize()
241+
empty_cache(self.device)
243242

244243

245244
class CosyVoice2Model(CosyVoiceModel):
@@ -249,7 +248,7 @@ def __init__(self,
249248
flow: torch.nn.Module,
250249
hift: torch.nn.Module,
251250
fp16: bool = False):
252-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
251+
self.device = get_device()
253252
self.llm = llm
254253
self.flow = flow
255254
self.hift = hift
@@ -266,7 +265,7 @@ def __init__(self,
266265
# speech fade in out
267266
self.speech_window = np.hamming(2 * self.source_cache_len)
268267
# rtf and decoding related
269-
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
268+
self.llm_context = get_stream_context(self.device)
270269
self.lock = threading.Lock()
271270
# dict used to store session related variable
272271
self.tts_speech_token_dict = {}
@@ -290,7 +289,7 @@ def load_vllm(self, model_dir):
290289
del self.llm.llm.model.model.layers
291290

292291
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
293-
with torch.cuda.amp.autocast(self.fp16):
292+
with get_autocast_context(self.fp16, self.device):
294293
tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
295294
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
296295
prompt_token=prompt_token.to(self.device),
@@ -389,9 +388,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
389388
self.tts_speech_token_dict.pop(this_uuid)
390389
self.llm_end_dict.pop(this_uuid)
391390
self.hift_cache_dict.pop(this_uuid)
392-
if torch.cuda.is_available():
393-
torch.cuda.empty_cache()
394-
torch.cuda.current_stream().synchronize()
391+
empty_cache(self.device)
395392

396393

397394
class CosyVoice3Model(CosyVoice2Model):
@@ -401,7 +398,7 @@ def __init__(self,
401398
flow: torch.nn.Module,
402399
hift: torch.nn.Module,
403400
fp16: bool = False):
404-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
401+
self.device = get_device()
405402
self.llm = llm
406403
self.flow = flow
407404
self.hift = hift
@@ -413,7 +410,7 @@ def __init__(self,
413410
self.stream_scale_factor = 2
414411
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
415412
# rtf and decoding related
416-
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
413+
self.llm_context = get_stream_context(self.device)
417414
self.lock = threading.Lock()
418415
# dict used to store session related variable
419416
self.tts_speech_token_dict = {}
@@ -423,7 +420,7 @@ def __init__(self,
423420
self.silent_tokens = [1, 2, 28, 29, 55, 248, 494, 2241, 2242, 2322, 2323]
424421

425422
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
426-
with torch.cuda.amp.autocast(self.fp16):
423+
with get_autocast_context(self.fp16, self.device):
427424
tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
428425
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
429426
prompt_token=prompt_token.to(self.device),

cosyvoice/utils/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ def set_all_random_seed(seed):
182182
random.seed(seed)
183183
np.random.seed(seed)
184184
torch.manual_seed(seed)
185-
torch.cuda.manual_seed_all(seed)
185+
if torch.cuda.is_available():
186+
torch.cuda.manual_seed_all(seed)
186187

187188

188189
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:

cosyvoice/utils/device.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Unified device management for CUDA, MPS (Apple Silicon), and CPU backends."""
15+
16+
import random
17+
from contextlib import nullcontext
18+
19+
import numpy as np
20+
import torch
21+
22+
23+
def get_device() -> torch.device:
24+
"""Return the best available device: cuda > mps > cpu."""
25+
if torch.cuda.is_available():
26+
return torch.device('cuda')
27+
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
28+
return torch.device('mps')
29+
return torch.device('cpu')
30+
31+
32+
def is_cuda() -> bool:
33+
return torch.cuda.is_available()
34+
35+
36+
def is_mps() -> bool:
37+
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
38+
39+
40+
def is_gpu_available() -> bool:
41+
return is_cuda() or is_mps()
42+
43+
44+
def get_stream_context(device: torch.device):
45+
"""Return a CUDA stream context or nullcontext for non-CUDA devices."""
46+
if device.type == 'cuda':
47+
return torch.cuda.stream(torch.cuda.Stream(device))
48+
return nullcontext()
49+
50+
51+
def get_autocast_context(enabled: bool, device: torch.device):
52+
"""Return the appropriate autocast context for the device."""
53+
if not enabled:
54+
return nullcontext()
55+
if device.type == 'cuda':
56+
return torch.cuda.amp.autocast(enabled=True)
57+
if device.type == 'mps':
58+
return torch.autocast(device_type='mps', dtype=torch.float16)
59+
return nullcontext()
60+
61+
62+
def empty_cache(device: torch.device):
63+
"""Clear device cache and synchronize."""
64+
if device.type == 'cuda':
65+
torch.cuda.empty_cache()
66+
torch.cuda.current_stream().synchronize()
67+
elif device.type == 'mps':
68+
if hasattr(torch.mps, 'empty_cache'):
69+
torch.mps.empty_cache()
70+
if hasattr(torch.mps, 'synchronize'):
71+
torch.mps.synchronize()
72+
73+
74+
def set_all_random_seed(seed: int):
75+
"""Set random seed across all available backends."""
76+
random.seed(seed)
77+
np.random.seed(seed)
78+
torch.manual_seed(seed)
79+
if torch.cuda.is_available():
80+
torch.cuda.manual_seed_all(seed)

requirements-cuda.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# CUDA-specific requirements (Linux with NVIDIA GPU)
2+
# Install with: pip install -r requirements-cuda.txt
3+
--extra-index-url https://download.pytorch.org/whl/cu121
4+
--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
5+
-r requirements.txt

requirements.txt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
--extra-index-url https://download.pytorch.org/whl/cu121
2-
--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ # https://github.com/microsoft/onnxruntime/issues/21684
31
conformer==0.3.2
42
deepspeed==0.15.1; sys_platform == 'linux'
53
diffusers==0.29.0
@@ -33,8 +31,8 @@ tensorboard==2.14.0
3331
tensorrt-cu12==10.13.3.9; sys_platform == 'linux'
3432
tensorrt-cu12-bindings==10.13.3.9; sys_platform == 'linux'
3533
tensorrt-cu12-libs==10.13.3.9; sys_platform == 'linux'
36-
torch==2.3.1
37-
torchaudio==2.3.1
34+
torch>=2.3.1
35+
torchaudio>=2.3.1
3836
transformers==4.51.3
3937
x-transformers==2.11.24
4038
uvicorn==0.30.0

0 commit comments

Comments
 (0)