lfxai.models.losses module

class BaseVAELoss(record_loss_every=50, rec_dist='bernoulli', steps_anneal=0)

Bases: ABC

Base class for losses.

Parameters:

record_loss_every: int, optional

Every how many steps to recorsd the loss.

rec_dist: {“bernoulli”, “gaussian”, “laplace”}, optional

Reconstruction distribution istribution of the likelihood on the each pixel. Implicitely defines the reconstruction loss. Bernoulli corresponds to a binary cross entropy (bse), Gaussian corresponds to MSE, Laplace corresponds to L1.

steps_anneal: nool, optional

Number of annealing steps where gradually adding the regularisation.

_abc_impl = <_abc_data object>
_pre_call(is_train, storer)
class BetaHLoss(beta=4, **kwargs)

Bases: BaseVAELoss

Compute the Beta-VAE loss as in [1]

Parameters:

betafloat, optional

Weight of the kl divergence.

kwargs:

Additional arguments for BaseLoss, e.g. rec_dist`.

References:

[1] Higgins, Irina, et al. “beta-vae: Learning basic visual concepts with a constrained variational framework.” (2016).

_abc_impl = <_abc_data object>
class BtcvaeLoss(n_data, alpha=1.0, beta=6.0, gamma=1.0, is_mss=True, **kwargs)

Bases: BaseVAELoss

Compute the decomposed KL loss with either minibatch weighted sampling or minibatch stratified sampling according to [1]

Parameters:

n_data: int

Number of data in the training set

alphafloat

Weight of the mutual information term.

betafloat

Weight of the total correlation term.

gammafloat

Weight of the dimension-wise KL term.

is_mssbool

Whether to use minibatch stratified sampling instead of minibatch weighted sampling.

kwargs:

Additional arguments for BaseLoss, e.g. rec_dist`.

References:

[1] Chen, Tian Qi, et al. “Isolating sources of disentanglement in variational autoencoders.” Advances in Neural Information Processing Systems. 2018.

_abc_impl = <_abc_data object>
_get_log_pz_qz_prodzi_qzCx(latent_sample, latent_dist, n_data, is_mss=True)
_kl_normal_loss(mean, logvar, storer=None)

Calculates the KL divergence between a normal distribution with diagonal covariance and a unit normal distribution.

Parameters:

meantorch.Tensor

Mean of the normal distribution. Shape (batch_size, latent_dim) where D is dimension of distribution.

logvartorch.Tensor

Diagonal log variance of the normal distribution. Shape (batch_size, latent_dim)

storerdict

Dictionary in which to store important variables for vizualisation.

_permute_dims(latent_sample)

Implementation of Algorithm 1 in ref [1]. Randomly permutes the sample from q(z) (latent_dist) across the batch for each of the latent dimensions (mean and log_var).

Parameters:

latent_sample: torch.Tensor

sample from the latent dimension using the reparameterisation trick shape : (batch_size, latent_dim).

References:

[1] Kim, Hyunjik, and Andriy Mnih. “Disentangling by factorising.” arXiv preprint arXiv:1802.05983 (2018).

_reconstruction_loss(data, recon_data, distribution='bernoulli', storer=None)

Calculates the per image reconstruction loss for a batch of data. I.e. negative log likelihood.

Parameters:

datatorch.Tensor

Input data (e.g. batch of images). Shape : (batch_size, n_chan, height, width).

recon_datatorch.Tensor

Reconstructed data. Shape : (batch_size, n_chan, height, width).

distribution{“bernoulli”, “gaussian”, “laplace”}

Distribution of the likelihood on the each pixel. Implicitely defines the loss Bernoulli corresponds to a binary cross entropy (bse) loss and is the most commonly used. It has the issue that it doesn’t penalize the same way (0.1,0.2) and (0.4,0.5), which might not be optimal. Gaussian distribution corresponds to MSE, and is sometimes used, but hard to train ecause it ends up focusing only a few pixels that are very wrong. Laplace distribution corresponds to L1 solves partially the issue of MSE.

storerdict

Dictionary in which to store important variables for vizualisation.

Returns:

losstorch.Tensor

Per image cross entropy (i.e. normalized per batch but not pixel and channel)

get_loss_f(loss_name, **kwargs_parse)

Return the correct loss function given the argparse arguments.

linear_annealing(init, fin, step, annealing_steps)

Linear annealing of a parameter.