11import torch
22from torch import Tensor
3- from unit .conftest import DEVICE
43
54
65def _check_valid_dimensions (n_rows : int , n_cols : int ) -> None :
@@ -37,9 +36,9 @@ def _augment_orthogonal_matrix(orthogonal_matrix: Tensor) -> Tensor:
3736
3837 n_rows = orthogonal_matrix .shape [0 ]
3938 projection = orthogonal_matrix @ orthogonal_matrix .T
40- zero = torch .zeros ([n_rows ], device = DEVICE )
39+ zero = torch .zeros ([n_rows ])
4140 while True :
42- random_vector = torch .randn ([n_rows ], device = DEVICE )
41+ random_vector = torch .randn ([n_rows ])
4342 projected_vector = random_vector - projection @ random_vector
4443 if not torch .allclose (projected_vector , zero ):
4544 break
@@ -70,7 +69,7 @@ def _generate_unitary_matrix(n_rows: int, n_cols: int) -> Tensor:
7069 """Generates a unitary matrix of shape [n_rows, n_cols]."""
7170
7271 _check_valid_dimensions (n_rows , n_cols )
73- partial_matrix = torch .randn ([n_rows , 1 ], device = DEVICE )
72+ partial_matrix = torch .randn ([n_rows , 1 ])
7473 partial_matrix = torch .nn .functional .normalize (partial_matrix , dim = 0 )
7574
7675 unitary_matrix = _complete_orthogonal_matrix (partial_matrix , n_cols )
@@ -83,7 +82,7 @@ def _generate_unitary_matrix_with_positive_column(n_rows: int, n_cols: int) -> T
8382 positive vector.
8483 """
8584 _check_valid_dimensions (n_rows , n_cols )
86- partial_matrix = torch .abs (torch .randn ([n_rows , 1 ], device = DEVICE ))
85+ partial_matrix = torch .abs (torch .randn ([n_rows , 1 ]))
8786 partial_matrix = torch .nn .functional .normalize (partial_matrix , dim = 0 )
8887
8988 unitary_matrix_with_positive_column = _complete_orthogonal_matrix (partial_matrix , n_cols )
@@ -94,7 +93,7 @@ def _generate_diagonal_singular_values(rank: int) -> Tensor:
9493 """
9594 generates a diagonal matrix of positive values sorted in descending order.
9695 """
97- singular_values = torch .abs (torch .randn ([rank ], device = DEVICE ))
96+ singular_values = torch .abs (torch .randn ([rank ]))
9897 singular_values = torch .sort (singular_values , descending = True )[0 ]
9998 S = torch .diag (singular_values )
10099 return S
@@ -108,7 +107,7 @@ def generate_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
108107 _check_valid_rank (n_rows , n_cols , rank )
109108
110109 if rank == 0 :
111- matrix = torch .zeros ([n_rows , n_cols ], device = DEVICE )
110+ matrix = torch .zeros ([n_rows , n_cols ])
112111 else :
113112 U = _generate_unitary_matrix (n_rows , rank )
114113 V = _generate_unitary_matrix (n_cols , rank )
@@ -126,7 +125,7 @@ def generate_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
126125
127126 _check_valid_rank (n_rows , n_cols , rank )
128127 if rank == 0 :
129- matrix = torch .zeros ([n_rows , n_cols ], device = DEVICE )
128+ matrix = torch .zeros ([n_rows , n_cols ])
130129 else :
131130 U = _generate_unitary_matrix_with_positive_column (n_rows , rank )
132131 V = _generate_unitary_matrix (n_cols , rank )
@@ -161,9 +160,7 @@ def generate_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
161160 generate_matrix (n_rows , n_cols , rank ) for n_rows , n_cols , rank in _matrix_dimension_triples
162161]
163162scaled_matrices = [scale * matrix for scale in _scales for matrix in matrices ]
164- zero_rank_matrices = [
165- torch .zeros ([n_rows , n_cols ], device = DEVICE ) for n_rows , n_cols in _zero_rank_matrix_shapes
166- ]
163+ zero_rank_matrices = [torch .zeros ([n_rows , n_cols ]) for n_rows , n_cols in _zero_rank_matrix_shapes ]
167164matrices_2_plus_rows = [matrix for matrix in matrices + zero_rank_matrices if matrix .shape [0 ] >= 2 ]
168165scaled_matrices_2_plus_rows = [
169166 matrix for matrix in scaled_matrices + zero_rank_matrices if matrix .shape [0 ] >= 2
0 commit comments