Commit 2a9e1c83 authored by mova's avatar mova
Browse files

valstep check for nans

parent 260343a7
......@@ -5,7 +5,7 @@ from tqdm import tqdm
from fgsim.config import conf, device
from fgsim.utils.batch_utils import move_batch_to_device
from fgsim.utils.check_for_nans import check_chain_for_nans
from fgsim.utils.check_for_nans import check_chain_for_nans, is_anormal_tensor
from fgsim.utils.logger import logger
from .train_state import TrainState
......@@ -22,6 +22,8 @@ def validate(train_state: TrainState) -> None:
with torch.no_grad():
prediction = torch.squeeze(train_state.holder.model(batch_gpu).T)
loss = train_state.holder.lossf(y=batch_gpu.y, yhat=prediction)
if is_anormal_tensor(loss):
raise ValueError
losses.append(loss)
del batch_gpu
......
......@@ -3,9 +3,13 @@ import torch
from .logger import logger
def is_anormal_tensor(inp: torch.Tensor) -> bool:
return bool(torch.any(torch.isinf(inp)) or torch.any(torch.isinf(inp)))
def contains_nans(inp, string=""):
if isinstance(inp, torch.Tensor):
res = torch.any(torch.isnan(inp))
res = is_anormal_tensor(inp)
return (res, string)
elif hasattr(inp, "state_dict"):
return contains_nans(inp.state_dict())
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment