@@ -238,7 +238,7 @@ test_set = torchvision.datasets.MNIST(root='./sample_data', download=True, trans
238238test_loader = torch.utils.data.DataLoader(test_set, batch_size = 24 )
239239
240240# Load a batch of inputs and outputs to use for XAI evaluation.
241- x_batch, y_batch = iter (test_loader).next( )
241+ x_batch, y_batch = next ( iter (test_loader))
242242x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy()
243243```
244244</details >
@@ -281,11 +281,8 @@ import captum
281281from captum.attr import Saliency, IntegratedGradients
282282
283283# Generate Integrated Gradients attributions of the first batch of the test set.
284- a_batch_saliency = Saliency(model).attribute(inputs = x_batch, target = y_batch, abs = True ).sum(axis = 1 ).cpu().numpy()
285- a_batch_intgrad = IntegratedGradients(model).attribute(inputs = x_batch, target = y_batch, baselines = torch.zeros_like(x_batch)).sum(axis = 1 ).cpu().numpy()
286-
287- # Save x_batch and y_batch as numpy arrays that will be used to call metric instances.
288- x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy()
284+ a_batch_saliency = Saliency(model).attribute(inputs = torch.tensor(x_batch, dtype = torch.float32), target = torch.tensor(y_batch, dtype = torch.int64), abs = True ).sum(axis = 1 ).cpu().cpu().numpy()
285+ a_batch_intgrad = IntegratedGradients(model).attribute(inputs = torch.tensor(x_batch, dtype = torch.float32), target = torch.tensor(y_batch, dtype = torch.int64)).sum(axis = 1 ).cpu().numpy()
289286
290287# Quick assert.
291288assert [isinstance (obj, np.ndarray) for obj in [x_batch, y_batch, a_batch_saliency, a_batch_intgrad]]
@@ -318,9 +315,8 @@ it first needs to be instantiated:
318315``` python
319316metric = quantus.MaxSensitivity(nr_samples = 10 ,
320317 lower_bound = 0.2 ,
321- norm_numerator = quantus.fro_norm,
322- norm_denominator = quantus.fro_norm,
323- perturb_func = quantus.uniform_noise,
318+ norm_numerator = quantus.norm_func.fro_norm,
319+ norm_denominator = quantus.norm_func.fro_norm,
324320 similarity_func = quantus.difference,
325321 abs = True ,
326322 normalise = True )
0 commit comments