lfxai.models.losses module
- class BaseVAELoss(record_loss_every=50, rec_dist='bernoulli', steps_anneal=0)
Bases:
ABC
Base class for losses.
Parameters:
- record_loss_every: int, optional
Every how many steps to recorsd the loss.
- rec_dist: {“bernoulli”, “gaussian”, “laplace”}, optional
Reconstruction distribution istribution of the likelihood on the each pixel. Implicitely defines the reconstruction loss. Bernoulli corresponds to a binary cross entropy (bse), Gaussian corresponds to MSE, Laplace corresponds to L1.
- steps_anneal: nool, optional
Number of annealing steps where gradually adding the regularisation.
- _abc_impl = <_abc_data object>
- _pre_call(is_train, storer)
- class BetaHLoss(beta=4, **kwargs)
Bases:
BaseVAELoss
Compute the Beta-VAE loss as in [1]
Parameters:
- betafloat, optional
Weight of the kl divergence.
- kwargs:
Additional arguments for BaseLoss, e.g. rec_dist`.
References:
[1] Higgins, Irina, et al. “beta-vae: Learning basic visual concepts with a constrained variational framework.” (2016).
- _abc_impl = <_abc_data object>
- class BtcvaeLoss(n_data, alpha=1.0, beta=6.0, gamma=1.0, is_mss=True, **kwargs)
Bases:
BaseVAELoss
Compute the decomposed KL loss with either minibatch weighted sampling or minibatch stratified sampling according to [1]
Parameters:
- n_data: int
Number of data in the training set
- alphafloat
Weight of the mutual information term.
- betafloat
Weight of the total correlation term.
- gammafloat
Weight of the dimension-wise KL term.
- is_mssbool
Whether to use minibatch stratified sampling instead of minibatch weighted sampling.
- kwargs:
Additional arguments for BaseLoss, e.g. rec_dist`.
References:
[1] Chen, Tian Qi, et al. “Isolating sources of disentanglement in variational autoencoders.” Advances in Neural Information Processing Systems. 2018.
- _abc_impl = <_abc_data object>
- _get_log_pz_qz_prodzi_qzCx(latent_sample, latent_dist, n_data, is_mss=True)
- _kl_normal_loss(mean, logvar, storer=None)
Calculates the KL divergence between a normal distribution with diagonal covariance and a unit normal distribution.
Parameters:
- meantorch.Tensor
Mean of the normal distribution. Shape (batch_size, latent_dim) where D is dimension of distribution.
- logvartorch.Tensor
Diagonal log variance of the normal distribution. Shape (batch_size, latent_dim)
- storerdict
Dictionary in which to store important variables for vizualisation.
- _permute_dims(latent_sample)
Implementation of Algorithm 1 in ref [1]. Randomly permutes the sample from q(z) (latent_dist) across the batch for each of the latent dimensions (mean and log_var).
Parameters:
- latent_sample: torch.Tensor
sample from the latent dimension using the reparameterisation trick shape : (batch_size, latent_dim).
References:
[1] Kim, Hyunjik, and Andriy Mnih. “Disentangling by factorising.” arXiv preprint arXiv:1802.05983 (2018).
- _reconstruction_loss(data, recon_data, distribution='bernoulli', storer=None)
Calculates the per image reconstruction loss for a batch of data. I.e. negative log likelihood.
Parameters:
- datatorch.Tensor
Input data (e.g. batch of images). Shape : (batch_size, n_chan, height, width).
- recon_datatorch.Tensor
Reconstructed data. Shape : (batch_size, n_chan, height, width).
- distribution{“bernoulli”, “gaussian”, “laplace”}
Distribution of the likelihood on the each pixel. Implicitely defines the loss Bernoulli corresponds to a binary cross entropy (bse) loss and is the most commonly used. It has the issue that it doesn’t penalize the same way (0.1,0.2) and (0.4,0.5), which might not be optimal. Gaussian distribution corresponds to MSE, and is sometimes used, but hard to train ecause it ends up focusing only a few pixels that are very wrong. Laplace distribution corresponds to L1 solves partially the issue of MSE.
- storerdict
Dictionary in which to store important variables for vizualisation.
Returns:
- losstorch.Tensor
Per image cross entropy (i.e. normalized per batch but not pixel and channel)
- get_loss_f(loss_name, **kwargs_parse)
Return the correct loss function given the argparse arguments.
- linear_annealing(init, fin, step, annealing_steps)
Linear annealing of a parameter.