Commit d06884ee authored by mova's avatar mova
Browse files

fix early stopping criterion

parent cc21332e
......@@ -6,10 +6,22 @@ from .train_state import TrainState
def early_stopping(train_state: TrainState) -> bool:
"""Compare the last conf.training.early_stopping validation losses
with the validation losses before that.
If the minimum has not been reduced by
conf.training.early_stopping_improvement, stop the training"""
if len(train_state.state.val_losses) < conf.training.early_stopping + 1:
return False
relative_improvement = 1 - (
min(train_state.state.val_losses) / train_state.state.min_val_loss
min(train_state.state.val_losses[-conf.training.early_stopping :])
/ min(train_state.state.val_losses[: -conf.training.early_stopping])
)
logger.info(
f"""\
Relative Improvement in the last {conf.training.early_stopping} \
validation steps: {relative_improvement*100}%"""
)
if relative_improvement < conf.training.early_stopping_improvement:
train_state.holder.save_checkpoint()
train_state.writer.flush()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment