Commit 45bea717 authored by mova's avatar mova
Browse files

rename config entries for early stopping

parent d5bf7b61
......@@ -45,8 +45,9 @@ model:
training:
events_processed_before_validation: 100000
validation_interval: '${div:${training.events_processed_before_validation},${loader.batch_size}}'
early_stopping: 10
early_stopping_improvement: 0.02
early_stopping:
validation_steps: 10
improvement: 0.02
yvar : "energy"
loss:
name: L1Loss
......
......@@ -6,23 +6,24 @@ from .train_state import TrainState
def early_stopping(train_state: TrainState) -> bool:
"""Compare the last conf.training.early_stopping validation losses
with the validation losses before that.
"""Compare the last `conf.training.early_stopping.validation_steps`
validation losses with the validation losses before that.
If the minimum has not been reduced by
conf.training.early_stopping_improvement, stop the training"""
if len(train_state.state.val_losses) < conf.training.early_stopping + 1:
`conf.training.early_stopping.improvement`, stop the training"""
valsteps = conf.training.early_stopping.validation_steps
if len(train_state.state.val_losses) < valsteps + 1:
return False
relative_improvement = 1 - (
min(train_state.state.val_losses[-conf.training.early_stopping :])
/ min(train_state.state.val_losses[: -conf.training.early_stopping])
min(train_state.state.val_losses[-valsteps:])
/ min(train_state.state.val_losses[:-valsteps])
)
logger.info(
f"""\
Relative Improvement in the last {conf.training.early_stopping} \
Relative Improvement in the last {valsteps} \
validation steps: {relative_improvement*100}%"""
)
if relative_improvement < conf.training.early_stopping_improvement:
if relative_improvement < conf.training.early_stopping.improvement:
train_state.holder.save_checkpoint()
train_state.writer.flush()
train_state.writer.close()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment