Skip to content

Commit c414f83

Browse files
authored
fixed nnz value (#17)
fixes for the rspmm kernels
1 parent 871add0 commit c414f83

3 files changed

Lines changed: 8 additions & 5 deletions

File tree

ultra/rspmm/rspmm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,10 @@ def generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="
181181
def load_extension(name, sources, extra_cflags=None, extra_cuda_cflags=None, **kwargs):
182182
if extra_cflags is None:
183183
extra_cflags = ["-Ofast"]
184-
if torch.backends.openmp.is_available():
184+
# PyTorch 2.2.1+ on Apple Silicon is now compiled by default with OpenMP
185+
# However, installing OpenMP on macs properly and wiring it together to the compiler is tedious
186+
# So on macs we turn off OpenMP (as the default behavior in all torch < 2.2.1 versions)
187+
if torch.backends.openmp.is_available() and not sys.platform.startswith('darwin'):
185188
extra_cflags += ["-fopenmp", "-DAT_PARALLEL_OPENMP"]
186189
else:
187190
extra_cflags.append("-DAT_PARALLEL_NATIVE")

ultra/rspmm/source/rspmm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ Tensor rspmm_forward_cpu(const Tensor &edge_index_, const Tensor &edge_type_, co
135135
const Tensor relation = relation_.contiguous();
136136
const Tensor input = input_.contiguous();
137137

138-
int64_t nnz = edge_index.size(0);
138+
int64_t nnz = edge_index.size(1);
139139
int64_t num_row = input.size(0);
140140
int64_t dim = input.size(1);
141141
Tensor output = at::empty({num_row, dim}, input.options());
@@ -183,7 +183,7 @@ std::tuple<Tensor, Tensor, Tensor> rspmm_backward_cpu(
183183
const Tensor output = output_.contiguous();
184184
const Tensor output_grad = output_grad_.contiguous();
185185

186-
int64_t nnz = edge_index.size(0);
186+
int64_t nnz = edge_index.size(1);
187187
int64_t num_row = input.size(0);
188188
int64_t dim = input.size(1);
189189
Tensor weight_grad = at::zeros_like(edge_weight);

ultra/rspmm/source/rspmm.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ Tensor rspmm_forward_cuda(const Tensor &edge_index_, const Tensor &edge_type_, c
230230
const Tensor relation = relation_.contiguous();
231231
const Tensor input = input_.contiguous();
232232

233-
int64_t nnz = edge_index.size(0);
233+
int64_t nnz = edge_index.size(1);
234234
int64_t num_row = input.size(0);
235235
int64_t dim = input.size(1);
236236
Tensor output = at::empty({num_row, dim}, input.options());
@@ -289,7 +289,7 @@ std::tuple<Tensor, Tensor, Tensor> rspmm_backward_cuda(
289289
const Tensor output = output_.contiguous();
290290
const Tensor output_grad = output_grad_.contiguous();
291291

292-
int64_t nnz = edge_index.size(0);
292+
int64_t nnz = edge_index.size(1);
293293
int64_t num_row = input.size(0);
294294
int64_t dim = input.size(1);
295295
Tensor weight_grad = at::zeros_like(edge_weight);

0 commit comments

Comments
 (0)