Welcome to lfxai’s documentation!
Label-Free XAI


Code Author: Jonathan Crabbé (jc2133@cam.ac.uk)
This repository contains the implementation of LFXAI, a framework to explain the latent representations of unsupervised black-box models with the help of usual feature importance and example-based methods. For more details, please read our ICML 2022 paper: ‘Label-Free Explainability for Unsupervised Models’.
1. Installation
From PyPI
pip install lfxai
From repository:
Clone the repository
Create a new virtual environment with Python 3.8
Run the following command from the repository folder:
pip install .
When the packages are installed, you are ready to explain unsupervised models.
2. Toy example
Bellow, you can find a toy demonstration where we compute label-free feature and example importance with a MNIST autoencoder. The relevant code can be found in the folder explanations.
import torch
from pathlib import Path
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torch.nn import MSELoss
from captum.attr import IntegratedGradients
from lfxai.models.images import AutoEncoderMnist, EncoderMnist, DecoderMnist
from lfxai.models.pretext import Identity
from lfxai.explanations.features import attribute_auxiliary
from lfxai.explanations.examples import SimplEx
# Select torch device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Load data
data_dir = Path.cwd() / "data/mnist"
train_dataset = MNIST(data_dir, train=True, download=True)
test_dataset = MNIST(data_dir, train=False, download=True)
train_dataset.transform = transforms.Compose([transforms.ToTensor()])
test_dataset.transform = transforms.Compose([transforms.ToTensor()])
train_loader = DataLoader(train_dataset, batch_size=100)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
# Get a model
encoder = EncoderMnist(encoded_space_dim=10)
decoder = DecoderMnist(encoded_space_dim=10)
model = AutoEncoderMnist(encoder, decoder, latent_dim=10, input_pert=Identity())
model.to(device)
# Get label-free feature importance
baseline = torch.zeros((1, 1, 28, 28)).to(device) # black image as baseline
attr_method = IntegratedGradients(model)
feature_importance = attribute_auxiliary(encoder, test_loader,
device, attr_method, baseline)
# Get label-free example importance
train_subset = Subset(train_dataset, indices=list(range(500))) # Limit the number of training examples
train_subloader = DataLoader(train_subset, batch_size=500)
attr_method = SimplEx(model, loss_f=MSELoss())
example_importance = attr_method.attribute_loader(device, train_subloader, test_loader)
3. Reproducing the paper results
MNIST experiments
In the experiments
folder, run the following script
python -m mnist --name experiment_name
where experiment_name can take the following values:
experiment_name |
description |
---|---|
consistency_features |
Consistency check for label-free |
consistency_examples |
Consistency check for label-free |
roar_test |
ROAR test for label-free |
pretext |
Pretext task sensitivity |
disvae |
Challenging assumptions with |
The resulting plots and data are saved here.
ECG5000 experiments
Run the following script
python -m ecg5000 --name experiment_name
where experiment_name can take the following values:
experiment_name |
description |
---|---|
consistency_features |
Consistency check for label-free |
consistency_examples |
Consistency check for label-free |
The resulting plots and data are saved here.
CIFAR10 experiments
Run the following script
python -m cifar10
The experiment can be selected by changing the experiment_name parameter in this file. The parameter can take the following values:
experiment_name |
description |
---|---|
consistency_features |
Consistency check for label-free |
consistency_examples |
Consistency check for label-free |
The resulting plots and data are saved here.
dSprites experiment
Run the following script
python -m dsprites
The experiment needs several hours to run since several VAEs are trained. The resulting plots and data are saved here.
4. Citing
If you use this code, please cite the associated paper:
@InProceedings{pmlr-v162-crabbe22a,
title = {Label-Free Explainability for Unsupervised Models},
author = {Crabb{\'e}, Jonathan and van der Schaar, Mihaela},
booktitle = {Proceedings of the 39th International Conference on Machine Learning},
pages = {4391--4420},
year = {2022},
editor = {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan},
volume = {162},
series = {Proceedings of Machine Learning Research},
month = {17--23 Jul},
publisher = {PMLR},
pdf = {https://proceedings.mlr.press/v162/crabbe22a/crabbe22a.pdf},
url = {https://proceedings.mlr.press/v162/crabbe22a.html},
abstract = {Unsupervised black-box models are challenging to interpret. Indeed, most existing explainability methods require labels to select which component(s) of the black-box’s output to interpret. In the absence of labels, black-box outputs often are representation vectors whose components do not correspond to any meaningful quantity. Hence, choosing which component(s) to interpret in a label-free unsupervised/self-supervised setting is an important, yet unsolved problem. To bridge this gap in the literature, we introduce two crucial extensions of post-hoc explanation techniques: (1) label-free feature importance and (2) label-free example importance that respectively highlight influential features and training examples for a black-box to construct representations at inference time. We demonstrate that our extensions can be successfully implemented as simple wrappers around many existing feature and example importance methods. We illustrate the utility of our label-free explainability paradigm through a qualitative and quantitative comparison of representation spaces learned by various autoencoders trained on distinct unsupervised tasks.}
}
API documentation
Models
Models
lfxai.models.images module
- class AutoEncoderMnist(encoder: EncoderMnist, decoder: DecoderMnist, latent_dim: int, input_pert: callable, name: str = 'model', loss_f: callable = MSELoss())
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- fit(device: device, train_loader: DataLoader, test_loader: DataLoader, save_dir: Path, n_epoch: int = 30, patience: int = 10, checkpoint_interval: int = -1) None
- forward(x)
Forward pass of model.
Parameters:
- xtorch.Tensor
Batch of data. Shape (batch_size, n_chan, height, width)
- load_metadata(directory: Path) dict
Load the metadata of a training directory.
Parameters:
- directorypathlib.Path
Path to folder where model is saved. For example ‘./experiments/mnist’.
- save(directory: Path) None
Save a model and corresponding metadata.
Parameters:
- directorypathlib.Path
Path to the directory where to save the data.
- save_metadata(directory: Path, **kwargs) None
Load the metadata of a training directory.
Parameters:
- directory: string
Path to folder where to save model. For example ‘./experiments/mnist’.
- kwargs:
Additional arguments to json.dump
- test_epoch(device: device, dataloader: DataLoader)
- train_epoch(device: device, dataloader: DataLoader, optimizer: Optimizer) ndarray
- training: bool
- class BetaTcVaeMnist(latent_dims: int = 10, beta: int = 1)
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- fit(device: device, train_loader: DataLoader, test_loader: DataLoader, n_epoch: int = 30) None
- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- latent_sample(mu, logvar)
- loss(recon_x: Tensor, x: Tensor, mu: Tensor, logvar: Tensor, z: Tensor, dataset_size: int) Tensor
- test_epoch(device: device, dataloader: DataLoader)
- train_epoch(device: device, dataloader: DataLoader, optimizer: Optimizer) ndarray
- training: bool
- class BetaVaeMnist(latent_dims: int = 10, beta: int = 1)
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- fit(device: device, train_loader: DataLoader, test_loader: DataLoader, n_epoch: int = 30) None
- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- latent_sample(mu, logvar)
- loss(recon_x: Tensor, x: Tensor, mu: Tensor, logvar: Tensor, dataset_size: int) Tensor
- test_epoch(device: device, dataloader: DataLoader) ndarray
- train_epoch(device: device, dataloader: DataLoader, optimizer: Optimizer) ndarray
- training: bool
- class ClassifierLatent(latent_dims: int)
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class ClassifierMnist(encoder: EncoderMnist, latent_dim: int, name: str)
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- fit(device: device, train_loader: DataLoader, test_loader: DataLoader, save_dir: Path, n_epoch: int = 30, patience: int = 10) None
- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- load_metadata(directory: Path) dict
Load the metadata of a training directory.
Parameters:
- directorypathlib.Path
Path to folder where model is saved. For example ‘./experiments/mnist’.
- save(directory: Path) None
Save a model and corresponding metadata.
Parameters:
- directorypathlib.Path
Path to the directory where to save the data.
- save_metadata(directory: Path, **kwargs) None
Load the metadata of a training directory.
Parameters:
- directory: string
Path to folder where to save model. For example ‘./experiments/mnist’.
- kwargs:
Additional arguments to json.dump
- test_epoch(device: device, dataloader: DataLoader)
- train_epoch(device: device, dataloader: DataLoader, optimizer: Optimizer) ndarray
- training: bool
- class DecoderBurgess(img_size, latent_dim=10)
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- forward(z)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class DecoderMnist(encoded_space_dim)
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class EncoderBurgess(img_size, latent_dim=10)
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- mu(x)
- training: bool
- class EncoderMnist(encoded_space_dim)
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class SimCLR(base_encoder, projection_dim=128)
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- fit(args: DictConfig, device: device) None
- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- static get_color_distortion(s=0.5)
- static get_lr(step, total_steps, lr_max, lr_min)
Compute learning rate according to cosine annealing schedule.
- static nt_xent(x, t=0.5)
- training: bool
- class VAE(img_size: tuple, encoder: EncoderBurgess, decoder: DecoderBurgess, latent_dim: int, loss_f: BaseVAELoss, name: str = 'model')
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- fit(device: device, train_loader: DataLoader, test_loader: DataLoader, save_dir: Path, n_epoch: int = 30, patience: int = 10) None
- forward(x)
Forward pass of model.
Parameters:
- xtorch.Tensor
Batch of data. Shape (batch_size, n_chan, height, width)
- load_metadata(directory: Path) dict
Load the metadata of a training directory.
Parameters:
- directorypathlib.Path
Path to folder where model is saved. For example ‘./experiments/mnist’.
- reparameterize(mean, logvar)
Samples from a normal distribution using the reparameterization trick.
Parameters:
- meantorch.Tensor
Mean of the normal distribution. Shape (batch_size, latent_dim)
- logvartorch.Tensor
Diagonal log variance of the normal distribution. Shape (batch_size, latent_dim)
- sample_latent(x)
Returns a sample from the latent distribution.
Parameters:
- xtorch.Tensor
Batch of data. Shape (batch_size, n_chan, height, width)
- save(directory: Path) None
Save a model and corresponding metadata.
Parameters:
- directorypathlib.Path
Path to the directory where to save the data.
- save_metadata(directory: Path, **kwargs) None
Load the metadata of a training directory.
Parameters:
- directory: string
Path to folder where to save model. For example ‘./experiments/mnist’.
- kwargs:
Additional arguments to json.dump
- test_epoch(device: device, dataloader: DataLoader)
- train_epoch(device: device, dataloader: DataLoader, optimizer: Optimizer) ndarray
- training: bool
- class VarDecoderMnist(c: int = 64, latent_dims: int = 10)
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class VarEncoderMnist(c: int = 64, latent_dims: int = 10)
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- mu(x)
- training: bool
- init_vae(img_size, latent_dim, loss_f, name)
Return an instance of a VAE with encoder and decoder from model_type.
- log_density_gaussian(x: Tensor, mu: Tensor, logvar: Tensor)
Computes the log pdf of the Gaussian with parameters mu and logvar at x
- Parameters
x – (Tensor) Point at whichGaussian PDF is to be evaluated
mu – (Tensor) Mean of the Gaussian distribution
logvar – (Tensor) Log variance of the Gaussian distribution
lfxai.models.time_series module
- class AutoencoderCNN(embedding_dim: int = 64, name: str = 'model', loss_f: callable = L1Loss())
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- fit(device: device, train_loader: DataLoader, test_loader: DataLoader, save_dir: Path, n_epoch: int = 30, patience: int = 10, checkpoint_interval: int = -1) None
- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- save(directory: Path) None
Save a model and corresponding metadata. Parameters: ———- directory : pathlib.Path
Path to the directory where to save the data.
- test_epoch(device: device, dataloader: DataLoader)
- train_epoch(device: device, dataloader: DataLoader, optimizer: Optimizer) ndarray
- training: bool
- class Decoder(seq_len: int, n_features: int = 1, input_dim: int = 64)
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class DecoderCNN(encoded_space_dim)
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class Encoder(seq_len: int, n_features: int = 1, embedding_dim: int = 64)
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class EncoderCNN(encoded_space_dim)
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class RecurrentAutoencoder(seq_len: int, n_features: int, embedding_dim: int = 64, name: str = 'model', loss_f: callable = L1Loss())
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- fit(device: device, train_loader: DataLoader, test_loader: DataLoader, save_dir: Path, n_epoch: int = 30, patience: int = 10, checkpoint_interval: int = -1) None
- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- save(directory: Path) None
Save a model and corresponding metadata. Parameters: ———- directory : pathlib.Path
Path to the directory where to save the data.
- test_epoch(device: device, dataloader: DataLoader)
- train_epoch(device: device, dataloader: DataLoader, optimizer: Optimizer) ndarray
- training: bool
lfxai.models.losses module
- class BaseVAELoss(record_loss_every=50, rec_dist='bernoulli', steps_anneal=0)
Bases:
ABC
Base class for losses.
Parameters:
- record_loss_every: int, optional
Every how many steps to recorsd the loss.
- rec_dist: {“bernoulli”, “gaussian”, “laplace”}, optional
Reconstruction distribution istribution of the likelihood on the each pixel. Implicitely defines the reconstruction loss. Bernoulli corresponds to a binary cross entropy (bse), Gaussian corresponds to MSE, Laplace corresponds to L1.
- steps_anneal: nool, optional
Number of annealing steps where gradually adding the regularisation.
- _abc_impl = <_abc_data object>
- _pre_call(is_train, storer)
- class BetaHLoss(beta=4, **kwargs)
Bases:
BaseVAELoss
Compute the Beta-VAE loss as in [1]
Parameters:
- betafloat, optional
Weight of the kl divergence.
- kwargs:
Additional arguments for BaseLoss, e.g. rec_dist`.
References:
[1] Higgins, Irina, et al. “beta-vae: Learning basic visual concepts with a constrained variational framework.” (2016).
- _abc_impl = <_abc_data object>
- class BtcvaeLoss(n_data, alpha=1.0, beta=6.0, gamma=1.0, is_mss=True, **kwargs)
Bases:
BaseVAELoss
Compute the decomposed KL loss with either minibatch weighted sampling or minibatch stratified sampling according to [1]
Parameters:
- n_data: int
Number of data in the training set
- alphafloat
Weight of the mutual information term.
- betafloat
Weight of the total correlation term.
- gammafloat
Weight of the dimension-wise KL term.
- is_mssbool
Whether to use minibatch stratified sampling instead of minibatch weighted sampling.
- kwargs:
Additional arguments for BaseLoss, e.g. rec_dist`.
References:
[1] Chen, Tian Qi, et al. “Isolating sources of disentanglement in variational autoencoders.” Advances in Neural Information Processing Systems. 2018.
- _abc_impl = <_abc_data object>
- _get_log_pz_qz_prodzi_qzCx(latent_sample, latent_dist, n_data, is_mss=True)
- _kl_normal_loss(mean, logvar, storer=None)
Calculates the KL divergence between a normal distribution with diagonal covariance and a unit normal distribution.
Parameters:
- meantorch.Tensor
Mean of the normal distribution. Shape (batch_size, latent_dim) where D is dimension of distribution.
- logvartorch.Tensor
Diagonal log variance of the normal distribution. Shape (batch_size, latent_dim)
- storerdict
Dictionary in which to store important variables for vizualisation.
- _permute_dims(latent_sample)
Implementation of Algorithm 1 in ref [1]. Randomly permutes the sample from q(z) (latent_dist) across the batch for each of the latent dimensions (mean and log_var).
Parameters:
- latent_sample: torch.Tensor
sample from the latent dimension using the reparameterisation trick shape : (batch_size, latent_dim).
References:
[1] Kim, Hyunjik, and Andriy Mnih. “Disentangling by factorising.” arXiv preprint arXiv:1802.05983 (2018).
- _reconstruction_loss(data, recon_data, distribution='bernoulli', storer=None)
Calculates the per image reconstruction loss for a batch of data. I.e. negative log likelihood.
Parameters:
- datatorch.Tensor
Input data (e.g. batch of images). Shape : (batch_size, n_chan, height, width).
- recon_datatorch.Tensor
Reconstructed data. Shape : (batch_size, n_chan, height, width).
- distribution{“bernoulli”, “gaussian”, “laplace”}
Distribution of the likelihood on the each pixel. Implicitely defines the loss Bernoulli corresponds to a binary cross entropy (bse) loss and is the most commonly used. It has the issue that it doesn’t penalize the same way (0.1,0.2) and (0.4,0.5), which might not be optimal. Gaussian distribution corresponds to MSE, and is sometimes used, but hard to train ecause it ends up focusing only a few pixels that are very wrong. Laplace distribution corresponds to L1 solves partially the issue of MSE.
- storerdict
Dictionary in which to store important variables for vizualisation.
Returns:
- losstorch.Tensor
Per image cross entropy (i.e. normalized per batch but not pixel and channel)
- get_loss_f(loss_name, **kwargs_parse)
Return the correct loss function given the argparse arguments.
- linear_annealing(init, fin, step, annealing_steps)
Linear annealing of a parameter.
Explanations
Explanations
lfxai.explanations.examples module
- class ExampleBasedExplainer(model: Module, X_train: Tensor, loss_f: callable, **kwargs)
Bases:
ABC
- _abc_impl = <_abc_data object>
- abstract attribute(X_test: Tensor, train_idx: list, **kwargs) Tensor
- Parameters
X_test –
train_idx –
**kwargs –
Returns:
- class InfluenceFunctions(model: Module, loss_f: callable, save_dir: Path, X_train: Optional[Tensor] = None)
Bases:
ExampleBasedExplainer
,ABC
- _abc_impl = <_abc_data object>
- attribute(X_test: Tensor, train_idx: list, batch_size: int = 1, damp: float = 0.001, scale: float = 1000, recursion_depth: int = 100, **kwargs) Tensor
Code adapted from https://github.com/ahmedmalaa/torch-influence-functions/ This function applies the stochastic estimation approach to evaluating influence function based on the power-series approximation of matrix inversion. Recall that the exact inverse Hessian H^-1 can be computed as follows: H^-1 = sum^infty_{i=0} (I - H) ^ j This series converges if all the eigen values of H are less than 1.
- Returns
list of torch tensors, contains product of Hessian and v.
- Return type
return_grads
- attribute_loader(device: device, train_loader: DataLoader, test_loader: DataLoader, train_loader_replacement: DataLoader, recursion_depth: int = 100, damp: float = 0.001, scale: float = 1000, **kwargs) Tensor
- class NearestNeighbours(model: Module, loss_f: Optional[callable] = None, X_train: Optional[Tensor] = None)
Bases:
ExampleBasedExplainer
,ABC
- _abc_impl = <_abc_data object>
- attribute(X_test: Tensor, train_idx: list, batch_size: int = 500, **kwargs) Tensor
- Parameters
X_test –
train_idx –
**kwargs –
Returns:
- attribute_loader(device: device, train_loader: DataLoader, test_loader: DataLoader, batch_size: int = 50, **kwargs) Tensor
- class SimplEx(model: Module, loss_f: Optional[callable] = None, X_train: Optional[Tensor] = None)
Bases:
ExampleBasedExplainer
,ABC
- _abc_impl = <_abc_data object>
- attribute(X_test: Tensor, train_idx: list, learning_rate: float = 1, batch_size: int = 50, **kwargs) Tensor
- Parameters
X_test –
train_idx –
**kwargs –
Returns:
- attribute_loader(device: device, train_loader: DataLoader, test_loader: DataLoader, batch_size: int = 50, **kwargs) Tensor
- static compute_weights(batch_representations: Tensor, train_representations: Tensor, n_epoch: int = 1000) Tensor
- class TracIn(model: Module, loss_f: callable, save_dir: Path, X_train: Optional[Tensor] = None)
Bases:
ExampleBasedExplainer
,ABC
- _abc_impl = <_abc_data object>
- attribute(X_test: Tensor, train_idx: list, learning_rate: float = 1, **kwargs) Tensor
- Parameters
X_test –
train_idx –
**kwargs –
Returns:
- attribute_loader(device: device, train_loader: DataLoader, test_loader: DataLoader, **kwargs) Tensor
lfxai.explanations.features module
- class AuxiliaryFunction(black_box: Module, base_features: Tensor)
Bases:
Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- forward(input_features: Tensor) Tensor
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- attribute_auxiliary(encoder: Module, data_loader: DataLoader, device: device, attr_method: Attribution, baseline=None) ndarray
- attribute_individual_dim(encoder: callable, dim_latent: int, data_loader: DataLoader, device: device, attr_method: Attribution, baseline: Tensor) ndarray