Commit 03702032 authored by mova's avatar mova
Browse files

train move validation and early_stopping to other files

parent a5caad3e
import sys
from ..config import conf
from ..utils.logger import logger
from .train_state import TrainState
def early_stopping(train_state: TrainState) -> None:
if (
train_state.state["grad_step"] != 0
and train_state.state["grad_step"] % conf.training.validation_interval == 0
):
# the the most recent losses
# dont stop for the first epochs
if len(train_state.state.val_losses) < conf.training.early_stopping:
return
recent_losses = train_state.state.val_losses[
-conf.training.early_stopping :
]
relative_improvement = 1 - (min(recent_losses) / recent_losses[0])
if relative_improvement < conf.training.early_stopping_improvement:
train_state.holder.save_models()
train_state.writer.flush()
train_state.writer.close()
logger.warn("Early Stopping criteria fullfilled")
if hasattr(train_state, "loader"):
train_state.loader.qfseq.drain_seq()
sys.exit()
import sys
import time
from copy import deepcopy
import torch
import torch_geometric
......@@ -10,11 +8,12 @@ from tqdm import tqdm
from ..config import conf, device
from ..io.queued_dataset import QueuedDataLoader
from ..monitor import setup_experiment, setup_writer
from ..utils.check_for_nans import check_chain_for_nans
from ..utils.logger import logger
from ..utils.move_batch_to_device import move_batch_to_device
from .early_stopping import early_stopping
from .holder import model_holder
from .train_state import TrainState
from .validate import validate
def training_step(
......@@ -30,70 +29,6 @@ def training_step(
train_state.holder.optim.step()
def validate(train_state: TrainState) -> None:
if train_state.state["grad_step"] % conf.training.validation_interval == 0:
check_chain_for_nans((train_state.holder.model,))
losses = []
for batch in tqdm(
train_state.loader.validation_batches, postfix="validating"
):
batch = batch.to(device)
prediction = torch.squeeze(train_state.holder.model(batch).T)
losses.append(train_state.holder.lossf(prediction, batch.y.float()))
mean_loss = torch.mean(torch.tensor(losses))
train_state.state.val_losses.append(float(mean_loss))
train_state.writer.add_scalar(
"val_loss", mean_loss, train_state.state["grad_step"]
)
train_state.experiment.log_metric(
"val_loss", mean_loss, train_state.state["grad_step"]
)
mean_loss = float(mean_loss)
if (
not hasattr(train_state.state, "min_val_loss")
or train_state.state.min_val_loss > mean_loss
):
train_state.state.min_val_loss = mean_loss
train_state.state.best_grad_step = train_state.state["grad_step"]
train_state.holder.best_model_state = deepcopy(
train_state.holder.model.state_dict()
)
assert train_state.state is train_state.holder.state
if (
train_state.state["grad_step"] != 0
and train_state.state["grad_step"] % conf.training.checkpoint_interval == 0
):
train_state.holder.save_models()
def early_stopping(train_state: TrainState) -> None:
if (
train_state.state["grad_step"] != 0
and train_state.state["grad_step"] % conf.training.validation_interval == 0
):
# the the most recent losses
# dont stop for the first epochs
if len(train_state.state.val_losses) < conf.training.early_stopping:
return
recent_losses = train_state.state.val_losses[
-conf.training.early_stopping :
]
relative_improvement = 1 - (min(recent_losses) / recent_losses[0])
if relative_improvement < conf.training.early_stopping_improvement:
train_state.holder.save_models()
train_state.writer.flush()
train_state.writer.close()
logger.warn("Early Stopping criteria fullfilled")
if hasattr(train_state, "loader"):
train_state.loader.qfseq.drain_seq()
sys.exit()
def training_procedure() -> None:
logger.warn(
"Starting training with state\n" + OmegaConf.to_yaml(model_holder.state)
......
from copy import deepcopy
import torch
from tqdm import tqdm
from ..config import conf, device
from ..utils.check_for_nans import check_chain_for_nans
from .train_state import TrainState
def validate(train_state: TrainState) -> None:
if train_state.state["grad_step"] % conf.training.validation_interval == 0:
check_chain_for_nans((train_state.holder.model,))
losses = []
for batch in tqdm(
train_state.loader.validation_batches, postfix="validating"
):
batch = batch.to(device)
prediction = torch.squeeze(train_state.holder.model(batch).T)
losses.append(train_state.holder.lossf(prediction, batch.y.float()))
mean_loss = torch.mean(torch.tensor(losses))
train_state.state.val_losses.append(float(mean_loss))
train_state.writer.add_scalar(
"val_loss", mean_loss, train_state.state["grad_step"]
)
train_state.experiment.log_metric(
"val_loss", mean_loss, train_state.state["grad_step"]
)
mean_loss = float(mean_loss)
if (
not hasattr(train_state.state, "min_val_loss")
or train_state.state.min_val_loss > mean_loss
):
train_state.state.min_val_loss = mean_loss
train_state.state.best_grad_step = train_state.state["grad_step"]
train_state.holder.best_model_state = deepcopy(
train_state.holder.model.state_dict()
)
assert train_state.state is train_state.holder.state
if (
train_state.state["grad_step"] != 0
and train_state.state["grad_step"] % conf.training.checkpoint_interval == 0
):
train_state.holder.save_models()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment