Skip to content

Commit 0b57504

Browse files
Update README.md
1 parent b36dd26 commit 0b57504

1 file changed

Lines changed: 5 additions & 9 deletions

File tree

README.md

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ test_set = torchvision.datasets.MNIST(root='./sample_data', download=True, trans
238238
test_loader = torch.utils.data.DataLoader(test_set, batch_size=24)
239239

240240
# Load a batch of inputs and outputs to use for XAI evaluation.
241-
x_batch, y_batch = iter(test_loader).next()
241+
x_batch, y_batch = next(iter(test_loader))
242242
x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy()
243243
```
244244
</details>
@@ -281,11 +281,8 @@ import captum
281281
from captum.attr import Saliency, IntegratedGradients
282282

283283
# Generate Integrated Gradients attributions of the first batch of the test set.
284-
a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1).cpu().numpy()
285-
a_batch_intgrad = IntegratedGradients(model).attribute(inputs=x_batch, target=y_batch, baselines=torch.zeros_like(x_batch)).sum(axis=1).cpu().numpy()
286-
287-
# Save x_batch and y_batch as numpy arrays that will be used to call metric instances.
288-
x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy()
284+
a_batch_saliency = Saliency(model).attribute(inputs=torch.tensor(x_batch, dtype=torch.float32), target=torch.tensor(y_batch, dtype=torch.int64), abs=True).sum(axis=1).cpu().cpu().numpy()
285+
a_batch_intgrad = IntegratedGradients(model).attribute(inputs=torch.tensor(x_batch, dtype=torch.float32), target=torch.tensor(y_batch, dtype=torch.int64)).sum(axis=1).cpu().numpy()
289286

290287
# Quick assert.
291288
assert [isinstance(obj, np.ndarray) for obj in [x_batch, y_batch, a_batch_saliency, a_batch_intgrad]]
@@ -318,9 +315,8 @@ it first needs to be instantiated:
318315
```python
319316
metric = quantus.MaxSensitivity(nr_samples=10,
320317
lower_bound=0.2,
321-
norm_numerator=quantus.fro_norm,
322-
norm_denominator=quantus.fro_norm,
323-
perturb_func=quantus.uniform_noise,
318+
norm_numerator=quantus.norm_func.fro_norm,
319+
norm_denominator=quantus.norm_func.fro_norm,
324320
similarity_func=quantus.difference,
325321
abs=True,
326322
normalise=True)

0 commit comments

Comments
 (0)