Skip to content
Snippets Groups Projects
Select Git revision
  • 99a5c193c8efc3fc26f0fcf71fd51bbe32bdc09f
  • master default
  • bing_issues#190_tf2
  • bing_tf2_convert
  • bing_issue#189_train_modular
  • simon_#172_integrate_weatherbench
  • develop
  • bing_issue#188_restructure_ambs
  • yan_issue#100_extract_prcp_data
  • bing_issue#170_data_preprocess_training_tf1
  • Gong2022_temperature_forecasts
  • bing_issue#186_clean_GMD1_tag
  • yan_issue#179_integrate_GZAWS_data_onfly
  • bing_issue#178_runscript_bug_postprocess
  • michael_issue#187_bugfix_setup_runscript_template
  • bing_issue#180_bugs_postprpocess_meta_postprocess
  • yan_issue#177_repo_for_CLGAN_gmd
  • bing_issue#176_integrate_weather_bench
  • michael_issue#181_eval_era5_forecasts
  • michael_issue#182_eval_subdomain
  • michael_issue#119_warmup_Horovod
  • bing_issue#160_test_zam347
  • ambs_v1
  • ambs_gmd_nowcasting_v1.0
  • GMD1
  • modular_booster_20210203
  • new_structure_20201004_v1.0
  • old_structure_20200930
28 results

vanilla_vae_model.py

Blame
  • user avatar
    b.gong authored
    ce378fdc
    History
    vanilla_vae_model.py 8.43 KiB
    import collections
    import functools
    import itertools
    from collections import OrderedDict
    import numpy as np
    import tensorflow as tf
    from tensorflow.python.util import nest
    from video_prediction import ops, flow_ops
    from video_prediction.models import BaseVideoPredictionModel
    from video_prediction.models import networks
    from video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat
    from video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell
    from video_prediction.utils import tf_utils
    from datetime import datetime
    from pathlib import Path
    from video_prediction.layers import layer_def as ld
    
    class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
        def __init__(self, mode='train', aggregate_nccl=None,hparams_dict=None,
                     hparams=None,**kwargs):
            super(VanillaVAEVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs)
            self.mode = mode
            self.learning_rate = self.hparams.lr
            self.nz = self.hparams.nz
            self.aggregate_nccl=aggregate_nccl
            self.gen_images_enc = None
            self.train_op = None
            self.summary_op = None
            self.recon_loss = None
            self.latent_loss = None
            self.total_loss = None
    
        def get_default_hparams_dict(self):
            """
            The keys of this dict define valid hyperparameters for instances of
            this class. A class inheriting from this one should override this
            method if it has a different set of hyperparameters.
    
            Returns:
                A dict with the following hyperparameters.
    
                batch_size: batch size for training.
                lr: learning rate. if decay steps is non-zero, this is the
                    learning rate for steps <= decay_step.
                end_lr: learning rate for steps >= end_decay_step if decay_steps
                    is non-zero, ignored otherwise.
                decay_steps: (decay_step, end_decay_step) tuple.
                max_steps: number of training steps.
                beta1: momentum term of Adam.
                beta2: momentum term of Adam.
                context_frames: the number of ground-truth frames to pass in at
                    start. Must be specified during instantiation.
                sequence_length: the number of frames in the video sequence,
                    including the context frames, so this model predicts
                    `sequence_length - context_frames` future frames. Must be
                    specified during instantiation.
            """
            default_hparams = super(VanillaVAEVideoPredictionModel, self).get_default_hparams_dict()
            print ("default hparams",default_hparams)
            hparams = dict(
                batch_size=16,
                lr=0.001,
                end_lr=0.0,
                decay_steps=(200000, 300000),
                lr_boundaries=(0,),
                max_steps=350000,
                nz=10,
                context_frames=-1,
                sequence_length=-1,
                clip_length=10, #Bing: TODO What is the clip_length, original is 10,
            )
            return dict(itertools.chain(default_hparams.items(), hparams.items()))
    
        def build_graph(self,x):
            
            #global_step = tf.train.get_or_create_global_step()
            #original_global_variables = tf.global_variables()
            # self.x = x["images"]
            #print ("self_x:",self.x)
            #tf.reset_default_graph()
            #self.x = tf.placeholder(tf.float32, [None,20,64,64,3])
            self.x = x["images"]
            self.global_step = tf.train.get_or_create_global_step()
            original_global_variables = tf.global_variables()
            #self.global_step = tf.Variable(0, name = 'global_step', trainable = False)
            #self.increment_global_step = tf.assign_add(self.global_step, 1, name = 'increment_global_step')
            self.x_hat, self.z_log_sigma_sq, self.z_mu = self.vae_arc_all()
            # Loss
            # Reconstruction loss
            # Minimize the cross-entropy loss
            #         epsilon = 1e-10
            #         recon_loss = -tf.reduce_sum(
            #             self.x[:,1:,:,:,:] * tf.log(epsilon+self.x_hat[:,:-1,:,:,:]) +
            #             (1-self.x[:,1:,:,:,:]) * tf.log(epsilon+1-self.x_hat[:,:-1,:,:,:]),
            #             axis=1
            #         )
    
            #        self.recon_loss = tf.reduce_mean(recon_loss)
            self.recon_loss = tf.reduce_mean(tf.square(self.x[:, 1:, :, :, 0] - self.x_hat[:, :-1, :, :, 0]))
    
            # Latent loss
            # KL divergence: measure the difference between two distributions
            # Here we measure the divergence between
            # the latent distribution and N(0, 1)
            latent_loss = -0.5 * tf.reduce_sum(
                1 + self.z_log_sigma_sq - tf.square(self.z_mu) -
                tf.exp(self.z_log_sigma_sq), axis = 1)
            self.latent_loss = tf.reduce_mean(latent_loss)
            self.total_loss = self.recon_loss + self.latent_loss
            self.train_op = tf.train.AdamOptimizer(
                learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step)
            # Build a saver
            #self.saver = tf.train.Saver(tf.global_variables())
            self.losses = {
                'recon_loss': self.recon_loss,
                'latent_loss': self.latent_loss,
                'total_loss': self.total_loss,
            }
    
            # Summary op
            self.loss_summary = tf.summary.scalar("recon_loss", self.recon_loss)
            self.loss_summary = tf.summary.scalar("latent_loss", self.latent_loss)
            self.loss_summary = tf.summary.scalar("total_loss", self.latent_loss)
            self.summary_op = tf.summary.merge_all()
            # H(x, x_hat) = -\Sigma x*log(x_hat) + (1-x)*log(1-x_hat)
            # self.ckpt = tf.train.Checkpoint(model=self.vae_arc2())
            # self.manager = tf.train.CheckpointManager(self.ckpt,self.checkpoint_dir,max_to_keep=3)
            self.outputs = {}
            self.outputs["gen_images"] = self.x_hat
            global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
            self.saveable_variables = [self.global_step] + global_variables
            return
    
    
        @staticmethod
        def vae_arc3(x,l_name=0,nz=16):
            seq_name = "sq_" + str(l_name) + "_"
            
            conv1 = ld.conv_layer(x, 3, 2, 8, seq_name + "encode_1")
            print("Encode_1_shape", conv1.shape)  # (?,2,2,8)
            # conv2
            conv2 = ld.conv_layer(conv1, 3, 1, 8, seq_name + "encode_2")  # (?,2,2,8)
            print("Encode 2_shape,", conv2.shape)
            # conv3
            conv3 = ld.conv_layer(conv2, 3, 2, 8, seq_name + "encode_3")  # (?,1,1,8)
            print("Encode 3_shape, ", conv3.shape)
            # flatten
            conv4 = tf.layers.Flatten()(conv3)
            print("Encode 4_shape, ", conv4.shape)
            conv3_shape = conv3.get_shape().as_list()
            print("conv4_shape",conv3_shape)
            
            z_mu = ld.fc_layer(conv4, hiddens = nz, idx = seq_name + "enc_fc4_m")
            z_log_sigma_sq = ld.fc_layer(conv4, hiddens = nz, idx = seq_name + "enc_fc4_m"'enc_fc4_sigma')
            eps = tf.random_normal(shape = tf.shape(z_log_sigma_sq), mean = 0, stddev = 1, dtype = tf.float32)
            z = z_mu + tf.sqrt(tf.exp(z_log_sigma_sq)) * eps
            
            z2 = ld.fc_layer(z, hiddens = conv3_shape[1] * conv3_shape[2] * conv3_shape[3], idx = seq_name + "decode_fc1")
            
            
            z3 = tf.reshape(z2, [-1, conv3_shape[1], conv3_shape[2], conv3_shape[3]])
            # conv5
            conv5 = ld.transpose_conv_layer(z3, 3, 2, 8,
                                            seq_name + "decode_5")  # (16,1,1,8)inputs, kernel_size, stride, num_features
            print("Decode 5 shape", conv5.shape)
            conv6  = ld.transpose_conv_layer(conv5, 3, 1, 8,
                                            seq_name + "decode_6")  # (16,1,1,8)inputs, kernel_size, stride, num_features
            
            # x_1
            x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, seq_name + "decode_8")  # set activation to linear
            print("X_hat", x_hat.shape)
            return x_hat, z_mu, z_log_sigma_sq, z
    
        def vae_arc_all(self):
            X = []
            z_log_sigma_sq_all = []
            z_mu_all = []
            for i in range(20):
                q, z_mu, z_log_sigma_sq, z = VanillaVAEVideoPredictionModel.vae_arc3(self.x[:, i, :, :, :], l_name=i, nz=self.nz)
                X.append(q)
                z_log_sigma_sq_all.append(z_log_sigma_sq)
                z_mu_all.append(z_mu)
            x_hat = tf.stack(X, axis = 1)
            z_log_sigma_sq_all = tf.stack(z_log_sigma_sq_all, axis = 1)
            z_mu_all = tf.stack(z_mu_all, axis = 1)
            print("X_hat", x_hat.shape)
            print("zlog_sigma_sq_all", z_log_sigma_sq_all.shape)
            return x_hat, z_log_sigma_sq_all, z_mu_all