Skip to content
Snippets Groups Projects
Select Git revision
  • 710b634c27582e98f73570ca6735916a6ac1a89f
  • documentation default
  • master protected
  • integration
  • pre_update
5 results

Animations Using clear_output.ipynb

Blame
  • train_pred.py 33.78 KiB
    #import libraries
    import os
    import sys
    import pathlib
    import time
    import matplotlib
    matplotlib.use('Agg')
    from matplotlib import pyplot as plt
    import numpy as np
    import h5py
    import datetime
    import data
    import spectral as sp
    import scipy.ndimage as ndi
    import horovod.tensorflow as hvd
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
    #os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    import tensorflow as tf
    from DiffAugment_tf import DiffAugment
    import tensorflow_addons as tfa
    from models.resnet import resnet_18
    from test_normalization import scale_data, normalize_data, normalize_data_withperc, clip_data, get_percentiles_mean
    from model_blocks import Generator, Generator_conv_S2inp, Discriminator, PATCH_30m_DIM, PATCH_20m_DIM, PATCH_10m_DIM
    #from mpi4py import MPI
    #comm = MPI.COMM_WORLD
    
    hvd.init()
    print("local rank", hvd.local_rank())
    print("rank", hvd.rank())
    
    # Pin GPU to be used to process local rank (one GPU per process)
    gpus = tf.config.experimental.list_physical_devices('GPU')
    print("gpus: ", gpus)
    
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
        #print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))
    if gpus:
        tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
    
    print ("hvd size", hvd.size())
    #print("comm size", comm.Get_size())
    output_dir = "output_diffaug_128_l1loss_eta100_noinstgen_bench"
    sample_dir = os.path.join(output_dir, 'samples_training')
    generated_dir = os.path.join(sample_dir, 'generated')
    sample_temp_dir = os.path.join(sample_dir, 'temp')
    pathlib.Path(sample_temp_dir).mkdir(parents=True, exist_ok=True) 
    sample_test = os.path.join(sample_dir, 'test')
    pathlib.Path(sample_test).mkdir(parents=True, exist_ok=True) 
    
    load_pretrain = False
    TRAIN = True
    CLIP_NORM = True
    PRED_INTERM = False
    compute_percentiles = False
    BUFFER_SIZE = 400
    if TRAIN == True:
            BATCH_SIZE = 2
    else:
            BATCH_SIZE = 1
    POOL_SIZE = 10
    PATCH_30m_DIM = 128
    PATCH_20m_DIM = 192
    PATCH_10m_DIM = 384
    IMG_HEIGHT = 256
    policy = 'color,translation,cutout'
    #policy = 'translation,cutout'
    print("inizio")
    
    # ==============================================================================
    # =                                    VGG-19                                  =
    # ==============================================================================
    
    VGG_model = tf.keras.applications.VGG16(
        include_top=False, weights='imagenet', input_tensor=None,
        input_shape=None, pooling=None
    )
    
    #latest = tf.train.Checkpoint("/p/project/joaiml/remote_sensing/rocco_sedona/ben_TF2/scripts/checkpoints/checkpoint-50.h5")
    source_model = resnet_18()
    source_model.build(input_shape=(None, 120, 120, 6))
    source_model.load_weights("/p/project/joaiml/remote_sensing/rocco_sedona/rslab_col/network_code/pix2pix/checkpoint_resnet18/checkpoint-50.h5")
    perceptual_model = tf.keras.Sequential()
    for layer in source_model.layers[:-2]: # go through until last layer
        perceptual_model.add(layer)
    perceptual_model.build(input_shape=(None, 120, 120, 6))
    perceptual_model.summary()
    
    # ==============================================================================
    # =                                    data                                    =
    # ==============================================================================
    #load retiled data
    """
    A_data =  np.load('data_numpy/A_data_30_40.npy').astype(np.float32)
    B_data_10m =  np.load('data_numpy/B_data_10_120.npy').astype(np.float32)
    B_data_20m =  np.load('data_numpy/B_data_20_60.npy').astype(np.float32)
    """
    #data patches
    #path_L8 = "/p/scratch/joaiml/sedona3/L8_Images/L8_patches_new.h5"
    path_S2 = "../../data/tile_33UVP/timeseries_project/S2_patches_128_shifted_4r.h5"
    #path_L8 = "/p/scratch/joaiml/sedona3/L8_Images/L8_patches_new.h5"
    #path_L8 = "../../data/tile_33UVP/timeseries_project/L8_patches_128.h5"  
    path_L8 = '../../data/tile_33UVP/timeseries_project/L8_patches_128_water.h5'
    
    h5_file_S2 = h5py.File(path_S2, "r")
    h5_file_L8 = h5py.File(path_L8, "r")
    
    val_S2 = h5_file_S2["val"][:]
    val_S2_scl = h5_file_S2["val_scl"][:]
    
    val_L8 = h5_file_L8["val"][:]
    val_L8_qa = h5_file_L8["val_qa"][:]
    
    validity = (val_L8_qa==1) & (val_L8==1) & (val_S2_scl==1) & (val_S2==1)
    
    interm_L8 = h5_file_L8["interm"][:]
    interm_S2 = h5_file_S2["interm"][:]
    date_L8 = h5_file_L8["date"][:]
    #len_entire_dataset = np.count_nonzero(validity==True)
    len_entire_dataset = validity.shape[0]
    #length dataset for training
    len_dataset = np.floor(len_entire_dataset * 0.7).astype(int)
    num_steps = len_dataset // (hvd.size() * BATCH_SIZE)
    #length dataset for test
    #len_test = len_entire_dataset - len_dataset
    if TRAIN == True:
            len_test = 1
            date = 0
            date_condition = (date_L8 == 0) | (date_L8 == 1) | (date_L8 == 2) | (date_L8 == 4) | (date_L8 == 5)
    else:
            #len_test = len_entire_dataset - len_dataset
            date = 3
            date_condition = date_L8 == date
            len_test = np.count_nonzero(date_condition)
            print(len_test)
            len_dataset = 0
    
    #number of non intermediate patches
    #num_non_interm = np.count_nonzero(interm_L8==0)
    num_non_interm = len_test
    #num_non_interm_train = int(num_non_interm * 0.1)
    #num_non_interm_test = num_non_interm - num_non_interm_train
    size_chunk = int(len_dataset/hvd.size())
    #size_chunk = 1000
    if TRAIN == True:
            size_chunk_test = 1
    else:
            size_chunk_test = int(len_test/hvd.size())
    #size_chunk_test = 100
    validity_train = validity[date_condition]
    validity_train = validity_train[(hvd.rank()*size_chunk):((hvd.rank()+1)*size_chunk)]
    if TRAIN == True:
            A_data_train = h5_file_L8["data_30m"][date_condition,:,:,:]
            A_data_train = A_data_train[(hvd.rank()*size_chunk):((hvd.rank()+1)*size_chunk)].astype(np.float32)
            B_data_train_10m = h5_file_S2["data_10m"][date_condition,:,:,:]
            B_data_train_10m = B_data_train_10m[(hvd.rank()*size_chunk):((hvd.rank()+1)*size_chunk)].astype(np.float32)
            B_data_train_20m = h5_file_S2["data_20m"][date_condition,:,:,:]
            B_data_train_20m = B_data_train_20m[(hvd.rank()*size_chunk):((hvd.rank()+1)*size_chunk)].astype(np.float32) 
            A_data_train[A_data_train<0] = 0
            #B_data = B_data * 0.0001
            B_data_train_10m[B_data_train_10m<0] = 0
            B_data_train_20m[B_data_train_20m<0] = 0
    
            validity_test = validity[(len_dataset+(hvd.rank()*size_chunk_test)):(len_dataset+((hvd.rank()+1)*size_chunk_test))]
            A_data_test = h5_file_L8["data_30m"][(len_dataset+(hvd.rank()*size_chunk_test)):(len_dataset+((hvd.rank()+1)*size_chunk_test))].astype(np.float32)
            B_data_test_10m = h5_file_S2["data_10m"][(len_dataset+(hvd.rank()*size_chunk_test)):(len_dataset+((hvd.rank()+1)*size_chunk_test))].astype(np.float32) 
            B_data_test_20m = h5_file_S2["data_20m"][(len_dataset+(hvd.rank()*size_chunk_test)):(len_dataset+((hvd.rank()+1)*size_chunk_test))].astype(np.float32) 
    else:
            validity_test = validity[(date*size_chunk_test):((date+1)*size_chunk_test)]
            A_data_test = h5_file_L8["data_30m"][(date*size_chunk_test):((date+1)*size_chunk_test)].astype(np.float32)
            B_data_test_10m = h5_file_S2["data_10m"][(date*size_chunk_test):((date+1)*size_chunk_test)].astype(np.float32) 
            B_data_test_20m = h5_file_S2["data_20m"][(date*size_chunk_test):((date+1)*size_chunk_test)].astype(np.float32) 
    #scale input data (valid range 0-10000)
    #A_data = A_data * 0.0001 
    A_data_test[A_data_test<0] = 0
    B_data_test_10m[B_data_test_10m<0] = 0
    B_data_test_20m[B_data_test_20m<0] = 0
    
    if compute_percentiles == True:
            arr_p1A, arr_p99A, arr_p5A, arr_p95A, arr_meanA = get_percentiles_mean(h5_file_L8, True)
            arr_p1B1, arr_p99B1, arr_p5B1, arr_p95B1, arr_meanB1, arr_p1B2, arr_p99B2, arr_p5B2, arr_p95B2, arr_meanB2 = get_percentiles_mean(h5_file_S2, False)
    
            np.save("data_numpy/p1A.npy", arr_p1A)
            np.save("data_numpy/p99A.npy", arr_p99A)
            np.save("data_numpy/p5A.npy", arr_p5A)
            np.save("data_numpy/p95A.npy", arr_p95A)
            np.save("data_numpy/meanA.npy", arr_meanA)
    
            np.save("data_numpy/arr_p1B1.npy", arr_p1B1)
            np.save("data_numpy/arr_p99B1.npy", arr_p99B1)
            np.save("data_numpy/arr_p5B1.npy", arr_p5B1)
            np.save("data_numpy/arr_p95B1.npy", arr_p95B1)
            np.save("data_numpy/arr_meanB1.npy", arr_meanB1)
    
            np.save("data_numpy/arr_p1B2.npy", arr_p1B2)
            np.save("data_numpy/arr_p99B2.npy", arr_p99B2)
            np.save("data_numpy/arr_p5B2.npy", arr_p5B2)
            np.save("data_numpy/arr_p95B2.npy", arr_p95B2)
            np.save("data_numpy/arr_meanB2.npy", arr_meanB2)
    else:
            arr_p1A = np.load("data_numpy/p1A.npy")
            p1A = arr_p1A.mean(axis = 0).astype(int)
            arr_p99A = np.load("data_numpy/p99A.npy")
            p99A = arr_p99A.mean(axis = 0).astype(int)
            arr_p5A = np.load("data_numpy/p5A.npy")
            p5A = arr_p5A.mean(axis = 0).astype(int)
            arr_p95A = np.load("data_numpy/p95A.npy")
            p95A = arr_p95A.mean(axis = 0).astype(int)
            arr_meanA = np.load("data_numpy/meanA.npy")
            meanA = arr_meanA.mean(axis = 0)
    
            arr_p1B1 = np.load("data_numpy/arr_p1B1.npy")
            p1B1 = arr_p1B1.mean(axis = 0).astype(int)
            arr_p99B1 = np.load("data_numpy/arr_p99B1.npy")
            p99B1 = arr_p99B1.mean(axis = 0).astype(int)
            arr_p95B1 = np.load("data_numpy/arr_p95B1.npy")
            p95B1 = arr_p95B1.mean(axis = 0).astype(int)
            arr_p5B1 = np.load("data_numpy/arr_p5B1.npy")
            p5B1 = arr_p5B1.mean(axis = 0).astype(int)
            arr_meanB1 = np.load("data_numpy/arr_meanB1.npy")
            meanB1 = arr_meanB1.mean(axis = 0)
    
            arr_p1B2 = np.load("data_numpy/arr_p1B2.npy")
            p1B2 = arr_p1B2.mean(axis = 0).astype(int)
            arr_p99B2 = np.load("data_numpy/arr_p99B2.npy")
            p99B2 = arr_p99B2.mean(axis = 0).astype(int)
            arr_p95B2 = np.load("data_numpy/arr_p95B2.npy")
            p95B2 = arr_p95B2.mean(axis = 0).astype(int)
            arr_p5B2 = np.load("data_numpy/arr_p5B2.npy")
            p5B2 = arr_p5B2.mean(axis = 0).astype(int)
            arr_meanB2 = np.load("data_numpy/arr_meanB2.npy")
            meanB2 = arr_meanB2.mean(axis = 0)
    
    if TRAIN == True:
            
            #use this only during training
            A_data_train = A_data_train[validity_train,:,:,1:7]
            B_data_train_10m = B_data_train_10m[validity_train]
            B_data_train_20m = B_data_train_20m[validity_train,:,:,0:2]
    
            A_data_test = A_data_test[validity_test,:,:,1:7]
            B_data_test_10m = B_data_test_10m[validity_test]
            B_data_test_20m = B_data_test_20m[validity_test,:,:,0:2]
    #use this only during test, otherwise it would train using also cloudy or snowy patches: add an if condition
    elif PRED_INTERM == True:
            interm_L8_train = interm_L8[:len_dataset]
            interm_L8_test = interm_L8[len_dataset:]
            A_data_train = A_data_train[interm_L8_train==0,:,:,1:7]
            B_data_train_10m = B_data_train_10m[interm_L8_train==0]
            B_data_train_20m = B_data_train_20m[interm_L8_train==0,:,:,0:2]
    
            A_data_test = A_data_test[interm_L8_test==0,:,:,1:7]
            B_data_test_10m = B_data_test_10m[interm_L8_test==0]
            B_data_test_20m = B_data_test_20m[interm_L8_test==0,:,:,0:2]
    else:
            #A_data_train = A_data_train[date_condition,:,:,1:7]
            #B_data_train_10m = B_data_train_10m[date_condition]
            #B_data_train_20m = B_data_train_20m[date_condition,:,:,0:2]
    
            A_data_test = A_data_test[:,:,:,1:7]
            B_data_test_10m = B_data_test_10m[:]
            B_data_test_20m = B_data_test_20m[:,:,:,0:2]
    #saturate extremum values, below and above percentiles 1 and 99
    print("before clipping")
    if (TRAIN == True) & (CLIP_NORM == True):
            A_data_train = clip_data(A_data_train, p1A, p99A).astype(np.float32)
            B_data_train_10m = clip_data(B_data_train_10m, p1B1, p99B1).astype(np.float32)
            B_data_train_20m = clip_data(B_data_train_20m, p1B2, p99B2).astype(np.float32)
    
    if CLIP_NORM == True:
            A_data_test = clip_data(A_data_test, p1A, p99A)
            B_data_test_10m = clip_data(B_data_test_10m, p1B1, p99B1)
            B_data_test_20m = clip_data(B_data_test_20m, p1B2, p99B2)
    
    print("A_data_test shape: ", A_data_test.shape)
    print("B_data_test_10m shape: ", B_data_test_10m.shape)
    print("B_data_test_20m shape: ", B_data_test_20m.shape)
    
    h5_file_S2.close()
    h5_file_L8.close()
    
    """
    #added percentile 1% and 99% to saturate extrema, astype(np.float32)
    A_data_train, minA, maxA = scale_data(A_data_train)
    B_data_train_10m, minB1, maxB1 = scale_data(B_data_train_10m)
    B_data_train_20m, minB2, maxB2 = scale_data(B_data_train_20m)
    A_data_test = scale_data_withperc(A_data_test, minA, maxA)
    B_data_test_10m = scale_data_withperc(B_data_test_10m, minB1, maxB1)
    B_data_test_20m = scale_data_withperc(B_data_test_20m, minB2, maxB2)
    """
    print("before normalizing")
    
    #normalize data with percentile
    if (TRAIN == True) & (CLIP_NORM == True):
            A_data_train = normalize_data_withperc(A_data_train, p5A.astype(np.float32), p95A.astype(np.float32), meanA.astype(np.float32))
            B_data_train_10m = normalize_data_withperc(B_data_train_10m, p5B1.astype(np.float32), p95B1.astype(np.float32), meanB1.astype(np.float32))
            B_data_train_20m = normalize_data_withperc(B_data_train_20m, p5B2.astype(np.float32), p95B2.astype(np.float32), meanB2.astype(np.float32))
    
    if CLIP_NORM == True:
            A_data_test = normalize_data_withperc(A_data_test, p5A, p95A, meanA)
            B_data_test_10m = normalize_data_withperc(B_data_test_10m, p5B1, p95B1, meanB1)
            B_data_test_20m = normalize_data_withperc(B_data_test_20m, p5B2, p95B2, meanB2)
    
    print("A_data_test shape: ", A_data_test.shape)
    print("B_data_test_10m shape: ", B_data_test_10m.shape)
    print("B_data_test_20m shape: ", B_data_test_20m.shape)
    
    #normalize
    """
    A_mean = np.mean(A_data_train, (0,1,2))
    #A_std = (np.std(A_data_train, (0,1,2))) * 10
    A_std = (np.std(A_data_train, (0,1,2))) * 10
    A_data_train = (A_data_train - A_mean) / A_std
    A_data_test = (A_data_test - A_mean) / A_std
    print("A_std: ", A_std, " A_mean: ", A_mean)
    print("A max: ", A_data_train.max(axis=3), " A min:", A_data_train.min(axis=3))
    B_mean_10m = np.mean(B_data_train_10m, (0,1,2))
    #B_std_10m = (np.std(B_data_train_10m, (0,1,2))) * 10
    B_std_10m = (np.std(B_data_train_10m, (0,1,2))) * 10
    print("B_std_10m: ", B_std_10m, " B_mean_10m: ", B_mean_10m)
    B_data_train_10m = (B_data_train_10m - B_mean_10m) / B_std_10m
    B_data_test_10m = (B_data_test_10m - B_mean_10m) / B_std_10m
    print("B10 max: ", B_data_train_10m.max(), " B10 min:", B_data_train_10m.min())
    B_mean_20m = np.mean(B_data_train_20m, (0,1,2))
    #B_std_20m = (np.std(B_data_train_20m, (0,1,2))) * 10
    B_std_20m = (np.std(B_data_train_20m, (0,1,2))) * 10
    print("B_std_20m: ", B_std_20m, " B_mean_20m: ", B_mean_20m)
    B_data_train_20m = (B_data_train_20m - B_mean_20m) / B_std_20m
    B_data_test_20m = (B_data_test_20m - B_mean_20m) / B_std_20m
    print("B20 max: ", B_data_train_20m.max(), " B20 min:", B_data_train_20m.min())
    """
    """
    #normalize
    A_mean = meanA
    #A_std = (np.std(A_data_train, (0,1,2))) * 10
    A_std = (np.std(A_data_train, (0,1,2))) * 1
    A_data_train = (A_data_train - A_mean) / A_std
    A_data_test = (A_data_test - A_mean) / A_std
    B_mean_10m = meanB1
    #B_std_10m = (np.std(B_data_train_10m, (0,1,2))) * 10
    B_std_10m = (np.std(B_data_train_10m, (0,1,2))) * 1
    B_data_train_10m = (B_data_train_10m - B_mean_10m) / B_std_10m
    B_data_test_10m = (B_data_test_10m - B_mean_10m) / B_std_10m
    B_mean_20m = meanB2
    #B_std_20m = (np.std(B_data_train_20m, (0,1,2))) * 10
    B_std_20m = (np.std(B_data_train_20m, (0,1,2))) * 1
    B_data_train_20m = (B_data_train_20m - B_mean_20m) / B_std_20m
    B_data_test_20m = (B_data_test_20m - B_mean_20m) / B_std_20m
    
    A_data_train = A_data[(hvd.local_rank()*size_chunk):((hvd.local_rank()+1)*size_chunk)]
    B_data_train_10m = B_data_10m[(hvd.local_rank()*size_chunk):((hvd.local_rank()+1)*size_chunk)]
    B_data_train_20m = B_data_20m[(hvd.local_rank()*size_chunk):((hvd.local_rank()+1)*size_chunk)]
    """
    if TRAIN == True:
            A_data_train = tf.data.Dataset.from_tensor_slices((A_data_train)).shard(hvd.size(), hvd.rank()).batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)
            B_data_train_10m = tf.data.Dataset.from_tensor_slices((B_data_train_10m)).shard(hvd.size(), hvd.rank()).batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)
            B_data_train_20m = tf.data.Dataset.from_tensor_slices((B_data_train_20m)).shard(hvd.size(), hvd.rank()).batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)
            A_B_dataset = tf.data.Dataset.zip((B_data_train_10m, B_data_train_20m, A_data_train))
    
    #A2B_pool = data.ItemPool(POOL_SIZE)
    #B2A_pool = data.ItemPool(POOL_SIZE)
    print("A_data_test shape: ", A_data_test.shape)
    print("B_data_test_10m shape: ", B_data_test_10m.shape)
    print("B_data_test_20m shape: ", B_data_test_20m.shape)
    A_data_test = tf.data.Dataset.from_tensor_slices((A_data_test)).batch(1)
    B_data_test_10m = tf.data.Dataset.from_tensor_slices((B_data_test_10m)).batch(1)
    B_data_test_20m = tf.data.Dataset.from_tensor_slices((B_data_test_20m)).batch(1)
    A_B_dataset_test = tf.data.Dataset.zip((B_data_test_10m, B_data_test_20m, A_data_test))
    
    #original_images_30m = np.ones([size_chunk, 40, 40, 6])
    #original_images_10m = np.ones([size_chunk, 120, 120, 4])
    #original_images_20m = np.ones([size_chunk, 60, 60, 2]
    #predicted_images = np.ones([size_chunk, 40, 40, 6])
    print("data done on rank: ", hvd.rank())
    #hvd.allreduce([0], name="Barrier")
    if hvd.rank()==0:
       path_pred_fold = "/p/project/joaiml/remote_sensing/rocco_sedona/rslab_col/data/tile_33UVP/timeseries_project/" + output_dir
       pathlib.Path(path_pred_fold).mkdir(parents=True, exist_ok=True) 
       path_pred = path_pred_fold + "/date_" + str(date) + ".h5"
       h5_file_pred = h5py.File(path_pred, "w")
       """
       d_pred = h5_file_pred.create_dataset("data_pred", shape=[len_test, 40, 40, 6],dtype=np.float32)
       d_or_30m = h5_file_pred.create_dataset("data_30m", shape=[len_test, 40, 40, 6],dtype=np.float32)
       d_or_10m = h5_file_pred.create_dataset("data_10m", shape=[len_test, 120, 120, 4],dtype=np.float32)
       d_or_20m = h5_file_pred.create_dataset("data_20m", shape=[len_test, 60, 60, 2],dtype=np.float32)
       """
       d_pred = h5_file_pred.create_dataset("data_pred", shape=[num_non_interm, PATCH_30m_DIM, PATCH_30m_DIM, 6],dtype=np.float32)
       d_or_30m = h5_file_pred.create_dataset("data_30m", shape=[num_non_interm, PATCH_30m_DIM, PATCH_30m_DIM, 6],dtype=np.float32)
       d_or_10m = h5_file_pred.create_dataset("data_10m", shape=[num_non_interm, PATCH_10m_DIM, PATCH_10m_DIM, 4],dtype=np.float32)
       d_or_20m = h5_file_pred.create_dataset("data_20m", shape=[num_non_interm, PATCH_20m_DIM, PATCH_20m_DIM, 2],dtype=np.float32)
    barrier = hvd.allreduce(tf.random.normal(shape=[1]))
    # ==============================================================================
    # =                                    model                                   =
    # ==============================================================================
    def meanSquaredLoss(y_true,y_pred):
        return tf.reduce_mean(tf.keras.losses.MSE(y_true,y_pred))
    
    def gram_matrix(input_tensor):
        result = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor)
        input_shape = tf.shape(input_tensor)
        num_locations = tf.cast(input_shape[1]*input_shape[2], tf.float32)
        return result/(num_locations)
    
    def styleLoss(generated_image, true_image):
        activatedModelVal = perceptual_model(generated_image,training=False)
        actualModelVal = perceptual_model(true_image,training=False)
        return meanSquaredLoss(gram_matrix(actualModelVal),gram_matrix(activatedModelVal))
    
    def gramMatrixLoss(generated_image, true_image):
        return meanSquaredLoss(gram_matrix(generated_image),gram_matrix(true_image))
    
    def featureLoss(generated_image, true_image):
        activatedModelVal = perceptual_model(generated_image,training=False)
        actualModelVal = perceptual_model(true_image,training=False)
        return meanSquaredLoss(actualModelVal,activatedModelVal)
    
    def featureLossVGG(generated_image, true_image):
        activatedModelVal_RGB = VGG_model(generated_image[:,:,:,0:3],training=False)
        actualModelVal_RGB = VGG_model(true_image[:,:,:,0:3],training=False)
        activatedModelVal_IR = VGG_model(generated_image[:,:,:,3:6],training=False)
        actualModelVal_IR = VGG_model(true_image[:,:,:,3:6],training=False)
        loss_RGB = meanSquaredLoss(actualModelVal_RGB,activatedModelVal_RGB)
        loss_IR = meanSquaredLoss(actualModelVal_IR,activatedModelVal_IR)
        loss_RGB_IR = loss_RGB + loss_IR
        return loss_RGB_IR
    
    """
    GAN_WEIGHT = 0.1
    LAMBDA = 4000
    ETA = 4000
    """
    """
    GAN_WEIGHT = 1
    LAMBDA = 4000
    ETA = 10
    """
    """
    GAN_WEIGHT = 5e-3
    LAMBDA = 1
    ETA = 1 #1e-2
    """
    GAN_WEIGHT = 1
    LAMBDA = 1
    ETA = 100
    
    loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    
    def generator_loss(disc_generated_output, disc_real_output, gen_output, target):
      #gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
      gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output - disc_real_output)
      #print("gan loss")
      # mean absolute error
      #print("target ", target)
      #print("gen_output", gen_output)
    
      l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
      #huber_loss = tf.keras.losses.Huber(gen_output, target)
      huber_loss = tf.compat.v1.losses.huber_loss(gen_output, target)
      style_loss = styleLoss(gen_output, target)
      gram_loss = gramMatrixLoss(gen_output, target)
      feature_loss = featureLoss(gen_output, target)
      feature_loss_VGG = featureLossVGG(gen_output, target)
      #print("target - gen_output", target - gen_output)
      #l1_loss = np.mean(np.abs(target - gen_output))
      #print("L1 loss")
      total_gen_loss = (GAN_WEIGHT * gan_loss) + (ETA * l1_loss) # + style_loss
      #total_gen_loss = (GAN_WEIGHT * gan_loss) + (LAMBDA * gram_loss)
      #total_gen_loss = (GAN_WEIGHT * gan_loss) + (ETA * l1_loss)  + (LAMBDA * feature_loss_VGG)# + style_loss
      #print("total loss")
      return total_gen_loss, gan_loss, l1_loss, style_loss, feature_loss, gram_loss, huber_loss, feature_loss_VGG
    
    
    def discriminator_loss(disc_real_output, disc_generated_output):
      """
      real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
    
      generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
      """
      real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output - disc_generated_output)
      generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output - disc_real_output)
    
      total_disc_loss = real_loss + generated_loss
    
      return total_disc_loss
    
    #define the Optimizers and Checkpoint-saver
    
    starter_learning_rate = 0.0001
    """
    end_learning_rate = 0.00001
    decay_steps = (len_dataset / BATCH_SIZE)
    scaled_lr = tf.keras.optimizers.schedules.PolynomialDecay(
           starter_learning_rate,
           decay_steps,
           end_learning_rate,
           power=1)
    """
    
    learning_rate_fn_disc = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        [60*num_steps, 120*num_steps],
        [1e-4, 5e-5, 2e-5]
        )
    
    learning_rate_fn_gen = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        [60*num_steps, 120*num_steps],
        [5e-5, 2e-5, 1e-5]
        )
    
    """
    learning_rate_fn_gen = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        [50*num_steps, 100*num_steps, 200*num_steps, 300*num_steps],
        [2e-4, 1e-4, 5e-5, 2e-5, 1e-5]
        )
    
    learning_rate_fn_disc = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        [50*num_steps, 100*num_steps, 200*num_steps, 300*num_steps],
        [2e-4, 1e-4, 5e-5, 2e-5, 1e-5]
        )
    """
    """
    learning_rate_fn_gen = tf.optimizers.schedules.ExponentialDecay(2e-4, num_steps * 10, 0.9)
    learning_rate_fn_disc = tf.optimizers.schedules.ExponentialDecay(2e-4, num_steps * 10, 0.9)
    """
    """
    generator_optimizer = tf.keras.optimizers.Adam(starter_learning_rate, beta_1=0.9)
    discriminator_optimizer = tf.keras.optimizers.Adam(starter_learning_rate, beta_1=0.9)
    """
    
    generator_optimizer = tfa.optimizers.AdamW(
        learning_rate = starter_learning_rate,
        beta_1= 0.9,
        beta_2 = 0.999,
        weight_decay = 0.0001
    )
    discriminator_optimizer = tfa.optimizers.AdamW(
        learning_rate = starter_learning_rate,
        beta_1= 0.9,
        beta_2 = 0.999,
        weight_decay = 0.0001
    )
    
    
    #generator_optimizer = tfa.optimizers.LAMB(learning_rate=starter_learning_rate)
    #discriminator_optimizer = tfa.optimizers.LAMB(learning_rate=scaled_lr)
    #generator_optimizer = tf.keras.optimizers.SGD(learning_rate=2e-4, momentum=0.9, nesterov=True)
    #discriminator_optimizer = tf.keras.optimizers.SGD(learning_rate=2e-4, momentum=0.9, nesterov=True)
    
    #load models 
    
    #generator = Generator()
    generator = Generator_conv_S2inp()
    generator.summary()
    
    
    discriminator = Discriminator()
    discriminator.summary()
    
    checkpoint = tf.train.Checkpoint(step_counter = tf.Variable(1, trainable=False),
                                     generator_optimizer=generator_optimizer,
                                     discriminator_optimizer=discriminator_optimizer,
                                     generator=generator,
                                     discriminator=discriminator)
    manager = tf.train.CheckpointManager(checkpoint, './tf_ckpts_' + output_dir, max_to_keep=3)
    
    if load_pretrain == True:
            checkpoint_pregen = tf.train.Checkpoint(generator=generator)
            manager_pregen = tf.train.CheckpointManager(checkpoint_pregen, './tf_ckpts_' + output_dir + '_gen', max_to_keep=3)
    
    #generate images
    
    def generate_images(model, example_input1, example_input2, tar):
      generate_images.counter += 1
      prediction = model([example_input1, example_input2], training=False)
      f = plt.figure(figsize=(15,15))
    
      display_list = [example_input1[0], tar[0], prediction[0]]
      title = ['Input Image', 'Ground Truth', 'Predicted Image']
      """
      for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
      f.savefig(os.path.join(generated_dir, "generated-%09d.png" % generate_images.counter))
      """
      #invert normalization
      """
      A = (tar[0] * A_std) + A_mean
      B_10m = (example_input1[0] * B_std_10m) + B_mean_10m
      B_20m = (example_input2[0] * B_std_20m) + B_mean_20m
      B2A = (prediction[0]* A_std) + A_mean
      """
      
      A = (tar[0] * (p95A - p5A)) + meanA
      B_10m = (example_input1[0] * (p95B1 - p5B1)) + meanB1
      B_20m = (example_input2[0] * (p95B2 - p5B2)) + meanB2
      B2A = (prediction[0] * (p95A - p5A)) + meanA
      
      """
      A = test_input[0] * 10000
      B = tar[0] * 10000
      A2B = prediction[0] * 10000
    
      #print("A is: ", A)
      
      print("maxA : ", maxA, "minA : ", minA)
      print("maxB1 : ", maxB1, "minB1 : ", minB1)
      print("maxB2 : ", maxB2, "minB2 : ", minB2)
      print("A: ", np.flip(A[:, :,0:3]))
      print("B2A: ", np.flip(B2A[:, :, 0:3]))
      print("B: ", np.flip(B_10m[:, :, 0:3]))
      """
      sp.save_rgb(os.path.join(sample_temp_dir, 'A.png'), np.flip(A[:, :,0:3], 2), format='png', color_scale = [0,5000], bounds = [0,4096])
      sp.save_rgb(os.path.join(sample_temp_dir, 'B2A.png'), np.flip(B2A[:, :, 0:3], 2), format='png', color_scale = [0,5000], bounds = [0,4096])
      sp.save_rgb(os.path.join(sample_temp_dir, 'B.png'), np.flip(B_10m[:, :, 0:3], 2), format='png', color_scale = [0,5000], bounds = [0,4096])
      f, axarr = plt.subplots(1, 3)
      im_A = plt.imread(os.path.join(sample_temp_dir, "A.png"))
      im_B2A = plt.imread(os.path.join(sample_temp_dir, "B2A.png"))
      im_B = plt.imread(os.path.join(sample_temp_dir, "B.png"))
      axarr[0].imshow(im_A)
      axarr[0].title.set_text('A')
      axarr[1].imshow(im_B2A)
      axarr[1].title.set_text('B2A')
      axarr[2].imshow(im_B)
      axarr[2].title.set_text('B')
      plt.tight_layout()
      f.savefig(os.path.join(sample_dir, "generated-%09d.png" % generate_images.counter))
      plt.close() 
      return A, B_10m, B_20m, B2A
    
    generate_images.counter = 0 
    
    #training
    EPOCHS = 9
    initial_epoch = 0
    log_dir="logs/"
    if hvd.rank()==0:
       summary_writer = tf.summary.create_file_writer(log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
    print("second barrier")
    barrier = hvd.allreduce(tf.random.normal(shape=[1]))
    @tf.function
    def train_step(input_image_1, input_image_2, target, epoch, first_batch):
      #print("entered train step")
      #with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      with tf.GradientTape(persistent=True) as tape:
        #print("entered tape, rank:", hvd.rank())
        gen_output = generator([input_image_1, input_image_2], training=True)
        #print("gen")
        disc_real_output = discriminator([input_image_1, input_image_2, target], training=True)
        #print("disc re")
        disc_generated_output = discriminator([input_image_1, input_image_2, gen_output], training=True)
        #print("disc gem")
        gen_total_loss, gen_gan_loss, gen_l1_loss, gen_style_loss, gen_feature_loss, gen_gram_loss, gen_huber_loss, gen_feature_loss_VGG = generator_loss(DiffAugment(disc_generated_output, policy=policy), DiffAugment(disc_real_output, policy=policy), gen_output, target)
        #print("gen loss")
        disc_loss = discriminator_loss(DiffAugment(disc_real_output, policy=policy), DiffAugment(disc_generated_output, policy=policy))
        #print("disc loss")
      # Horovod: add Horovod Distributed GradientTape., 
      #gen_tape = hvd.DistributedGradientTape(gen_tape)
      #disc_tape = hvd.DistributedGradientTape(disc_tape)
      #print("done1")
      tape = hvd.DistributedGradientTape(tape)
      #print("done2")
      generator_gradients = tape.gradient(gen_total_loss,
                                              generator.trainable_variables)
      discriminator_gradients = tape.gradient(disc_loss,
                                                   discriminator.trainable_variables)
      #print("done3")
      generator_optimizer.apply_gradients(zip(generator_gradients,
                                              generator.trainable_variables))
      discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                                  discriminator.trainable_variables))
      #print("done4")
      # Horovod: broadcast initial variable states from rank 0 to all other processes.
      # This is necessary to ensure consistent initialization of all workers when
      # training is started with random weights or restored from a checkpoint.
      #
      # Note: broadcast should be done after the first gradient step to ensure optimizer
      # initialization.
      if first_batch == True:
         hvd.broadcast_variables(generator.variables, root_rank=0)
         hvd.broadcast_variables(generator_optimizer.variables(), root_rank=0)
         hvd.broadcast_variables(discriminator.variables, root_rank=0)
         hvd.broadcast_variables(discriminator_optimizer.variables(), root_rank=0)
         #print("first batch done, rank:", hvd.rank())
      #elif first_batch == False:
         #print("first batch is false, rank:", hvd.rank())
      if hvd.rank()==0:
         with summary_writer.as_default():
           tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)
           tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)
           tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)
           tf.summary.scalar('gen_style_loss', gen_style_loss, step=epoch)
           tf.summary.scalar('gen_gram_loss', gen_gram_loss, step=epoch)
           tf.summary.scalar('gen_huber_loss', gen_huber_loss, step=epoch)
           tf.summary.scalar('gen_feature_loss', gen_feature_loss, step=epoch)
           tf.summary.scalar('gen_feature_loss_VGG', gen_feature_loss_VGG, step=epoch)
           tf.summary.scalar('disc_loss', disc_loss, step=epoch)
           #print("done summary, rank:", hvd.rank())
    
    def fit(train_ds, initial_epoch, epochs, test_ds):
      for epoch in range(initial_epoch, epochs):
        start = time.time()
        if hvd.rank()==0:
           print("Epoch: ", epoch)
        # Train
        for n, (input_image1, input_image2, target) in (train_ds.enumerate()):
          if hvd.rank() == 0:
             print('.', end='')
             #print("boh")
             if (n+1) % 100 == 0:
                print()
          train_step(input_image1, input_image2, target, epoch, (n == 0) & (epoch == initial_epoch))
        if hvd.rank() == 0:
           print()
           print ('Time taken for epoch {} is {} sec\n'.format(epoch, time.time()-start))
           # saving (checkpoint) the model every 20 epochs
           if (epoch + 1) % 10 == 0:
              save_path = manager.save()
              print("Saved checkpoint for step {}: {}".format(int(checkpoint.step_counter), save_path))
        checkpoint.step_counter.assign_add(1)
        #plot sample image
        if hvd.rank()==0:
           for example_input1, example_input2, example_target in test_ds.take(1):
             generate_images(generator, example_input1, example_input2, example_target)
    
    """
    checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                     discriminator_optimizer=discriminator_optimizer,
                                     generator=generator,
                                     discriminator=discriminator)
    """
    def predict(test_ds):
      #i=0
      #for inp1, inp2, tar in A_B_dataset_test.take(len_test):
      for n, (input_image1, input_image2, target) in (test_ds.enumerate()):
         t_A, t_B_10m, t_B_20m, t_A2B = generate_images(generator, input_image1, input_image2, target)
         d_or_30m[n, :, :, :] = t_A
         d_or_10m[n, :, :, :] = t_B_10m
         d_or_20m[n, :, :, :] = t_B_20m
         d_pred[n, :, :, :] = clip_data(t_A2B, p1A, p99A)
         #i = i + 1
      """
      np.save("output/predicted_im.npy", predicted_images)
      np.save("output/original_im_10m.npy", original_images_10m)
      np.save("output/original_im_20m.npy", original_images_20m)
      np.save("output/original_im_30m.npy", original_images_30m)
      """
      #checkpoint.save(file_prefix = checkpoint_prefix)
    
    #run
    if __name__ == "__main__":
       if load_pretrain:
          checkpoint_pregen.restore(manager_pregen.latest_checkpoint)
          print("Restored pregenerator from {}".format(manager_pregen.latest_checkpoint))
          print("initial epoch: ", initial_epoch)
       elif manager.latest_checkpoint:
          checkpoint.restore(manager.latest_checkpoint)
          print("Restored from {}".format(manager.latest_checkpoint))
          initial_epoch = int(checkpoint.step_counter) + 1
          generate_images.counter = int(checkpoint.step_counter)
          print("initial epoch: ", initial_epoch)
       else:
          print("Initializing from scratch.")
    
       print("test_ds: ", A_B_dataset_test)
       if TRAIN == True:
          print("train_ds: ", A_B_dataset)
          fit(A_B_dataset, initial_epoch, EPOCHS, A_B_dataset_test)
       elif hvd.rank() == 0:
          generate_images.counter = 0
          #predict(A_B_dataset_test)
          #predict(A_B_dataset.concatenate(A_B_dataset_test))
          predict(A_B_dataset_test)
          h5_file_pred.close()
       else:
          time.sleep(60*60*2)