serotiny.models.basic_model module#

class serotiny.models.basic_model.BasicModel(network: ~torch.nn.modules.module.Module | None = None, loss: ~torch.nn.modules.loss._Loss | None = None, x_label: str = 'x', y_label: str = 'y', optimizer: ~torch.optim.optimizer.Optimizer = <class 'torch.optim.adam.Adam'>, squeeze_y: bool = False, save_predictions: ~typing.Callable | None = None, fields_to_log: ~typing.Sequence | None = None, pretrained_weights: str | None = None, **kwargs)[source]#

Bases: BaseModel

A minimal Pytorch Lightning wrapper around generic Pytorch models.

Parameters:
  • network (Optional[nn.Module] = None) – The network to wrap

  • loss (Optional[Loss] = None) – The loss function to optimize for

  • x_label (str = “x”) – The key used to retrieve the input from dataloader batches

  • y_label (str = “y”) – The key used to retrieve the target from dataloader batches

  • optimizer (torch.optim.Optimizer = torch.optim.Adam) – The optimizer class

  • save_predictions (Optional[Callable] = None) – A function to save the results of serotiny predict

  • fields_to_log (Optional[Union[Sequence, Dict]] = None) – List of batch fields to store with the outputs. Use a list to log the same fields for every training stage (train, val, test, prediction). If a list is used, it is assumed to be for test and prediction only

  • pretrained_weights (Optional[str] = None) – Path to pretrained weights. If network is not None, this will be loaded via network.load_state_dict, otherwise it will be loaded via torch.load.

forward(x, **kwargs)[source]#
parse_batch(batch)[source]#