Select Git revision
quantum_SVM.py
train_pred.py 33.78 KiB
#import libraries
import os
import sys
import pathlib
import time
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import numpy as np
import h5py
import datetime
import data
import spectral as sp
import scipy.ndimage as ndi
import horovod.tensorflow as hvd
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import tensorflow as tf
from DiffAugment_tf import DiffAugment
import tensorflow_addons as tfa
from models.resnet import resnet_18
from test_normalization import scale_data, normalize_data, normalize_data_withperc, clip_data, get_percentiles_mean
from model_blocks import Generator, Generator_conv_S2inp, Discriminator, PATCH_30m_DIM, PATCH_20m_DIM, PATCH_10m_DIM
#from mpi4py import MPI
#comm = MPI.COMM_WORLD
hvd.init()
print("local rank", hvd.local_rank())
print("rank", hvd.rank())
# Pin GPU to be used to process local rank (one GPU per process)
gpus = tf.config.experimental.list_physical_devices('GPU')
print("gpus: ", gpus)
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
#print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))
if gpus:
tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
print ("hvd size", hvd.size())
#print("comm size", comm.Get_size())
output_dir = "output_diffaug_128_l1loss_eta100_noinstgen_bench"
sample_dir = os.path.join(output_dir, 'samples_training')
generated_dir = os.path.join(sample_dir, 'generated')
sample_temp_dir = os.path.join(sample_dir, 'temp')
pathlib.Path(sample_temp_dir).mkdir(parents=True, exist_ok=True)
sample_test = os.path.join(sample_dir, 'test')
pathlib.Path(sample_test).mkdir(parents=True, exist_ok=True)
load_pretrain = False
TRAIN = True
CLIP_NORM = True
PRED_INTERM = False
compute_percentiles = False
BUFFER_SIZE = 400
if TRAIN == True:
BATCH_SIZE = 2
else:
BATCH_SIZE = 1
POOL_SIZE = 10
PATCH_30m_DIM = 128
PATCH_20m_DIM = 192
PATCH_10m_DIM = 384
IMG_HEIGHT = 256
policy = 'color,translation,cutout'
#policy = 'translation,cutout'
print("inizio")
# ==============================================================================
# = VGG-19 =
# ==============================================================================
VGG_model = tf.keras.applications.VGG16(
include_top=False, weights='imagenet', input_tensor=None,
input_shape=None, pooling=None
)
#latest = tf.train.Checkpoint("/p/project/joaiml/remote_sensing/rocco_sedona/ben_TF2/scripts/checkpoints/checkpoint-50.h5")
source_model = resnet_18()
source_model.build(input_shape=(None, 120, 120, 6))
source_model.load_weights("/p/project/joaiml/remote_sensing/rocco_sedona/rslab_col/network_code/pix2pix/checkpoint_resnet18/checkpoint-50.h5")
perceptual_model = tf.keras.Sequential()
for layer in source_model.layers[:-2]: # go through until last layer
perceptual_model.add(layer)
perceptual_model.build(input_shape=(None, 120, 120, 6))
perceptual_model.summary()
# ==============================================================================
# = data =
# ==============================================================================
#load retiled data
"""
A_data = np.load('data_numpy/A_data_30_40.npy').astype(np.float32)
B_data_10m = np.load('data_numpy/B_data_10_120.npy').astype(np.float32)
B_data_20m = np.load('data_numpy/B_data_20_60.npy').astype(np.float32)
"""
#data patches
#path_L8 = "/p/scratch/joaiml/sedona3/L8_Images/L8_patches_new.h5"
path_S2 = "../../data/tile_33UVP/timeseries_project/S2_patches_128_shifted_4r.h5"
#path_L8 = "/p/scratch/joaiml/sedona3/L8_Images/L8_patches_new.h5"
#path_L8 = "../../data/tile_33UVP/timeseries_project/L8_patches_128.h5"
path_L8 = '../../data/tile_33UVP/timeseries_project/L8_patches_128_water.h5'
h5_file_S2 = h5py.File(path_S2, "r")
h5_file_L8 = h5py.File(path_L8, "r")
val_S2 = h5_file_S2["val"][:]
val_S2_scl = h5_file_S2["val_scl"][:]
val_L8 = h5_file_L8["val"][:]
val_L8_qa = h5_file_L8["val_qa"][:]
validity = (val_L8_qa==1) & (val_L8==1) & (val_S2_scl==1) & (val_S2==1)
interm_L8 = h5_file_L8["interm"][:]
interm_S2 = h5_file_S2["interm"][:]
date_L8 = h5_file_L8["date"][:]
#len_entire_dataset = np.count_nonzero(validity==True)
len_entire_dataset = validity.shape[0]
#length dataset for training
len_dataset = np.floor(len_entire_dataset * 0.7).astype(int)
num_steps = len_dataset // (hvd.size() * BATCH_SIZE)
#length dataset for test
#len_test = len_entire_dataset - len_dataset
if TRAIN == True:
len_test = 1
date = 0
date_condition = (date_L8 == 0) | (date_L8 == 1) | (date_L8 == 2) | (date_L8 == 4) | (date_L8 == 5)
else:
#len_test = len_entire_dataset - len_dataset
date = 3
date_condition = date_L8 == date
len_test = np.count_nonzero(date_condition)
print(len_test)
len_dataset = 0
#number of non intermediate patches
#num_non_interm = np.count_nonzero(interm_L8==0)
num_non_interm = len_test
#num_non_interm_train = int(num_non_interm * 0.1)
#num_non_interm_test = num_non_interm - num_non_interm_train
size_chunk = int(len_dataset/hvd.size())
#size_chunk = 1000
if TRAIN == True:
size_chunk_test = 1
else:
size_chunk_test = int(len_test/hvd.size())
#size_chunk_test = 100
validity_train = validity[date_condition]
validity_train = validity_train[(hvd.rank()*size_chunk):((hvd.rank()+1)*size_chunk)]
if TRAIN == True:
A_data_train = h5_file_L8["data_30m"][date_condition,:,:,:]
A_data_train = A_data_train[(hvd.rank()*size_chunk):((hvd.rank()+1)*size_chunk)].astype(np.float32)
B_data_train_10m = h5_file_S2["data_10m"][date_condition,:,:,:]
B_data_train_10m = B_data_train_10m[(hvd.rank()*size_chunk):((hvd.rank()+1)*size_chunk)].astype(np.float32)
B_data_train_20m = h5_file_S2["data_20m"][date_condition,:,:,:]
B_data_train_20m = B_data_train_20m[(hvd.rank()*size_chunk):((hvd.rank()+1)*size_chunk)].astype(np.float32)
A_data_train[A_data_train<0] = 0
#B_data = B_data * 0.0001
B_data_train_10m[B_data_train_10m<0] = 0
B_data_train_20m[B_data_train_20m<0] = 0
validity_test = validity[(len_dataset+(hvd.rank()*size_chunk_test)):(len_dataset+((hvd.rank()+1)*size_chunk_test))]
A_data_test = h5_file_L8["data_30m"][(len_dataset+(hvd.rank()*size_chunk_test)):(len_dataset+((hvd.rank()+1)*size_chunk_test))].astype(np.float32)
B_data_test_10m = h5_file_S2["data_10m"][(len_dataset+(hvd.rank()*size_chunk_test)):(len_dataset+((hvd.rank()+1)*size_chunk_test))].astype(np.float32)
B_data_test_20m = h5_file_S2["data_20m"][(len_dataset+(hvd.rank()*size_chunk_test)):(len_dataset+((hvd.rank()+1)*size_chunk_test))].astype(np.float32)
else:
validity_test = validity[(date*size_chunk_test):((date+1)*size_chunk_test)]
A_data_test = h5_file_L8["data_30m"][(date*size_chunk_test):((date+1)*size_chunk_test)].astype(np.float32)
B_data_test_10m = h5_file_S2["data_10m"][(date*size_chunk_test):((date+1)*size_chunk_test)].astype(np.float32)
B_data_test_20m = h5_file_S2["data_20m"][(date*size_chunk_test):((date+1)*size_chunk_test)].astype(np.float32)
#scale input data (valid range 0-10000)
#A_data = A_data * 0.0001
A_data_test[A_data_test<0] = 0
B_data_test_10m[B_data_test_10m<0] = 0
B_data_test_20m[B_data_test_20m<0] = 0
if compute_percentiles == True:
arr_p1A, arr_p99A, arr_p5A, arr_p95A, arr_meanA = get_percentiles_mean(h5_file_L8, True)
arr_p1B1, arr_p99B1, arr_p5B1, arr_p95B1, arr_meanB1, arr_p1B2, arr_p99B2, arr_p5B2, arr_p95B2, arr_meanB2 = get_percentiles_mean(h5_file_S2, False)
np.save("data_numpy/p1A.npy", arr_p1A)
np.save("data_numpy/p99A.npy", arr_p99A)
np.save("data_numpy/p5A.npy", arr_p5A)
np.save("data_numpy/p95A.npy", arr_p95A)
np.save("data_numpy/meanA.npy", arr_meanA)
np.save("data_numpy/arr_p1B1.npy", arr_p1B1)
np.save("data_numpy/arr_p99B1.npy", arr_p99B1)
np.save("data_numpy/arr_p5B1.npy", arr_p5B1)
np.save("data_numpy/arr_p95B1.npy", arr_p95B1)
np.save("data_numpy/arr_meanB1.npy", arr_meanB1)
np.save("data_numpy/arr_p1B2.npy", arr_p1B2)
np.save("data_numpy/arr_p99B2.npy", arr_p99B2)
np.save("data_numpy/arr_p5B2.npy", arr_p5B2)
np.save("data_numpy/arr_p95B2.npy", arr_p95B2)
np.save("data_numpy/arr_meanB2.npy", arr_meanB2)
else:
arr_p1A = np.load("data_numpy/p1A.npy")
p1A = arr_p1A.mean(axis = 0).astype(int)
arr_p99A = np.load("data_numpy/p99A.npy")
p99A = arr_p99A.mean(axis = 0).astype(int)
arr_p5A = np.load("data_numpy/p5A.npy")
p5A = arr_p5A.mean(axis = 0).astype(int)
arr_p95A = np.load("data_numpy/p95A.npy")
p95A = arr_p95A.mean(axis = 0).astype(int)
arr_meanA = np.load("data_numpy/meanA.npy")
meanA = arr_meanA.mean(axis = 0)
arr_p1B1 = np.load("data_numpy/arr_p1B1.npy")
p1B1 = arr_p1B1.mean(axis = 0).astype(int)
arr_p99B1 = np.load("data_numpy/arr_p99B1.npy")
p99B1 = arr_p99B1.mean(axis = 0).astype(int)
arr_p95B1 = np.load("data_numpy/arr_p95B1.npy")
p95B1 = arr_p95B1.mean(axis = 0).astype(int)
arr_p5B1 = np.load("data_numpy/arr_p5B1.npy")
p5B1 = arr_p5B1.mean(axis = 0).astype(int)
arr_meanB1 = np.load("data_numpy/arr_meanB1.npy")
meanB1 = arr_meanB1.mean(axis = 0)
arr_p1B2 = np.load("data_numpy/arr_p1B2.npy")
p1B2 = arr_p1B2.mean(axis = 0).astype(int)
arr_p99B2 = np.load("data_numpy/arr_p99B2.npy")
p99B2 = arr_p99B2.mean(axis = 0).astype(int)
arr_p95B2 = np.load("data_numpy/arr_p95B2.npy")
p95B2 = arr_p95B2.mean(axis = 0).astype(int)
arr_p5B2 = np.load("data_numpy/arr_p5B2.npy")
p5B2 = arr_p5B2.mean(axis = 0).astype(int)
arr_meanB2 = np.load("data_numpy/arr_meanB2.npy")
meanB2 = arr_meanB2.mean(axis = 0)
if TRAIN == True:
#use this only during training
A_data_train = A_data_train[validity_train,:,:,1:7]
B_data_train_10m = B_data_train_10m[validity_train]
B_data_train_20m = B_data_train_20m[validity_train,:,:,0:2]
A_data_test = A_data_test[validity_test,:,:,1:7]
B_data_test_10m = B_data_test_10m[validity_test]
B_data_test_20m = B_data_test_20m[validity_test,:,:,0:2]
#use this only during test, otherwise it would train using also cloudy or snowy patches: add an if condition
elif PRED_INTERM == True:
interm_L8_train = interm_L8[:len_dataset]
interm_L8_test = interm_L8[len_dataset:]
A_data_train = A_data_train[interm_L8_train==0,:,:,1:7]
B_data_train_10m = B_data_train_10m[interm_L8_train==0]
B_data_train_20m = B_data_train_20m[interm_L8_train==0,:,:,0:2]
A_data_test = A_data_test[interm_L8_test==0,:,:,1:7]
B_data_test_10m = B_data_test_10m[interm_L8_test==0]
B_data_test_20m = B_data_test_20m[interm_L8_test==0,:,:,0:2]
else:
#A_data_train = A_data_train[date_condition,:,:,1:7]
#B_data_train_10m = B_data_train_10m[date_condition]
#B_data_train_20m = B_data_train_20m[date_condition,:,:,0:2]
A_data_test = A_data_test[:,:,:,1:7]
B_data_test_10m = B_data_test_10m[:]
B_data_test_20m = B_data_test_20m[:,:,:,0:2]
#saturate extremum values, below and above percentiles 1 and 99
print("before clipping")
if (TRAIN == True) & (CLIP_NORM == True):
A_data_train = clip_data(A_data_train, p1A, p99A).astype(np.float32)
B_data_train_10m = clip_data(B_data_train_10m, p1B1, p99B1).astype(np.float32)
B_data_train_20m = clip_data(B_data_train_20m, p1B2, p99B2).astype(np.float32)
if CLIP_NORM == True:
A_data_test = clip_data(A_data_test, p1A, p99A)
B_data_test_10m = clip_data(B_data_test_10m, p1B1, p99B1)
B_data_test_20m = clip_data(B_data_test_20m, p1B2, p99B2)
print("A_data_test shape: ", A_data_test.shape)
print("B_data_test_10m shape: ", B_data_test_10m.shape)
print("B_data_test_20m shape: ", B_data_test_20m.shape)
h5_file_S2.close()
h5_file_L8.close()
"""
#added percentile 1% and 99% to saturate extrema, astype(np.float32)
A_data_train, minA, maxA = scale_data(A_data_train)
B_data_train_10m, minB1, maxB1 = scale_data(B_data_train_10m)
B_data_train_20m, minB2, maxB2 = scale_data(B_data_train_20m)
A_data_test = scale_data_withperc(A_data_test, minA, maxA)
B_data_test_10m = scale_data_withperc(B_data_test_10m, minB1, maxB1)
B_data_test_20m = scale_data_withperc(B_data_test_20m, minB2, maxB2)
"""
print("before normalizing")
#normalize data with percentile
if (TRAIN == True) & (CLIP_NORM == True):
A_data_train = normalize_data_withperc(A_data_train, p5A.astype(np.float32), p95A.astype(np.float32), meanA.astype(np.float32))
B_data_train_10m = normalize_data_withperc(B_data_train_10m, p5B1.astype(np.float32), p95B1.astype(np.float32), meanB1.astype(np.float32))
B_data_train_20m = normalize_data_withperc(B_data_train_20m, p5B2.astype(np.float32), p95B2.astype(np.float32), meanB2.astype(np.float32))
if CLIP_NORM == True:
A_data_test = normalize_data_withperc(A_data_test, p5A, p95A, meanA)
B_data_test_10m = normalize_data_withperc(B_data_test_10m, p5B1, p95B1, meanB1)
B_data_test_20m = normalize_data_withperc(B_data_test_20m, p5B2, p95B2, meanB2)
print("A_data_test shape: ", A_data_test.shape)
print("B_data_test_10m shape: ", B_data_test_10m.shape)
print("B_data_test_20m shape: ", B_data_test_20m.shape)
#normalize
"""
A_mean = np.mean(A_data_train, (0,1,2))
#A_std = (np.std(A_data_train, (0,1,2))) * 10
A_std = (np.std(A_data_train, (0,1,2))) * 10
A_data_train = (A_data_train - A_mean) / A_std
A_data_test = (A_data_test - A_mean) / A_std
print("A_std: ", A_std, " A_mean: ", A_mean)
print("A max: ", A_data_train.max(axis=3), " A min:", A_data_train.min(axis=3))
B_mean_10m = np.mean(B_data_train_10m, (0,1,2))
#B_std_10m = (np.std(B_data_train_10m, (0,1,2))) * 10
B_std_10m = (np.std(B_data_train_10m, (0,1,2))) * 10
print("B_std_10m: ", B_std_10m, " B_mean_10m: ", B_mean_10m)
B_data_train_10m = (B_data_train_10m - B_mean_10m) / B_std_10m
B_data_test_10m = (B_data_test_10m - B_mean_10m) / B_std_10m
print("B10 max: ", B_data_train_10m.max(), " B10 min:", B_data_train_10m.min())
B_mean_20m = np.mean(B_data_train_20m, (0,1,2))
#B_std_20m = (np.std(B_data_train_20m, (0,1,2))) * 10
B_std_20m = (np.std(B_data_train_20m, (0,1,2))) * 10
print("B_std_20m: ", B_std_20m, " B_mean_20m: ", B_mean_20m)
B_data_train_20m = (B_data_train_20m - B_mean_20m) / B_std_20m
B_data_test_20m = (B_data_test_20m - B_mean_20m) / B_std_20m
print("B20 max: ", B_data_train_20m.max(), " B20 min:", B_data_train_20m.min())
"""
"""
#normalize
A_mean = meanA
#A_std = (np.std(A_data_train, (0,1,2))) * 10
A_std = (np.std(A_data_train, (0,1,2))) * 1
A_data_train = (A_data_train - A_mean) / A_std
A_data_test = (A_data_test - A_mean) / A_std
B_mean_10m = meanB1
#B_std_10m = (np.std(B_data_train_10m, (0,1,2))) * 10
B_std_10m = (np.std(B_data_train_10m, (0,1,2))) * 1
B_data_train_10m = (B_data_train_10m - B_mean_10m) / B_std_10m
B_data_test_10m = (B_data_test_10m - B_mean_10m) / B_std_10m
B_mean_20m = meanB2
#B_std_20m = (np.std(B_data_train_20m, (0,1,2))) * 10
B_std_20m = (np.std(B_data_train_20m, (0,1,2))) * 1
B_data_train_20m = (B_data_train_20m - B_mean_20m) / B_std_20m
B_data_test_20m = (B_data_test_20m - B_mean_20m) / B_std_20m
A_data_train = A_data[(hvd.local_rank()*size_chunk):((hvd.local_rank()+1)*size_chunk)]
B_data_train_10m = B_data_10m[(hvd.local_rank()*size_chunk):((hvd.local_rank()+1)*size_chunk)]
B_data_train_20m = B_data_20m[(hvd.local_rank()*size_chunk):((hvd.local_rank()+1)*size_chunk)]
"""
if TRAIN == True:
A_data_train = tf.data.Dataset.from_tensor_slices((A_data_train)).shard(hvd.size(), hvd.rank()).batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)
B_data_train_10m = tf.data.Dataset.from_tensor_slices((B_data_train_10m)).shard(hvd.size(), hvd.rank()).batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)
B_data_train_20m = tf.data.Dataset.from_tensor_slices((B_data_train_20m)).shard(hvd.size(), hvd.rank()).batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)
A_B_dataset = tf.data.Dataset.zip((B_data_train_10m, B_data_train_20m, A_data_train))
#A2B_pool = data.ItemPool(POOL_SIZE)
#B2A_pool = data.ItemPool(POOL_SIZE)
print("A_data_test shape: ", A_data_test.shape)
print("B_data_test_10m shape: ", B_data_test_10m.shape)
print("B_data_test_20m shape: ", B_data_test_20m.shape)
A_data_test = tf.data.Dataset.from_tensor_slices((A_data_test)).batch(1)
B_data_test_10m = tf.data.Dataset.from_tensor_slices((B_data_test_10m)).batch(1)
B_data_test_20m = tf.data.Dataset.from_tensor_slices((B_data_test_20m)).batch(1)
A_B_dataset_test = tf.data.Dataset.zip((B_data_test_10m, B_data_test_20m, A_data_test))
#original_images_30m = np.ones([size_chunk, 40, 40, 6])
#original_images_10m = np.ones([size_chunk, 120, 120, 4])
#original_images_20m = np.ones([size_chunk, 60, 60, 2]
#predicted_images = np.ones([size_chunk, 40, 40, 6])
print("data done on rank: ", hvd.rank())
#hvd.allreduce([0], name="Barrier")
if hvd.rank()==0:
path_pred_fold = "/p/project/joaiml/remote_sensing/rocco_sedona/rslab_col/data/tile_33UVP/timeseries_project/" + output_dir
pathlib.Path(path_pred_fold).mkdir(parents=True, exist_ok=True)
path_pred = path_pred_fold + "/date_" + str(date) + ".h5"
h5_file_pred = h5py.File(path_pred, "w")
"""
d_pred = h5_file_pred.create_dataset("data_pred", shape=[len_test, 40, 40, 6],dtype=np.float32)
d_or_30m = h5_file_pred.create_dataset("data_30m", shape=[len_test, 40, 40, 6],dtype=np.float32)
d_or_10m = h5_file_pred.create_dataset("data_10m", shape=[len_test, 120, 120, 4],dtype=np.float32)
d_or_20m = h5_file_pred.create_dataset("data_20m", shape=[len_test, 60, 60, 2],dtype=np.float32)
"""
d_pred = h5_file_pred.create_dataset("data_pred", shape=[num_non_interm, PATCH_30m_DIM, PATCH_30m_DIM, 6],dtype=np.float32)
d_or_30m = h5_file_pred.create_dataset("data_30m", shape=[num_non_interm, PATCH_30m_DIM, PATCH_30m_DIM, 6],dtype=np.float32)
d_or_10m = h5_file_pred.create_dataset("data_10m", shape=[num_non_interm, PATCH_10m_DIM, PATCH_10m_DIM, 4],dtype=np.float32)
d_or_20m = h5_file_pred.create_dataset("data_20m", shape=[num_non_interm, PATCH_20m_DIM, PATCH_20m_DIM, 2],dtype=np.float32)
barrier = hvd.allreduce(tf.random.normal(shape=[1]))
# ==============================================================================
# = model =
# ==============================================================================
def meanSquaredLoss(y_true,y_pred):
return tf.reduce_mean(tf.keras.losses.MSE(y_true,y_pred))
def gram_matrix(input_tensor):
result = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor)
input_shape = tf.shape(input_tensor)
num_locations = tf.cast(input_shape[1]*input_shape[2], tf.float32)
return result/(num_locations)
def styleLoss(generated_image, true_image):
activatedModelVal = perceptual_model(generated_image,training=False)
actualModelVal = perceptual_model(true_image,training=False)
return meanSquaredLoss(gram_matrix(actualModelVal),gram_matrix(activatedModelVal))
def gramMatrixLoss(generated_image, true_image):
return meanSquaredLoss(gram_matrix(generated_image),gram_matrix(true_image))
def featureLoss(generated_image, true_image):
activatedModelVal = perceptual_model(generated_image,training=False)
actualModelVal = perceptual_model(true_image,training=False)
return meanSquaredLoss(actualModelVal,activatedModelVal)
def featureLossVGG(generated_image, true_image):
activatedModelVal_RGB = VGG_model(generated_image[:,:,:,0:3],training=False)
actualModelVal_RGB = VGG_model(true_image[:,:,:,0:3],training=False)
activatedModelVal_IR = VGG_model(generated_image[:,:,:,3:6],training=False)
actualModelVal_IR = VGG_model(true_image[:,:,:,3:6],training=False)
loss_RGB = meanSquaredLoss(actualModelVal_RGB,activatedModelVal_RGB)
loss_IR = meanSquaredLoss(actualModelVal_IR,activatedModelVal_IR)
loss_RGB_IR = loss_RGB + loss_IR
return loss_RGB_IR
"""
GAN_WEIGHT = 0.1
LAMBDA = 4000
ETA = 4000
"""
"""
GAN_WEIGHT = 1
LAMBDA = 4000
ETA = 10
"""
"""
GAN_WEIGHT = 5e-3
LAMBDA = 1
ETA = 1 #1e-2
"""
GAN_WEIGHT = 1
LAMBDA = 1
ETA = 100
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def generator_loss(disc_generated_output, disc_real_output, gen_output, target):
#gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output - disc_real_output)
#print("gan loss")
# mean absolute error
#print("target ", target)
#print("gen_output", gen_output)
l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
#huber_loss = tf.keras.losses.Huber(gen_output, target)
huber_loss = tf.compat.v1.losses.huber_loss(gen_output, target)
style_loss = styleLoss(gen_output, target)
gram_loss = gramMatrixLoss(gen_output, target)
feature_loss = featureLoss(gen_output, target)
feature_loss_VGG = featureLossVGG(gen_output, target)
#print("target - gen_output", target - gen_output)
#l1_loss = np.mean(np.abs(target - gen_output))
#print("L1 loss")
total_gen_loss = (GAN_WEIGHT * gan_loss) + (ETA * l1_loss) # + style_loss
#total_gen_loss = (GAN_WEIGHT * gan_loss) + (LAMBDA * gram_loss)
#total_gen_loss = (GAN_WEIGHT * gan_loss) + (ETA * l1_loss) + (LAMBDA * feature_loss_VGG)# + style_loss
#print("total loss")
return total_gen_loss, gan_loss, l1_loss, style_loss, feature_loss, gram_loss, huber_loss, feature_loss_VGG
def discriminator_loss(disc_real_output, disc_generated_output):
"""
real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
"""
real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output - disc_generated_output)
generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output - disc_real_output)
total_disc_loss = real_loss + generated_loss
return total_disc_loss
#define the Optimizers and Checkpoint-saver
starter_learning_rate = 0.0001
"""
end_learning_rate = 0.00001
decay_steps = (len_dataset / BATCH_SIZE)
scaled_lr = tf.keras.optimizers.schedules.PolynomialDecay(
starter_learning_rate,
decay_steps,
end_learning_rate,
power=1)
"""
learning_rate_fn_disc = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
[60*num_steps, 120*num_steps],
[1e-4, 5e-5, 2e-5]
)
learning_rate_fn_gen = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
[60*num_steps, 120*num_steps],
[5e-5, 2e-5, 1e-5]
)
"""
learning_rate_fn_gen = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
[50*num_steps, 100*num_steps, 200*num_steps, 300*num_steps],
[2e-4, 1e-4, 5e-5, 2e-5, 1e-5]
)
learning_rate_fn_disc = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
[50*num_steps, 100*num_steps, 200*num_steps, 300*num_steps],
[2e-4, 1e-4, 5e-5, 2e-5, 1e-5]
)
"""
"""
learning_rate_fn_gen = tf.optimizers.schedules.ExponentialDecay(2e-4, num_steps * 10, 0.9)
learning_rate_fn_disc = tf.optimizers.schedules.ExponentialDecay(2e-4, num_steps * 10, 0.9)
"""
"""
generator_optimizer = tf.keras.optimizers.Adam(starter_learning_rate, beta_1=0.9)
discriminator_optimizer = tf.keras.optimizers.Adam(starter_learning_rate, beta_1=0.9)
"""
generator_optimizer = tfa.optimizers.AdamW(
learning_rate = starter_learning_rate,
beta_1= 0.9,
beta_2 = 0.999,
weight_decay = 0.0001
)
discriminator_optimizer = tfa.optimizers.AdamW(
learning_rate = starter_learning_rate,
beta_1= 0.9,
beta_2 = 0.999,
weight_decay = 0.0001
)
#generator_optimizer = tfa.optimizers.LAMB(learning_rate=starter_learning_rate)
#discriminator_optimizer = tfa.optimizers.LAMB(learning_rate=scaled_lr)
#generator_optimizer = tf.keras.optimizers.SGD(learning_rate=2e-4, momentum=0.9, nesterov=True)
#discriminator_optimizer = tf.keras.optimizers.SGD(learning_rate=2e-4, momentum=0.9, nesterov=True)
#load models
#generator = Generator()
generator = Generator_conv_S2inp()
generator.summary()
discriminator = Discriminator()
discriminator.summary()
checkpoint = tf.train.Checkpoint(step_counter = tf.Variable(1, trainable=False),
generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator)
manager = tf.train.CheckpointManager(checkpoint, './tf_ckpts_' + output_dir, max_to_keep=3)
if load_pretrain == True:
checkpoint_pregen = tf.train.Checkpoint(generator=generator)
manager_pregen = tf.train.CheckpointManager(checkpoint_pregen, './tf_ckpts_' + output_dir + '_gen', max_to_keep=3)
#generate images
def generate_images(model, example_input1, example_input2, tar):
generate_images.counter += 1
prediction = model([example_input1, example_input2], training=False)
f = plt.figure(figsize=(15,15))
display_list = [example_input1[0], tar[0], prediction[0]]
title = ['Input Image', 'Ground Truth', 'Predicted Image']
"""
for i in range(3):
plt.subplot(1, 3, i+1)
plt.title(title[i])
# getting the pixel values between [0, 1] to plot it.
plt.imshow(display_list[i] * 0.5 + 0.5)
plt.axis('off')
f.savefig(os.path.join(generated_dir, "generated-%09d.png" % generate_images.counter))
"""
#invert normalization
"""
A = (tar[0] * A_std) + A_mean
B_10m = (example_input1[0] * B_std_10m) + B_mean_10m
B_20m = (example_input2[0] * B_std_20m) + B_mean_20m
B2A = (prediction[0]* A_std) + A_mean
"""
A = (tar[0] * (p95A - p5A)) + meanA
B_10m = (example_input1[0] * (p95B1 - p5B1)) + meanB1
B_20m = (example_input2[0] * (p95B2 - p5B2)) + meanB2
B2A = (prediction[0] * (p95A - p5A)) + meanA
"""
A = test_input[0] * 10000
B = tar[0] * 10000
A2B = prediction[0] * 10000
#print("A is: ", A)
print("maxA : ", maxA, "minA : ", minA)
print("maxB1 : ", maxB1, "minB1 : ", minB1)
print("maxB2 : ", maxB2, "minB2 : ", minB2)
print("A: ", np.flip(A[:, :,0:3]))
print("B2A: ", np.flip(B2A[:, :, 0:3]))
print("B: ", np.flip(B_10m[:, :, 0:3]))
"""
sp.save_rgb(os.path.join(sample_temp_dir, 'A.png'), np.flip(A[:, :,0:3], 2), format='png', color_scale = [0,5000], bounds = [0,4096])
sp.save_rgb(os.path.join(sample_temp_dir, 'B2A.png'), np.flip(B2A[:, :, 0:3], 2), format='png', color_scale = [0,5000], bounds = [0,4096])
sp.save_rgb(os.path.join(sample_temp_dir, 'B.png'), np.flip(B_10m[:, :, 0:3], 2), format='png', color_scale = [0,5000], bounds = [0,4096])
f, axarr = plt.subplots(1, 3)
im_A = plt.imread(os.path.join(sample_temp_dir, "A.png"))
im_B2A = plt.imread(os.path.join(sample_temp_dir, "B2A.png"))
im_B = plt.imread(os.path.join(sample_temp_dir, "B.png"))
axarr[0].imshow(im_A)
axarr[0].title.set_text('A')
axarr[1].imshow(im_B2A)
axarr[1].title.set_text('B2A')
axarr[2].imshow(im_B)
axarr[2].title.set_text('B')
plt.tight_layout()
f.savefig(os.path.join(sample_dir, "generated-%09d.png" % generate_images.counter))
plt.close()
return A, B_10m, B_20m, B2A
generate_images.counter = 0
#training
EPOCHS = 9
initial_epoch = 0
log_dir="logs/"
if hvd.rank()==0:
summary_writer = tf.summary.create_file_writer(log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
print("second barrier")
barrier = hvd.allreduce(tf.random.normal(shape=[1]))
@tf.function
def train_step(input_image_1, input_image_2, target, epoch, first_batch):
#print("entered train step")
#with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
with tf.GradientTape(persistent=True) as tape:
#print("entered tape, rank:", hvd.rank())
gen_output = generator([input_image_1, input_image_2], training=True)
#print("gen")
disc_real_output = discriminator([input_image_1, input_image_2, target], training=True)
#print("disc re")
disc_generated_output = discriminator([input_image_1, input_image_2, gen_output], training=True)
#print("disc gem")
gen_total_loss, gen_gan_loss, gen_l1_loss, gen_style_loss, gen_feature_loss, gen_gram_loss, gen_huber_loss, gen_feature_loss_VGG = generator_loss(DiffAugment(disc_generated_output, policy=policy), DiffAugment(disc_real_output, policy=policy), gen_output, target)
#print("gen loss")
disc_loss = discriminator_loss(DiffAugment(disc_real_output, policy=policy), DiffAugment(disc_generated_output, policy=policy))
#print("disc loss")
# Horovod: add Horovod Distributed GradientTape.,
#gen_tape = hvd.DistributedGradientTape(gen_tape)
#disc_tape = hvd.DistributedGradientTape(disc_tape)
#print("done1")
tape = hvd.DistributedGradientTape(tape)
#print("done2")
generator_gradients = tape.gradient(gen_total_loss,
generator.trainable_variables)
discriminator_gradients = tape.gradient(disc_loss,
discriminator.trainable_variables)
#print("done3")
generator_optimizer.apply_gradients(zip(generator_gradients,
generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
discriminator.trainable_variables))
#print("done4")
# Horovod: broadcast initial variable states from rank 0 to all other processes.
# This is necessary to ensure consistent initialization of all workers when
# training is started with random weights or restored from a checkpoint.
#
# Note: broadcast should be done after the first gradient step to ensure optimizer
# initialization.
if first_batch == True:
hvd.broadcast_variables(generator.variables, root_rank=0)
hvd.broadcast_variables(generator_optimizer.variables(), root_rank=0)
hvd.broadcast_variables(discriminator.variables, root_rank=0)
hvd.broadcast_variables(discriminator_optimizer.variables(), root_rank=0)
#print("first batch done, rank:", hvd.rank())
#elif first_batch == False:
#print("first batch is false, rank:", hvd.rank())
if hvd.rank()==0:
with summary_writer.as_default():
tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)
tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)
tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)
tf.summary.scalar('gen_style_loss', gen_style_loss, step=epoch)
tf.summary.scalar('gen_gram_loss', gen_gram_loss, step=epoch)
tf.summary.scalar('gen_huber_loss', gen_huber_loss, step=epoch)
tf.summary.scalar('gen_feature_loss', gen_feature_loss, step=epoch)
tf.summary.scalar('gen_feature_loss_VGG', gen_feature_loss_VGG, step=epoch)
tf.summary.scalar('disc_loss', disc_loss, step=epoch)
#print("done summary, rank:", hvd.rank())
def fit(train_ds, initial_epoch, epochs, test_ds):
for epoch in range(initial_epoch, epochs):
start = time.time()
if hvd.rank()==0:
print("Epoch: ", epoch)
# Train
for n, (input_image1, input_image2, target) in (train_ds.enumerate()):
if hvd.rank() == 0:
print('.', end='')
#print("boh")
if (n+1) % 100 == 0:
print()
train_step(input_image1, input_image2, target, epoch, (n == 0) & (epoch == initial_epoch))
if hvd.rank() == 0:
print()
print ('Time taken for epoch {} is {} sec\n'.format(epoch, time.time()-start))
# saving (checkpoint) the model every 20 epochs
if (epoch + 1) % 10 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(checkpoint.step_counter), save_path))
checkpoint.step_counter.assign_add(1)
#plot sample image
if hvd.rank()==0:
for example_input1, example_input2, example_target in test_ds.take(1):
generate_images(generator, example_input1, example_input2, example_target)
"""
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator)
"""
def predict(test_ds):
#i=0
#for inp1, inp2, tar in A_B_dataset_test.take(len_test):
for n, (input_image1, input_image2, target) in (test_ds.enumerate()):
t_A, t_B_10m, t_B_20m, t_A2B = generate_images(generator, input_image1, input_image2, target)
d_or_30m[n, :, :, :] = t_A
d_or_10m[n, :, :, :] = t_B_10m
d_or_20m[n, :, :, :] = t_B_20m
d_pred[n, :, :, :] = clip_data(t_A2B, p1A, p99A)
#i = i + 1
"""
np.save("output/predicted_im.npy", predicted_images)
np.save("output/original_im_10m.npy", original_images_10m)
np.save("output/original_im_20m.npy", original_images_20m)
np.save("output/original_im_30m.npy", original_images_30m)
"""
#checkpoint.save(file_prefix = checkpoint_prefix)
#run
if __name__ == "__main__":
if load_pretrain:
checkpoint_pregen.restore(manager_pregen.latest_checkpoint)
print("Restored pregenerator from {}".format(manager_pregen.latest_checkpoint))
print("initial epoch: ", initial_epoch)
elif manager.latest_checkpoint:
checkpoint.restore(manager.latest_checkpoint)
print("Restored from {}".format(manager.latest_checkpoint))
initial_epoch = int(checkpoint.step_counter) + 1
generate_images.counter = int(checkpoint.step_counter)
print("initial epoch: ", initial_epoch)
else:
print("Initializing from scratch.")
print("test_ds: ", A_B_dataset_test)
if TRAIN == True:
print("train_ds: ", A_B_dataset)
fit(A_B_dataset, initial_epoch, EPOCHS, A_B_dataset_test)
elif hvd.rank() == 0:
generate_images.counter = 0
#predict(A_B_dataset_test)
#predict(A_B_dataset.concatenate(A_B_dataset_test))
predict(A_B_dataset_test)
h5_file_pred.close()
else:
time.sleep(60*60*2)