Source code for neuro_fuzzy_toolbox.training.early_stopping

[docs] class EarlyStopping(): """ Early stopping mechanism for halting the training of a machine learning model (``torch.nn.Module``) when no sufficient improvement is observed. """
[docs] def __init__(self, patience, delta=0, last_state=False): """ Initializes a new EarlyStopping instance. Args: patience (int): Number of epochs without improvement before stopping training. delta (float): Minimum improvement required to consider that the model has improved. Defaults to ``0``. last_state (bool): If ``True``, restores the last model state when stopping instead of the best state found during training. Defaults to ``False``. """ #Parameters self.patience = patience self.delta = delta self.last_state = last_state # True if the last state is the one to be restored (not the best one) #For running self._counter = 0 self._best_loss = None self._best_state_dict = None self._stop = False
[docs] def __call__(self, model, loss, verbose=False): """ Evaluates whether training should be stopped. If the stopping criterion is met, the ``stop`` attribute is updated to ``True``. Args: model (torch.nn.Module): Model to evaluate. loss (float): Current loss value of the model. verbose (bool): If ``True``, prints a warning message when early stopping is triggered. Defaults to ``False``. """ if self._best_loss is None: self._best_loss = loss self._best_state_dict = model.state_dict() elif loss + self.delta > self._best_loss: self._counter += 1 if self._counter >= self.patience: self._stop = True if verbose: print('\nEARLY STOPPING') if self.last_state == False: model.load_state_dict(self._best_state_dict) else: self._best_loss = loss self._best_state_dict = model.state_dict() self._counter = 0
[docs] def reset(self): """ Resets the early stopping mechanism to its initial state. """ self._counter = 0 self._best_loss = None self._best_state_dict = None self._stop = False
@property def stop(self): """ Indicates whether training should be stopped. Returns: bool: ``True`` if the stopping criterion has been met, ``False`` otherwise. """ return self._stop