lfxai.utils.datasets module
- class CIFAR10Pair(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)
Bases:
CIFAR10
Generate mini-batche pairs on CIFAR10 training set.
- class DSprites(root='/home/docs/checkouts/readthedocs.org/user_builds/lfxai/checkouts/latest/src/lfxai/utils/../data/dsprites/', **kwargs)
Bases:
DisentangledDataset
DSprites Dataset from [1]. Disentanglement test Sprites dataset.Procedurally generated 2D shapes, from 6 disentangled latent factors. This dataset uses 6 latents, controlling the color, shape, scale, rotation and position of a sprite. All possible variations of the latents are present. Ordering along dimension 1 is fixed and can be mapped back to the exact latent values that generated that image. Pixel outputs are different. No noise added. .. rubric:: Notes
hard coded metadata because issue with python 3 loading of python 2
Parameters:
- rootstring
Root directory of dataset.
References
- [1] Higgins, I., Matthey, L., Pal, A., Burgess, C., Glorot, X., Botvinick,
M., … & Lerchner, A. (2017). beta-vae: Learning basic visual concepts with a constrained variational framework. In International Conference on Learning Representations.
- _abc_impl = <_abc_data object>
- background_color = 0
- download()
Download the dataset.
- files = {'train': 'dsprite_train.npz'}
- img_size = (1, 64, 64)
- lat_names = ('shape', 'scale', 'orientation', 'posX', 'posY')
- lat_sizes = array([ 3, 6, 40, 32, 32])
- lat_values = {'color': array([1.]), 'orientation': array([0. , 0.16110732, 0.32221463, 0.48332195, 0.64442926, 0.80553658, 0.96664389, 1.12775121, 1.28885852, 1.44996584, 1.61107316, 1.77218047, 1.93328779, 2.0943951 , 2.25550242, 2.41660973, 2.57771705, 2.73882436, 2.89993168, 3.061039 , 3.22214631, 3.38325363, 3.54436094, 3.70546826, 3.86657557, 4.02768289, 4.1887902 , 4.34989752, 4.51100484, 4.67211215, 4.83321947, 4.99432678, 5.1554341 , 5.31654141, 5.47764873, 5.63875604, 5.79986336, 5.96097068, 6.12207799, 6.28318531]), 'posX': array([0. , 0.03225806, 0.06451613, 0.09677419, 0.12903226, 0.16129032, 0.19354839, 0.22580645, 0.25806452, 0.29032258, 0.32258065, 0.35483871, 0.38709677, 0.41935484, 0.4516129 , 0.48387097, 0.51612903, 0.5483871 , 0.58064516, 0.61290323, 0.64516129, 0.67741935, 0.70967742, 0.74193548, 0.77419355, 0.80645161, 0.83870968, 0.87096774, 0.90322581, 0.93548387, 0.96774194, 1. ]), 'posY': array([0. , 0.03225806, 0.06451613, 0.09677419, 0.12903226, 0.16129032, 0.19354839, 0.22580645, 0.25806452, 0.29032258, 0.32258065, 0.35483871, 0.38709677, 0.41935484, 0.4516129 , 0.48387097, 0.51612903, 0.5483871 , 0.58064516, 0.61290323, 0.64516129, 0.67741935, 0.70967742, 0.74193548, 0.77419355, 0.80645161, 0.83870968, 0.87096774, 0.90322581, 0.93548387, 0.96774194, 1. ]), 'scale': array([0.5, 0.6, 0.7, 0.8, 0.9, 1. ]), 'shape': array([1., 2., 3.])}
- urls = {'train': 'https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz?raw=true'}
- class DisentangledDataset(root, transforms_list=[], logger=<Logger lfxai.utils.datasets (WARNING)>)
Bases:
Dataset
,ABC
Base Class for disentangled VAE datasets.
Parameters:
- rootstring
Root directory of dataset.
- transforms_listlist
List of torch.vision.transforms to apply to the data when loading it.
- _abc_impl = <_abc_data object>
- abstract download()
Download the dataset.
- class ECG5000(dir: Path, train: bool = True, random_seed: int = 42, experiment: str = 'features')
Bases:
Dataset
- download()
Download the dataset.
- class MaskedMNIST(root: str, train: bool = True, masks: Optional[Tensor] = None)
Bases:
MNIST