serotiny.models.basic_model module#
- class serotiny.models.basic_model.BasicModel(network: ~torch.nn.modules.module.Module | None = None, loss: ~torch.nn.modules.loss._Loss | None = None, x_label: str = 'x', y_label: str = 'y', optimizer: ~torch.optim.optimizer.Optimizer = <class 'torch.optim.adam.Adam'>, squeeze_y: bool = False, save_predictions: ~typing.Callable | None = None, fields_to_log: ~typing.Sequence | None = None, pretrained_weights: str | None = None, **kwargs)[source]#
Bases:
BaseModel
A minimal Pytorch Lightning wrapper around generic Pytorch models.
- Parameters:
network (Optional[nn.Module] = None) – The network to wrap
loss (Optional[Loss] = None) – The loss function to optimize for
x_label (str = “x”) – The key used to retrieve the input from dataloader batches
y_label (str = “y”) – The key used to retrieve the target from dataloader batches
optimizer (torch.optim.Optimizer = torch.optim.Adam) – The optimizer class
save_predictions (Optional[Callable] = None) – A function to save the results of serotiny predict
fields_to_log (Optional[Union[Sequence, Dict]] = None) – List of batch fields to store with the outputs. Use a list to log the same fields for every training stage (train, val, test, prediction). If a list is used, it is assumed to be for test and prediction only
pretrained_weights (Optional[str] = None) – Path to pretrained weights. If network is not None, this will be loaded via network.load_state_dict, otherwise it will be loaded via torch.load.