We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent dc27de3 commit ac69684Copy full SHA for ac69684
1 file changed
dhg/nn/convs/hypergraphs/unignn_conv.py
@@ -62,7 +62,9 @@ def forward(self, X: torch.Tensor, hg: Hypergraph) -> torch.Tensor:
62
# compute the special degree of hyperedges
63
_De = torch.zeros(hg.num_e, device=hg.device)
64
# scatter_reduce() is relay on the torch 1.12.1, which may be updated in the future
65
- _De = _De.scatter_reduce(0, index=hg.v2e_dst, src=hg.D_v.clone()._values()[hg.v2e_src], reduce="mean")
+ _Dv = hg.D_v.clone()._values()[hg.v2e_src]
66
+ _De = _De.scatter_reduce(0, index=hg.v2e_dst, src=_Dv, reduce="sum") / _De.scatter_reduce(
67
+ 0, index=hg.v2e_dst, src=(_Dv != 0).type_as(_De), reduce="sum")
68
_De = _De.pow(-0.5)
69
_De[_De.isinf()] = 1
70
Y = _De.view(-1, 1) * Y
0 commit comments