Welcome to lfxai’s documentation!

Label-Free XAI

Tests License: MIT Documentation Status image

Code Author: Jonathan Crabbé (jc2133@cam.ac.uk)

This repository contains the implementation of LFXAI, a framework to explain the latent representations of unsupervised black-box models with the help of usual feature importance and example-based methods. For more details, please read our ICML 2022 paper: ‘Label-Free Explainability for Unsupervised Models’.

1. Installation

From PyPI

pip install lfxai

From repository:

  1. Clone the repository

  2. Create a new virtual environment with Python 3.8

  3. Run the following command from the repository folder:

pip install .

When the packages are installed, you are ready to explain unsupervised models.

2. Toy example

Bellow, you can find a toy demonstration where we compute label-free feature and example importance with a MNIST autoencoder. The relevant code can be found in the folder explanations.

import torch
from pathlib import Path
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torch.nn import MSELoss
from captum.attr import IntegratedGradients

from lfxai.models.images import AutoEncoderMnist, EncoderMnist, DecoderMnist
from lfxai.models.pretext import Identity
from lfxai.explanations.features import attribute_auxiliary
from lfxai.explanations.examples import SimplEx

# Select torch device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Load data
data_dir = Path.cwd() / "data/mnist"
train_dataset = MNIST(data_dir, train=True, download=True)
test_dataset = MNIST(data_dir, train=False, download=True)
train_dataset.transform = transforms.Compose([transforms.ToTensor()])
test_dataset.transform = transforms.Compose([transforms.ToTensor()])
train_loader = DataLoader(train_dataset, batch_size=100)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)

# Get a model
encoder = EncoderMnist(encoded_space_dim=10)
decoder = DecoderMnist(encoded_space_dim=10)
model = AutoEncoderMnist(encoder, decoder, latent_dim=10, input_pert=Identity())
model.to(device)

# Get label-free feature importance
baseline = torch.zeros((1, 1, 28, 28)).to(device) # black image as baseline
attr_method = IntegratedGradients(model)
feature_importance = attribute_auxiliary(encoder, test_loader,
                                         device, attr_method, baseline)

# Get label-free example importance
train_subset = Subset(train_dataset, indices=list(range(500))) # Limit the number of training examples
train_subloader = DataLoader(train_subset, batch_size=500)
attr_method = SimplEx(model, loss_f=MSELoss())
example_importance = attr_method.attribute_loader(device, train_subloader, test_loader)

3. Reproducing the paper results

MNIST experiments

In the experiments folder, run the following script

python -m mnist --name experiment_name

where experiment_name can take the following values:

experiment_name

description

consistency_features

Consistency check for label-free
feature importance (paper Section 4.1)

consistency_examples

Consistency check for label-free
example importance (paper Section 4.1)

roar_test

ROAR test for label-free
feature importance (paper Appendix C)

pretext

Pretext task sensitivity
use case (paper Section 4.2)

disvae

Challenging assumptions with
disentangled VAEs (paper Section 4.3)

The resulting plots and data are saved here.

ECG5000 experiments

Run the following script

python -m ecg5000 --name experiment_name

where experiment_name can take the following values:

experiment_name

description

consistency_features

Consistency check for label-free
feature importance (paper Section 4.1)

consistency_examples

Consistency check for label-free
example importance (paper Section 4.1)

The resulting plots and data are saved here.

CIFAR10 experiments

Run the following script

python -m cifar10

The experiment can be selected by changing the experiment_name parameter in this file. The parameter can take the following values:

experiment_name

description

consistency_features

Consistency check for label-free
feature importance (paper Section 4.1)

consistency_examples

Consistency check for label-free
example importance (paper Section 4.1)

The resulting plots and data are saved here.

dSprites experiment

Run the following script

python -m dsprites

The experiment needs several hours to run since several VAEs are trained. The resulting plots and data are saved here.

4. Citing

If you use this code, please cite the associated paper:

@InProceedings{pmlr-v162-crabbe22a,
  title =    {Label-Free Explainability for Unsupervised Models},
  author =       {Crabb{\'e}, Jonathan and van der Schaar, Mihaela},
  booktitle =    {Proceedings of the 39th International Conference on Machine Learning},
  pages =    {4391--4420},
  year =     {2022},
  editor =   {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan},
  volume =   {162},
  series =   {Proceedings of Machine Learning Research},
  month =    {17--23 Jul},
  publisher =    {PMLR},
  pdf =      {https://proceedings.mlr.press/v162/crabbe22a/crabbe22a.pdf},
  url =      {https://proceedings.mlr.press/v162/crabbe22a.html},
  abstract =     {Unsupervised black-box models are challenging to interpret. Indeed, most existing explainability methods require labels to select which component(s) of the black-box’s output to interpret. In the absence of labels, black-box outputs often are representation vectors whose components do not correspond to any meaningful quantity. Hence, choosing which component(s) to interpret in a label-free unsupervised/self-supervised setting is an important, yet unsolved problem. To bridge this gap in the literature, we introduce two crucial extensions of post-hoc explanation techniques: (1) label-free feature importance and (2) label-free example importance that respectively highlight influential features and training examples for a black-box to construct representations at inference time. We demonstrate that our extensions can be successfully implemented as simple wrappers around many existing feature and example importance methods. We illustrate the utility of our label-free explainability paradigm through a qualitative and quantitative comparison of representation spaces learned by various autoencoders trained on distinct unsupervised tasks.}
}

API documentation

Models

Models

lfxai.models.images module

class AutoEncoderMnist(encoder: EncoderMnist, decoder: DecoderMnist, latent_dim: int, input_pert: callable, name: str = 'model', loss_f: callable = MSELoss())

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
fit(device: device, train_loader: DataLoader, test_loader: DataLoader, save_dir: Path, n_epoch: int = 30, patience: int = 10, checkpoint_interval: int = -1) None
forward(x)

Forward pass of model.

Parameters:

xtorch.Tensor

Batch of data. Shape (batch_size, n_chan, height, width)

load_metadata(directory: Path) dict

Load the metadata of a training directory.

Parameters:

directorypathlib.Path

Path to folder where model is saved. For example ‘./experiments/mnist’.

save(directory: Path) None

Save a model and corresponding metadata.

Parameters:

directorypathlib.Path

Path to the directory where to save the data.

save_metadata(directory: Path, **kwargs) None

Load the metadata of a training directory.

Parameters:

directory: string

Path to folder where to save model. For example ‘./experiments/mnist’.

kwargs:

Additional arguments to json.dump

test_epoch(device: device, dataloader: DataLoader)
train_epoch(device: device, dataloader: DataLoader, optimizer: Optimizer) ndarray
training: bool
class BetaTcVaeMnist(latent_dims: int = 10, beta: int = 1)

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
fit(device: device, train_loader: DataLoader, test_loader: DataLoader, n_epoch: int = 30) None
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

latent_sample(mu, logvar)
loss(recon_x: Tensor, x: Tensor, mu: Tensor, logvar: Tensor, z: Tensor, dataset_size: int) Tensor
test_epoch(device: device, dataloader: DataLoader)
train_epoch(device: device, dataloader: DataLoader, optimizer: Optimizer) ndarray
training: bool
class BetaVaeMnist(latent_dims: int = 10, beta: int = 1)

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
fit(device: device, train_loader: DataLoader, test_loader: DataLoader, n_epoch: int = 30) None
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

latent_sample(mu, logvar)
loss(recon_x: Tensor, x: Tensor, mu: Tensor, logvar: Tensor, dataset_size: int) Tensor
test_epoch(device: device, dataloader: DataLoader) ndarray
train_epoch(device: device, dataloader: DataLoader, optimizer: Optimizer) ndarray
training: bool
class ClassifierLatent(latent_dims: int)

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class ClassifierMnist(encoder: EncoderMnist, latent_dim: int, name: str)

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
fit(device: device, train_loader: DataLoader, test_loader: DataLoader, save_dir: Path, n_epoch: int = 30, patience: int = 10) None
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

load_metadata(directory: Path) dict

Load the metadata of a training directory.

Parameters:

directorypathlib.Path

Path to folder where model is saved. For example ‘./experiments/mnist’.

save(directory: Path) None

Save a model and corresponding metadata.

Parameters:

directorypathlib.Path

Path to the directory where to save the data.

save_metadata(directory: Path, **kwargs) None

Load the metadata of a training directory.

Parameters:

directory: string

Path to folder where to save model. For example ‘./experiments/mnist’.

kwargs:

Additional arguments to json.dump

test_epoch(device: device, dataloader: DataLoader)
train_epoch(device: device, dataloader: DataLoader, optimizer: Optimizer) ndarray
training: bool
class DecoderBurgess(img_size, latent_dim=10)

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
forward(z)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class DecoderMnist(encoded_space_dim)

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class EncoderBurgess(img_size, latent_dim=10)

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

mu(x)
training: bool
class EncoderMnist(encoded_space_dim)

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class SimCLR(base_encoder, projection_dim=128)

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
fit(args: DictConfig, device: device) None
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

static get_color_distortion(s=0.5)
static get_lr(step, total_steps, lr_max, lr_min)

Compute learning rate according to cosine annealing schedule.

static nt_xent(x, t=0.5)
training: bool
class VAE(img_size: tuple, encoder: EncoderBurgess, decoder: DecoderBurgess, latent_dim: int, loss_f: BaseVAELoss, name: str = 'model')

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
fit(device: device, train_loader: DataLoader, test_loader: DataLoader, save_dir: Path, n_epoch: int = 30, patience: int = 10) None
forward(x)

Forward pass of model.

Parameters:

xtorch.Tensor

Batch of data. Shape (batch_size, n_chan, height, width)

load_metadata(directory: Path) dict

Load the metadata of a training directory.

Parameters:

directorypathlib.Path

Path to folder where model is saved. For example ‘./experiments/mnist’.

reparameterize(mean, logvar)

Samples from a normal distribution using the reparameterization trick.

Parameters:

meantorch.Tensor

Mean of the normal distribution. Shape (batch_size, latent_dim)

logvartorch.Tensor

Diagonal log variance of the normal distribution. Shape (batch_size, latent_dim)

sample_latent(x)

Returns a sample from the latent distribution.

Parameters:

xtorch.Tensor

Batch of data. Shape (batch_size, n_chan, height, width)

save(directory: Path) None

Save a model and corresponding metadata.

Parameters:

directorypathlib.Path

Path to the directory where to save the data.

save_metadata(directory: Path, **kwargs) None

Load the metadata of a training directory.

Parameters:

directory: string

Path to folder where to save model. For example ‘./experiments/mnist’.

kwargs:

Additional arguments to json.dump

test_epoch(device: device, dataloader: DataLoader)
train_epoch(device: device, dataloader: DataLoader, optimizer: Optimizer) ndarray
training: bool
class VarDecoderMnist(c: int = 64, latent_dims: int = 10)

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class VarEncoderMnist(c: int = 64, latent_dims: int = 10)

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

mu(x)
training: bool
init_vae(img_size, latent_dim, loss_f, name)

Return an instance of a VAE with encoder and decoder from model_type.

log_density_gaussian(x: Tensor, mu: Tensor, logvar: Tensor)

Computes the log pdf of the Gaussian with parameters mu and logvar at x

Parameters
  • x – (Tensor) Point at whichGaussian PDF is to be evaluated

  • mu – (Tensor) Mean of the Gaussian distribution

  • logvar – (Tensor) Log variance of the Gaussian distribution

lfxai.models.time_series module

class AutoencoderCNN(embedding_dim: int = 64, name: str = 'model', loss_f: callable = L1Loss())

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
fit(device: device, train_loader: DataLoader, test_loader: DataLoader, save_dir: Path, n_epoch: int = 30, patience: int = 10, checkpoint_interval: int = -1) None
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

save(directory: Path) None

Save a model and corresponding metadata. Parameters: ———- directory : pathlib.Path

Path to the directory where to save the data.

test_epoch(device: device, dataloader: DataLoader)
train_epoch(device: device, dataloader: DataLoader, optimizer: Optimizer) ndarray
training: bool
class Decoder(seq_len: int, n_features: int = 1, input_dim: int = 64)

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class DecoderCNN(encoded_space_dim)

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class Encoder(seq_len: int, n_features: int = 1, embedding_dim: int = 64)

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class EncoderCNN(encoded_space_dim)

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class RecurrentAutoencoder(seq_len: int, n_features: int, embedding_dim: int = 64, name: str = 'model', loss_f: callable = L1Loss())

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
fit(device: device, train_loader: DataLoader, test_loader: DataLoader, save_dir: Path, n_epoch: int = 30, patience: int = 10, checkpoint_interval: int = -1) None
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

save(directory: Path) None

Save a model and corresponding metadata. Parameters: ———- directory : pathlib.Path

Path to the directory where to save the data.

test_epoch(device: device, dataloader: DataLoader)
train_epoch(device: device, dataloader: DataLoader, optimizer: Optimizer) ndarray
training: bool

lfxai.models.losses module

class BaseVAELoss(record_loss_every=50, rec_dist='bernoulli', steps_anneal=0)

Bases: ABC

Base class for losses.

Parameters:

record_loss_every: int, optional

Every how many steps to recorsd the loss.

rec_dist: {“bernoulli”, “gaussian”, “laplace”}, optional

Reconstruction distribution istribution of the likelihood on the each pixel. Implicitely defines the reconstruction loss. Bernoulli corresponds to a binary cross entropy (bse), Gaussian corresponds to MSE, Laplace corresponds to L1.

steps_anneal: nool, optional

Number of annealing steps where gradually adding the regularisation.

_abc_impl = <_abc_data object>
_pre_call(is_train, storer)
class BetaHLoss(beta=4, **kwargs)

Bases: BaseVAELoss

Compute the Beta-VAE loss as in [1]

Parameters:

betafloat, optional

Weight of the kl divergence.

kwargs:

Additional arguments for BaseLoss, e.g. rec_dist`.

References:

[1] Higgins, Irina, et al. “beta-vae: Learning basic visual concepts with a constrained variational framework.” (2016).

_abc_impl = <_abc_data object>
class BtcvaeLoss(n_data, alpha=1.0, beta=6.0, gamma=1.0, is_mss=True, **kwargs)

Bases: BaseVAELoss

Compute the decomposed KL loss with either minibatch weighted sampling or minibatch stratified sampling according to [1]

Parameters:

n_data: int

Number of data in the training set

alphafloat

Weight of the mutual information term.

betafloat

Weight of the total correlation term.

gammafloat

Weight of the dimension-wise KL term.

is_mssbool

Whether to use minibatch stratified sampling instead of minibatch weighted sampling.

kwargs:

Additional arguments for BaseLoss, e.g. rec_dist`.

References:

[1] Chen, Tian Qi, et al. “Isolating sources of disentanglement in variational autoencoders.” Advances in Neural Information Processing Systems. 2018.

_abc_impl = <_abc_data object>
_get_log_pz_qz_prodzi_qzCx(latent_sample, latent_dist, n_data, is_mss=True)
_kl_normal_loss(mean, logvar, storer=None)

Calculates the KL divergence between a normal distribution with diagonal covariance and a unit normal distribution.

Parameters:

meantorch.Tensor

Mean of the normal distribution. Shape (batch_size, latent_dim) where D is dimension of distribution.

logvartorch.Tensor

Diagonal log variance of the normal distribution. Shape (batch_size, latent_dim)

storerdict

Dictionary in which to store important variables for vizualisation.

_permute_dims(latent_sample)

Implementation of Algorithm 1 in ref [1]. Randomly permutes the sample from q(z) (latent_dist) across the batch for each of the latent dimensions (mean and log_var).

Parameters:

latent_sample: torch.Tensor

sample from the latent dimension using the reparameterisation trick shape : (batch_size, latent_dim).

References:

[1] Kim, Hyunjik, and Andriy Mnih. “Disentangling by factorising.” arXiv preprint arXiv:1802.05983 (2018).

_reconstruction_loss(data, recon_data, distribution='bernoulli', storer=None)

Calculates the per image reconstruction loss for a batch of data. I.e. negative log likelihood.

Parameters:

datatorch.Tensor

Input data (e.g. batch of images). Shape : (batch_size, n_chan, height, width).

recon_datatorch.Tensor

Reconstructed data. Shape : (batch_size, n_chan, height, width).

distribution{“bernoulli”, “gaussian”, “laplace”}

Distribution of the likelihood on the each pixel. Implicitely defines the loss Bernoulli corresponds to a binary cross entropy (bse) loss and is the most commonly used. It has the issue that it doesn’t penalize the same way (0.1,0.2) and (0.4,0.5), which might not be optimal. Gaussian distribution corresponds to MSE, and is sometimes used, but hard to train ecause it ends up focusing only a few pixels that are very wrong. Laplace distribution corresponds to L1 solves partially the issue of MSE.

storerdict

Dictionary in which to store important variables for vizualisation.

Returns:

losstorch.Tensor

Per image cross entropy (i.e. normalized per batch but not pixel and channel)

get_loss_f(loss_name, **kwargs_parse)

Return the correct loss function given the argparse arguments.

linear_annealing(init, fin, step, annealing_steps)

Linear annealing of a parameter.

Explanations

Explanations

lfxai.explanations.examples module

class ExampleBasedExplainer(model: Module, X_train: Tensor, loss_f: callable, **kwargs)

Bases: ABC

_abc_impl = <_abc_data object>
abstract attribute(X_test: Tensor, train_idx: list, **kwargs) Tensor
Parameters
  • X_test

  • train_idx

  • **kwargs

Returns:

class InfluenceFunctions(model: Module, loss_f: callable, save_dir: Path, X_train: Optional[Tensor] = None)

Bases: ExampleBasedExplainer, ABC

_abc_impl = <_abc_data object>
attribute(X_test: Tensor, train_idx: list, batch_size: int = 1, damp: float = 0.001, scale: float = 1000, recursion_depth: int = 100, **kwargs) Tensor

Code adapted from https://github.com/ahmedmalaa/torch-influence-functions/ This function applies the stochastic estimation approach to evaluating influence function based on the power-series approximation of matrix inversion. Recall that the exact inverse Hessian H^-1 can be computed as follows: H^-1 = sum^infty_{i=0} (I - H) ^ j This series converges if all the eigen values of H are less than 1.

Returns

list of torch tensors, contains product of Hessian and v.

Return type

return_grads

attribute_loader(device: device, train_loader: DataLoader, test_loader: DataLoader, train_loader_replacement: DataLoader, recursion_depth: int = 100, damp: float = 0.001, scale: float = 1000, **kwargs) Tensor
class NearestNeighbours(model: Module, loss_f: Optional[callable] = None, X_train: Optional[Tensor] = None)

Bases: ExampleBasedExplainer, ABC

_abc_impl = <_abc_data object>
attribute(X_test: Tensor, train_idx: list, batch_size: int = 500, **kwargs) Tensor
Parameters
  • X_test

  • train_idx

  • **kwargs

Returns:

attribute_loader(device: device, train_loader: DataLoader, test_loader: DataLoader, batch_size: int = 50, **kwargs) Tensor
class SimplEx(model: Module, loss_f: Optional[callable] = None, X_train: Optional[Tensor] = None)

Bases: ExampleBasedExplainer, ABC

_abc_impl = <_abc_data object>
attribute(X_test: Tensor, train_idx: list, learning_rate: float = 1, batch_size: int = 50, **kwargs) Tensor
Parameters
  • X_test

  • train_idx

  • **kwargs

Returns:

attribute_loader(device: device, train_loader: DataLoader, test_loader: DataLoader, batch_size: int = 50, **kwargs) Tensor
static compute_weights(batch_representations: Tensor, train_representations: Tensor, n_epoch: int = 1000) Tensor
class TracIn(model: Module, loss_f: callable, save_dir: Path, X_train: Optional[Tensor] = None)

Bases: ExampleBasedExplainer, ABC

_abc_impl = <_abc_data object>
attribute(X_test: Tensor, train_idx: list, learning_rate: float = 1, **kwargs) Tensor
Parameters
  • X_test

  • train_idx

  • **kwargs

Returns:

attribute_loader(device: device, train_loader: DataLoader, test_loader: DataLoader, **kwargs) Tensor

lfxai.explanations.features module

class AuxiliaryFunction(black_box: Module, base_features: Tensor)

Bases: Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
forward(input_features: Tensor) Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
attribute_auxiliary(encoder: Module, data_loader: DataLoader, device: device, attr_method: Attribution, baseline=None) ndarray
attribute_individual_dim(encoder: callable, dim_latent: int, data_loader: DataLoader, device: device, attr_method: Attribution, baseline: Tensor) ndarray