diff --git a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py index 7b0c462c3a2f7e17956c00906cb39582bb2e3db7..1fcbd1cf97442f1ea440039a0bb6769473b957f3 100644 --- a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py +++ b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py @@ -447,7 +447,7 @@ def main(): #Get prediction values feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)#return [batchsize,seq_len,lat,lon,channel] - assert gen_images.shape[1] = sequence_length-1 #The generate images seq_len should be sequence_len -1, since the last one is not used for comparing with groud truth + assert gen_images.shape[1] == sequence_length-1 #The generate images seq_len should be sequence_len -1, since the last one is not used for comparing with groud truth print("gen_images 20200822:",np.array(gen_images).shape) #Loop in batch size for i in range(args.batch_size):