serotiny.datamodules.dataframe.dataframe_datamodule module#

class serotiny.datamodules.dataframe.dataframe_datamodule.DataframeDatamodule(path: UPath | str, transforms: Dict, split_column: str | None = None, columns: Sequence[str] | None = None, split_map: Dict | None = None, just_inference: bool = False, cache_dir: UPath | str | None = None, subsample: Dict | None = None, seed: int = 42, **dataloader_kwargs)[source]#

Bases: LightningDataModule

A pytorch lightning datamodule based on dataframes. It can either use a single dataframe file, which contains a column based on which a train- val- test split can be made; or it can use three dataframe files, one for each fold (train, val, test).

Additionally, if it is only going to be used for prediction/testing, a flag just_inference can be set to True so the splits are ignored and the whole dataset is used.

The predict_datamodule is simply constructed from the whole dataset, regardless of the value of just_inference.

Parameters:
  • path (Union[Path, str]) – Path to a dataframe file

  • transforms (Dict) – Transforms specifications for each given split.

  • split_column (Optional[str] = None) – Name of a column in the dataset which can be used to create train, val, test splits.

  • columns (Optional[Sequence[str]] = None) – List of columns to load from the dataset, in case it’s a parquet file. If None, load everything.

  • split_map (Optional[Dict] = None) – TODO: document this argument

  • just_inference (bool = False) – Whether this datamodule will be used for just inference (testing/prediction). If so, the splits are ignored and the whole dataset is used.

  • cache_dir (Optional[Union[Path, str]] = None) – Path to a directory in which to store cached transformed inputs, to accelerate batch loading.

  • subsample (Optional[Dict] = None) – Dictionary with a key per split (“train”, “val”, “test”), and the number of samples of each split to use per epoch. If None (default), use all the samples in each split per epoch.

  • dataloader_kwargs – Additional keyword arguments are passed to the torch.utils.data.DataLoader class when instantiating it (aside from shuffle which is only used for the train dataloader). Among these args are num_workers, batch_size, shuffle, etc. See the PyTorch docs for more info on these args: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

get_dataloader(split)[source]#
get_dataset(split)[source]#
make_dataloader(split)[source]#
predict_dataloader()[source]#
test_dataloader()[source]#
train_dataloader()[source]#
val_dataloader()[source]#