Skip to content
Snippets Groups Projects
Commit 5fbdaf71 authored by stadtler1's avatar stadtler1
Browse files

Merge branch 'bing_issue#010_remove_hickle_split_data' into bing_issue#009_clean_up_postprocessing

parents f97bcafb bd656226
Branches
Tags
No related merge requests found
......@@ -160,11 +160,11 @@ def make_dataset_iterator(train_dataset, val_dataset, batch_size ):
def plot_train(train_losses,val_losses,output_dir):
epochs = list(range(len(train_losses)))
plt.plot(epochs, train_losses, 'g', label='Training loss')
plt.plot(epochs, val_losses, 'b', label='validation loss')
iterations = list(range(len(train_losses)))
plt.plot(iterations, train_losses, 'g', label='Training loss')
plt.plot(iterations, val_losses, 'b', label='validation loss')
plt.title('Training and Validation loss')
plt.xlabel('Epochs')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.savefig(os.path.join(output_dir,'plot_train.png'))
......
......@@ -69,6 +69,7 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
#print("self.x_hat_context_frames,",self.x_hat_context_frames)
#self.context_frames_loss = tf.reduce_mean(
# tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0]))
# This is the loss function (RMSE):
self.total_loss = tf.reduce_mean(
tf.square(self.x[:, self.context_frames:, :, :, 0] - self.x_hat_context_frames[:, (self.context_frames-1):-1, :, :, 0]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment