serotiny.datamodules.dataframe.dataframe_datamodule module#
- class serotiny.datamodules.dataframe.dataframe_datamodule.DataframeDatamodule(path: UPath | str, transforms: Dict, split_column: str | None = None, columns: Sequence[str] | None = None, split_map: Dict | None = None, just_inference: bool = False, cache_dir: UPath | str | None = None, subsample: Dict | None = None, seed: int = 42, **dataloader_kwargs)[source]#
Bases:
LightningDataModule
A pytorch lightning datamodule based on dataframes. It can either use a single dataframe file, which contains a column based on which a train- val- test split can be made; or it can use three dataframe files, one for each fold (train, val, test).
Additionally, if it is only going to be used for prediction/testing, a flag just_inference can be set to True so the splits are ignored and the whole dataset is used.
The predict_datamodule is simply constructed from the whole dataset, regardless of the value of just_inference.
- Parameters:
path (Union[Path, str]) – Path to a dataframe file
transforms (Dict) – Transforms specifications for each given split.
split_column (Optional[str] = None) – Name of a column in the dataset which can be used to create train, val, test splits.
columns (Optional[Sequence[str]] = None) – List of columns to load from the dataset, in case it’s a parquet file. If None, load everything.
split_map (Optional[Dict] = None) – TODO: document this argument
just_inference (bool = False) – Whether this datamodule will be used for just inference (testing/prediction). If so, the splits are ignored and the whole dataset is used.
cache_dir (Optional[Union[Path, str]] = None) – Path to a directory in which to store cached transformed inputs, to accelerate batch loading.
subsample (Optional[Dict] = None) – Dictionary with a key per split (“train”, “val”, “test”), and the number of samples of each split to use per epoch. If None (default), use all the samples in each split per epoch.
dataloader_kwargs – Additional keyword arguments are passed to the torch.utils.data.DataLoader class when instantiating it (aside from shuffle which is only used for the train dataloader). Among these args are num_workers, batch_size, shuffle, etc. See the PyTorch docs for more info on these args: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader