Skip to content

Commit 4a4ee3d

Browse files
committed
support muon optimizer
1 parent 51d3d3d commit 4a4ee3d

10 files changed

Lines changed: 316 additions & 28 deletions

File tree

basics/base_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def build_optimizer(self, model):
307307
optimizer = build_object_from_class_name(
308308
optimizer_args['optimizer_cls'],
309309
torch.optim.Optimizer,
310-
model.parameters(),
310+
model if optimizer_args['optimizer_cls'] == 'modules.optimizer.muon.Muon_AdamW' else model.parameters(),
311311
**optimizer_args
312312
)
313313
return optimizer

configs/acoustic.yaml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,15 @@ lambda_aux_mel_loss: 0.2
104104
# train and eval
105105
num_sanity_val_steps: 1
106106
optimizer_args:
107+
optimizer_cls: modules.optimizer.muon.Muon_AdamW
107108
lr: 0.0006
109+
muon_args:
110+
weight_decay: 0.1
111+
adamw_args:
112+
weight_decay: 0.0
108113
lr_scheduler_args:
109-
step_size: 10000
110-
gamma: 0.75
114+
step_size: 5000
115+
gamma: 0.8
111116
max_batch_frames: 50000
112117
max_batch_size: 64
113118
dataset_size_key: 'lengths'

configs/templates/config_acoustic.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,15 @@ shallow_diffusion_args:
101101
lambda_aux_mel_loss: 0.2
102102

103103
optimizer_args:
104+
optimizer_cls: modules.optimizer.muon.Muon_AdamW
104105
lr: 0.0006
106+
muon_args:
107+
weight_decay: 0.1
108+
adamw_args:
109+
weight_decay: 0.0
105110
lr_scheduler_args:
106-
scheduler_cls: torch.optim.lr_scheduler.StepLR
107-
step_size: 10000
108-
gamma: 0.75
111+
step_size: 5000
112+
gamma: 0.8
109113
max_batch_frames: 50000
110114
max_batch_size: 64
111115
max_updates: 160000

configs/templates/config_variance.yaml

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ enc_ffn_kernel_size: 3
6767
use_rope: true
6868
hidden_size: 256
6969
dur_prediction_args:
70-
arch: fs2
71-
hidden_size: 512
70+
arch: resnet
71+
hidden_size: 256
7272
dropout: 0.1
7373
num_layers: 5
7474
kernel_size: 3
@@ -123,11 +123,15 @@ lambda_pitch_loss: 1.0
123123
lambda_var_loss: 1.0
124124

125125
optimizer_args:
126+
optimizer_cls: modules.optimizer.muon.Muon_AdamW
126127
lr: 0.0006
128+
muon_args:
129+
weight_decay: 0.1
130+
adamw_args:
131+
weight_decay: 0.0
127132
lr_scheduler_args:
128-
scheduler_cls: torch.optim.lr_scheduler.StepLR
129-
step_size: 10000
130-
gamma: 0.75
133+
step_size: 5000
134+
gamma: 0.8
131135
max_batch_frames: 80000
132136
max_batch_size: 48
133137
max_updates: 160000

configs/variance.yaml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ rel_pos: true
4040
hidden_size: 256
4141

4242
dur_prediction_args:
43-
arch: fs2
44-
hidden_size: 512
43+
arch: resnet
44+
hidden_size: 256
4545
dropout: 0.1
4646
num_layers: 5
4747
kernel_size: 3
@@ -114,10 +114,15 @@ diff_speedup: 10
114114
# train and eval
115115
num_sanity_val_steps: 1
116116
optimizer_args:
117+
optimizer_cls: modules.optimizer.muon.Muon_AdamW
117118
lr: 0.0006
119+
muon_args:
120+
weight_decay: 0.1
121+
adamw_args:
122+
weight_decay: 0.0
118123
lr_scheduler_args:
119-
step_size: 10000
120-
gamma: 0.75
124+
step_size: 5000
125+
gamma: 0.8
121126
max_batch_frames: 80000
122127
max_batch_size: 48
123128
dataset_size_key: 'lengths'

modules/fastspeech/tts_modules.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class DurationPredictor(torch.nn.Module):
6262
"""
6363

6464
def __init__(self, in_dims, n_layers=2, n_chans=384, kernel_size=3,
65-
dropout_rate=0.1, offset=1.0, dur_loss_type='mse'):
65+
dropout_rate=0.1, offset=1.0, dur_loss_type='mse', arch='resnet'):
6666
"""Initialize duration predictor module.
6767
Args:
6868
in_dims (int): Input dimension.
@@ -76,16 +76,29 @@ def __init__(self, in_dims, n_layers=2, n_chans=384, kernel_size=3,
7676
self.offset = offset
7777
self.conv = torch.nn.ModuleList()
7878
self.kernel_size = kernel_size
79+
self.use_resnet = (arch == 'resnet')
7980
for idx in range(n_layers):
8081
in_chans = in_dims if idx == 0 else n_chans
81-
self.conv.append(torch.nn.Sequential(
82-
torch.nn.Identity(), # this is a placeholder for ConstantPad1d which is now merged into Conv1d
83-
torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=kernel_size // 2),
84-
torch.nn.ReLU(),
85-
LayerNorm(n_chans, dim=1),
86-
torch.nn.Dropout(dropout_rate)
87-
))
88-
82+
if self.use_resnet:
83+
self.conv.append(nn.Sequential(
84+
LayerNorm(in_chans, dim=1),
85+
nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=kernel_size // 2),
86+
nn.ReLU(),
87+
nn.Conv1d(n_chans, n_chans, 1),
88+
nn.Dropout(dropout_rate)
89+
))
90+
else:
91+
self.conv.append(nn.Sequential(
92+
nn.Identity(), # this is a placeholder for ConstantPad1d which is now merged into Conv1d
93+
nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=kernel_size // 2),
94+
nn.ReLU(),
95+
LayerNorm(n_chans, dim=1),
96+
nn.Dropout(dropout_rate)
97+
))
98+
if self.use_resnet and in_dims != n_chans:
99+
self.res_conv = nn.Conv1d(in_dims, n_chans, 1)
100+
else:
101+
self.res_conv = None
89102
self.loss_type = dur_loss_type
90103
if self.loss_type in ['mse', 'huber']:
91104
self.out_dims = 1
@@ -121,8 +134,12 @@ def forward(self, xs, x_masks=None, infer=True):
121134
xs = xs.transpose(1, -1) # (B, idim, Tmax)
122135
masks = 1 - x_masks.float()
123136
masks_ = masks[:, None, :]
124-
for f in self.conv:
125-
xs = f(xs) # (B, C, Tmax)
137+
for idx, f in enumerate(self.conv):
138+
if self.use_resnet:
139+
residual = self.res_conv(xs) if idx == 0 and self.res_conv is not None else xs
140+
xs = residual + f(xs)
141+
else:
142+
xs = f(xs)
126143
if x_masks is not None:
127144
xs = xs * masks_
128145
xs = self.linear(xs.transpose(1, -1)) # [B, T, C]

modules/fastspeech/variance_encoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def __init__(self, vocab_size):
4646
dropout_rate=dur_hparams['dropout'],
4747
kernel_size=dur_hparams['kernel_size'],
4848
offset=dur_hparams['log_offset'],
49-
dur_loss_type=dur_hparams['loss_type']
49+
dur_loss_type=dur_hparams['loss_type'],
50+
arch=dur_hparams['arch']
5051
)
5152

5253
def forward(
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from torch import Tensor
2+
from torch.optim import Optimizer
3+
from torch.optim.optimizer import ParamsT
4+
from dataclasses import dataclass
5+
from typing import Any, Dict, List, Type, Callable, Optional, Iterable
6+
7+
8+
@dataclass
9+
class OptimizerSpec:
10+
"""Spec for creating an optimizer that is part of a `ChainedOptimizer`."""
11+
12+
class_type: Type[Optimizer]
13+
init_args: Dict[str, Any]
14+
param_filter: Optional[Callable[[Tensor], bool]]
15+
16+
17+
class ChainedOptimizer(Optimizer):
18+
"""
19+
A wrapper around multiple optimizers that allows for chaining them together.
20+
The optimizers are applied in the order they are passed in the constructor.
21+
Each optimizer is responsible for updating a subset of the parameters, which
22+
is determined by the `param_filter` function. If no optimizer is found for a
23+
parameter group, an exception is raised.
24+
"""
25+
26+
def __init__(
27+
self,
28+
params: ParamsT,
29+
optimizer_specs: List[OptimizerSpec],
30+
lr: float,
31+
weight_decay: float = 0.0,
32+
optimizer_selection_callback: Optional[Callable[[Tensor, int], None]] = None,
33+
**common_kwargs,
34+
):
35+
self.optimizer_specs = optimizer_specs
36+
self.optimizer_selection_callback = optimizer_selection_callback
37+
self.optimizers: List[Optimizer] = []
38+
defaults = dict(lr=lr, weight_decay=weight_decay)
39+
super().__init__(params, defaults)
40+
41+
# Split the params for each optimzier
42+
params_for_optimizers = [[] for _ in optimizer_specs]
43+
for param_group in self.param_groups:
44+
params = param_group["params"]
45+
indices = param_group["optimizer_and_param_group_indices"] = set()
46+
for param in params:
47+
assert isinstance(param, Tensor), f"Expected a Tensor, got {type(param)}"
48+
for index, spec in enumerate(optimizer_specs):
49+
if spec.param_filter is None or spec.param_filter(param):
50+
if self.optimizer_selection_callback is not None:
51+
self.optimizer_selection_callback(param, index)
52+
params_for_optimizers[index].append(param)
53+
indices.add((index, 0))
54+
break
55+
56+
# Initialize the optimizers
57+
for spec, selected_params in zip(optimizer_specs, params_for_optimizers):
58+
optimizer_args = {
59+
'lr': lr,
60+
'weight_decay': weight_decay,
61+
}
62+
optimizer_args.update(common_kwargs)
63+
optimizer_args.update(spec.init_args)
64+
optimizer = spec.class_type(selected_params, **optimizer_args)
65+
self.optimizers.append(optimizer)
66+
67+
def state_dict(self) -> Dict[str, Any]:
68+
return {
69+
"optimizers": [opt.state_dict() for opt in self.optimizers],
70+
**super().state_dict(),
71+
}
72+
73+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
74+
optimizers = state_dict.pop("optimizers")
75+
super().load_state_dict(state_dict)
76+
for i in range(len(self.optimizers)):
77+
self.optimizers[i].load_state_dict(optimizers[i])
78+
79+
def zero_grad(self, set_to_none: bool = True) -> None:
80+
for opt in self.optimizers:
81+
opt.zero_grad(set_to_none=set_to_none)
82+
83+
def _copy_lr_to_optimizers(self) -> None:
84+
for param_group in self.param_groups:
85+
indices = param_group["optimizer_and_param_group_indices"]
86+
for optimizer_idx, param_group_idx in indices:
87+
self.optimizers[optimizer_idx].param_groups[param_group_idx]["lr"] = param_group["lr"]
88+
89+
def step(self, closure=None) -> None:
90+
self._copy_lr_to_optimizers()
91+
for opt in self.optimizers:
92+
opt.step(closure)
93+
94+
def add_param_group(self, param_group: Dict[str, Any]) -> None:
95+
super().add_param_group(param_group)
96+
97+
# If optimizer has not been initialized, skip adding the param groups
98+
if not self.optimizers:
99+
return
100+
101+
# Split the params for each optimzier
102+
params_for_optimizers = [[] for _ in self.optimizer_specs]
103+
params = param_group["params"]
104+
indices = param_group["optimizer_and_param_group_indices"] = set()
105+
for param in params:
106+
assert isinstance(param, Tensor), f"Expected a Tensor, got {type(param)}"
107+
found_optimizer = False
108+
for index, spec in enumerate(self.optimizer_specs):
109+
if spec.param_filter is None or spec.param_filter(param):
110+
if self.optimizer_selection_callback is not None:
111+
self.optimizer_selection_callback(param, index)
112+
params_for_optimizers[index].append(param)
113+
indices.add((index, len(self.optimizers[index].param_groups)))
114+
found_optimizer = True
115+
break
116+
if not found_optimizer:
117+
raise ValueError("No valid optimizer found for the given parameter group")
118+
119+
# Add the selected param group to the optimizers
120+
for optimizer, selected_params in zip(self.optimizers, params_for_optimizers):
121+
if selected_params:
122+
optimizer.add_param_group({"params": selected_params})

0 commit comments

Comments
 (0)