lfxai.utils.datasets module

class CIFAR10Pair(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)

Bases: CIFAR10

Generate mini-batche pairs on CIFAR10 training set.

class DSprites(root='/home/docs/checkouts/readthedocs.org/user_builds/lfxai/checkouts/latest/src/lfxai/utils/../data/dsprites/', **kwargs)

Bases: DisentangledDataset

DSprites Dataset from [1]. Disentanglement test Sprites dataset.Procedurally generated 2D shapes, from 6 disentangled latent factors. This dataset uses 6 latents, controlling the color, shape, scale, rotation and position of a sprite. All possible variations of the latents are present. Ordering along dimension 1 is fixed and can be mapped back to the exact latent values that generated that image. Pixel outputs are different. No noise added. .. rubric:: Notes

Parameters:

rootstring

Root directory of dataset.

References

[1] Higgins, I., Matthey, L., Pal, A., Burgess, C., Glorot, X., Botvinick,

M., … & Lerchner, A. (2017). beta-vae: Learning basic visual concepts with a constrained variational framework. In International Conference on Learning Representations.

_abc_impl = <_abc_data object>
background_color = 0
download()

Download the dataset.

files = {'train': 'dsprite_train.npz'}
img_size = (1, 64, 64)
lat_names = ('shape', 'scale', 'orientation', 'posX', 'posY')
lat_sizes = array([ 3,  6, 40, 32, 32])
lat_values = {'color': array([1.]), 'orientation': array([0.        , 0.16110732, 0.32221463, 0.48332195, 0.64442926,        0.80553658, 0.96664389, 1.12775121, 1.28885852, 1.44996584,        1.61107316, 1.77218047, 1.93328779, 2.0943951 , 2.25550242,        2.41660973, 2.57771705, 2.73882436, 2.89993168, 3.061039  ,        3.22214631, 3.38325363, 3.54436094, 3.70546826, 3.86657557,        4.02768289, 4.1887902 , 4.34989752, 4.51100484, 4.67211215,        4.83321947, 4.99432678, 5.1554341 , 5.31654141, 5.47764873,        5.63875604, 5.79986336, 5.96097068, 6.12207799, 6.28318531]), 'posX': array([0.        , 0.03225806, 0.06451613, 0.09677419, 0.12903226,        0.16129032, 0.19354839, 0.22580645, 0.25806452, 0.29032258,        0.32258065, 0.35483871, 0.38709677, 0.41935484, 0.4516129 ,        0.48387097, 0.51612903, 0.5483871 , 0.58064516, 0.61290323,        0.64516129, 0.67741935, 0.70967742, 0.74193548, 0.77419355,        0.80645161, 0.83870968, 0.87096774, 0.90322581, 0.93548387,        0.96774194, 1.        ]), 'posY': array([0.        , 0.03225806, 0.06451613, 0.09677419, 0.12903226,        0.16129032, 0.19354839, 0.22580645, 0.25806452, 0.29032258,        0.32258065, 0.35483871, 0.38709677, 0.41935484, 0.4516129 ,        0.48387097, 0.51612903, 0.5483871 , 0.58064516, 0.61290323,        0.64516129, 0.67741935, 0.70967742, 0.74193548, 0.77419355,        0.80645161, 0.83870968, 0.87096774, 0.90322581, 0.93548387,        0.96774194, 1.        ]), 'scale': array([0.5, 0.6, 0.7, 0.8, 0.9, 1. ]), 'shape': array([1., 2., 3.])}
urls = {'train': 'https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz?raw=true'}
class DisentangledDataset(root, transforms_list=[], logger=<Logger lfxai.utils.datasets (WARNING)>)

Bases: Dataset, ABC

Base Class for disentangled VAE datasets.

Parameters:

rootstring

Root directory of dataset.

transforms_listlist

List of torch.vision.transforms to apply to the data when loading it.

_abc_impl = <_abc_data object>
abstract download()

Download the dataset.

class ECG5000(dir: Path, train: bool = True, random_seed: int = 42, experiment: str = 'features')

Bases: Dataset

download()

Download the dataset.

class MaskedMNIST(root: str, train: bool = True, masks: Optional[Tensor] = None)

Bases: MNIST