1- import os
2- import sys
3-
41import torch
52from torch .nn import functional as F
6- from torch_scatter import scatter_max
3+ from torch_scatter import scatter_add , scatter_mean , scatter_max
74import networkx as nx
85from rdkit import Chem
9- from rdkit .Chem import RDConfig , Descriptors
6+ from rdkit .Chem import Descriptors
107
118from torchdrug import utils
129from torchdrug .layers import functional
@@ -23,6 +20,8 @@ def area_under_roc(pred, target):
2320 pred (Tensor): predictions of shape :math:`(n,)`
2421 target (Tensor): binary targets of shape :math:`(n,)`
2522 """
23+ if target .dtype != torch .long :
24+ raise TypeError ("Expect `target` to be torch.long, but found %s" % target .dtype )
2625 order = pred .argsort (descending = True )
2726 target = target [order ]
2827 hit = target .cumsum (0 )
@@ -40,6 +39,8 @@ def area_under_prc(pred, target):
4039 pred (Tensor): predictions of shape :math:`(n,)`
4140 target (Tensor): binary targets of shape :math:`(n,)`
4241 """
42+ if target .dtype != torch .long :
43+ raise TypeError ("Expect `target` to be torch.long, but found %s" % target .dtype )
4344 order = pred .argsort (descending = True )
4445 target = target [order ]
4546 precision = target .cumsum (0 ) / torch .arange (1 , len (target ) + 1 , device = target .device )
@@ -178,13 +179,103 @@ def chemical_validity(pred):
178179 validity .append (1 if mol else 0 )
179180
180181 return torch .tensor (validity , dtype = torch .float , device = pred .device )
182+
183+
184+ @R .register ("metrics.accuracy" )
185+ def accuracy (pred , target ):
186+ """
187+ Compute classification accuracy over sets with equal size.
188+
189+ Suppose there are :math:`N` sets and :math:`C` categories.
190+
191+ Parameters:
192+ pred (Tensor): prediction of shape :math:`(N, C)`
193+ target (Tensor): target of shape :math:`(N,)`
194+ """
195+ return (pred .argmax (dim = - 1 ) == target ).float ().mean ()
196+
197+
198+ @R .register ("metrics.mcc" )
199+ def matthews_corrcoef (pred , target , eps = 1e-6 ):
200+ """
201+ Matthews correlation coefficient between target and prediction.
202+
203+ Definition follows matthews_corrcoef for K classes in sklearn.
204+ For details, see: 'https://scikit-learn.org/stable/modules/model_evaluation.html#matthews-corrcoef'
205+
206+ Parameters:
207+ pred (Tensor): prediction of shape :math: `(N,)`
208+ target (Tensor): target of shape :math: `(N,)`
209+ """
210+ num_class = pred .size (- 1 )
211+ pred = pred .argmax (- 1 )
212+ ones = torch .ones (len (target ), device = pred .device )
213+ confusion_matrix = scatter_add (ones , target * num_class + pred , dim = 0 , dim_size = num_class ** 2 )
214+ confusion_matrix = confusion_matrix .view (num_class , num_class )
215+ t = confusion_matrix .sum (dim = 1 )
216+ p = confusion_matrix .sum (dim = 0 )
217+ c = confusion_matrix .trace ()
218+ s = confusion_matrix .sum ()
219+ return (c * s - t @ p ) / ((s * s - p @ p ) * (s * s - t @ t ) + eps ).sqrt ()
220+
221+
222+ @R .register ("metrics.pearsonr" )
223+ def pearsonr (pred , target ):
224+ """
225+ Pearson correlation between target and prediction.
226+ Mimics `scipy.stats.pearsonr`.
227+
228+ Parameters:
229+ pred (Tensor): prediction of shape :math: `(N,)`
230+ target (Tensor): target of shape :math: `(N,)`
231+ """
232+ pred_mean = pred .float ().mean ()
233+ target_mean = target .float ().mean ()
234+ pred_centered = pred - pred_mean
235+ target_centered = target - target_mean
236+ pred_normalized = pred_centered / pred_centered .norm (2 )
237+ target_normalized = target_centered / target_centered .norm (2 )
238+ pearsonr = pred_normalized @ target_normalized
239+ return pearsonr
240+
241+
242+ @R .register ("metrics.spearmanr" )
243+ def spearmanr (pred , target , eps = 1e-6 ):
244+ """
245+ Spearman correlation between target and prediction.
246+ Implement in PyTorch, but non-diffierentiable. (validation metric only)
247+
248+ Parameters:
249+ pred (Tensor): prediction of shape :math: `(N,)`
250+ target (Tensor): target of shape :math: `(N,)`
251+ """
252+
253+ def get_ranking (input ):
254+ input_set , input_inverse = input .unique (return_inverse = True )
255+ order = input_inverse .argsort ()
256+ ranking = torch .zeros (len (input_inverse ), device = input .device )
257+ ranking [order ] = torch .arange (1 , len (input ) + 1 , dtype = torch .float , device = input .device )
258+
259+ # for elements that have the same value, replace their rankings with the mean of their rankings
260+ mean_ranking = scatter_mean (ranking , input_inverse , dim = 0 , dim_size = len (input_set ))
261+ ranking = mean_ranking [input_inverse ]
262+ return ranking
263+
264+ pred = get_ranking (pred )
265+ target = get_ranking (target )
266+ covariance = (pred * target ).mean () - pred .mean () * target .mean ()
267+ pred_std = pred .std (unbiased = False )
268+ target_std = target .std (unbiased = False )
269+ spearmanr = covariance / (pred_std * target_std + eps )
270+ return spearmanr
181271
182272
273+ @R .register ("metrics.variadic_accuracy" )
183274def variadic_accuracy (input , target , size ):
184275 """
185276 Compute classification accuracy over variadic sizes of categories.
186277
187- Suppose there are :math:`N` samples, and the number of categories in all samples is summed to :math`B`.
278+ Suppose there are :math:`N` samples, and the number of categories in all samples is summed to :math: `B`.
188279
189280 Parameters:
190281 input (Tensor): prediction of shape :math:`(B,)`
@@ -196,4 +287,4 @@ def variadic_accuracy(input, target, size):
196287 input_class = scatter_max (input , index2graph )[1 ]
197288 target_index = target + size .cumsum (0 ) - size
198289 accuracy = (input_class == target_index ).float ()
199- return accuracy
290+ return accuracy
0 commit comments