Gifsplanation via Latent Shift: A Simple Autoencoder Approach to Counterfactual Generation for Chest X-rays
Joseph Paul Cohen
and
Rupert Brooks
and
Sovann En
and
Evan Zucker
and
Anuj Pareek
and
Matthew P. Lungren
and
Akshay Chaudhari
arXiv e-Print archive - 2021 via Local arXiv
Keywords:
cs.CV, cs.AI, eess.IV
First published: 2024/11/21 (just now) Abstract: Motivation: Traditional image attribution methods struggle to satisfactorily
explain predictions of neural networks. Prediction explanation is important,
especially in medical imaging, for avoiding the unintended consequences of
deploying AI systems when false positive predictions can impact patient care.
Thus, there is a pressing need to develop improved models for model
explainability and introspection. Specific problem: A new approach is to
transform input images to increase or decrease features which cause the
prediction. However, current approaches are difficult to implement as they are
monolithic or rely on GANs. These hurdles prevent wide adoption. Our approach:
Given an arbitrary classifier, we propose a simple autoencoder and gradient
update (Latent Shift) that can transform the latent representation of a
specific input image to exaggerate or curtail the features used for prediction.
We use this method to study chest X-ray classifiers and evaluate their
performance. We conduct a reader study with two radiologists assessing 240
chest X-ray predictions to identify which ones are false positives (half are)
using traditional attribution maps or our proposed method. Results: We found
low overlap with ground truth pathology masks for models with reasonably high
accuracy. However, the results from our reader study indicate that these models
are generally looking at the correct features. We also found that the Latent
Shift explanation allows a user to have more confidence in true positive
predictions compared to traditional approaches (0.15$\pm$0.95 in a 5 point
scale with p=0.01) with only a small increase in false positive predictions
(0.04$\pm$1.06 with p=0.57).
Accompanying webpage: https://mlmed.org/gifsplanation
Source code: https://github.com/mlmed/gifsplanation
**Background:** The goal of this work is to indicate image features which are relevant to the prediction of a neural network and convey that information to the user by displaying a counterfactual image animation.
**The Latent Shift Method:** This method works on any pretrained encoder/decoder and classifier which is differentiable. No special considerations are needed during model training. With this approach they want the exact opposite of an adversarial attack but it is using the same idea. They want to perturb the input image so that the classifier reduces its prediction. If they just compute $\frac{\partial f}{\partial x}$ and move the pixels directly then they will get an imperceivable difference like an adversarial attack. Using a decoder they can regularize the transformation so it will only yield value images.
The encoder takes the input image and encodes it into a latent representation $z$. Then the decoder reconstructs the image and feeds this image into the classifier. The gradient is computed from the output of the classifier with respect to $z$. Subtracting the gradient from z and reconstructing the image generates a counterfactual.
https://i.imgur.com/iuZGUTH.gif
They found that if they change the prediction by -30% the images come out pretty good. So an iterative search along the vector defined by the gradient in the latent space until the prediction is reduced by 30%.
From this sequence a 2D image can be reconstructed which is similar to a traditional attribution map by taking the maximum pixel wise difference between every image and the unperturbed reconstruction.
https://i.imgur.com/V3PCgXZ.png
The results look great!
https://i.imgur.com/DBki84c.gif
https://i.imgur.com/kFfQNKD.gif
In order to validate if this approach can help spot false positive predictions, two radiologists to evaluate how confident they were in a models predictions. For each image, radiologists viewed the prediction in two ways, using traditional methods or the Latent Shift images. Traditional methods includes the image gradient, guided backprop, and integrated gradients. The Latent Shift Counterfactual includes the animation as well as the 2D version.
https://i.imgur.com/TlUBhzL.png
What they would like to see, that for true positives, the results are all 5 and for false positives they are all 1.
What they observe however, is that many false positives still cause high confidence in the model predictions but not as much as the true positives. Between these two methods they find for true positives that the latent shift counterfactuals show a significant increase in confidence which is good.
> 0.15±0.95 confidence increase using the Latent Shift method (p=0.01).
For false positives they find an increase in confidence but it is not significant.
> 0.04±1.06 increase which is not significant (p=0.57)
**Conclusions:**
- Latent Shift's ability to generate counterfactuals is pretty good!
- Vanilla autoencoders are sufficient for some pathologies.
- StyleGAN and higher quality models should improve performance.
- IoU analysis may not be the best fit.
- Explainable AI methods can have an impact on the user confidence in the model.
(Disclaimer: I am the author of this work)
Project Website: https://mlmed.org/gifsplanation/