Skip to content

Commit 2672003

Browse files
committed
refactor annotation
1 parent d82f2c6 commit 2672003

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

mambular/arch_utils/layer_utils/batch_ensemble_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def reset_parameters(self, scaling_init: Literal["ones", "random-signs", "normal
167167
if self.bias is not None:
168168
nn.init.zeros_(self.bias)
169169

170-
def forward(self, x: torch.Tensor, hidden: torch.Tensor | None = None) -> torch.Tensor:
170+
def forward(self, x: torch.Tensor, hidden: torch.Tensor = None) -> torch.Tensor: # type: ignore
171171
"""Forward pass for the BatchEnsembleRNNLayer.
172172
173173
Parameters

mambular/arch_utils/transformer_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def __init__(
168168
else:
169169
raise ValueError(f"Invalid activation '{activation}'. Choose from 'relu' or 'gelu'.")
170170

171-
def forward(self, src, src_mask: torch.Tensor | None = None):
171+
def forward(self, src, src_mask: torch.Tensor = None): # type: ignore
172172
"""Pass the input through the encoder layer.
173173
174174
Parameters
@@ -313,7 +313,7 @@ def __init__(
313313

314314
self.ensemble_projections = ensemble_projections
315315

316-
def forward(self, x, mask: torch.Tensor | None = None):
316+
def forward(self, x, mask: torch.Tensor = None): # type: ignore
317317
"""Pass the input through the encoder layers in turn.
318318
319319
Parameters

0 commit comments

Comments
 (0)