Skip to content
Snippets Groups Projects
Commit 6fb659a9 authored by Yan Ji's avatar Yan Ji
Browse files

update main_predict.py

parent 3e35dd96
Branches
No related tags found
No related merge requests found
Pipeline #105909 failed
...@@ -80,7 +80,7 @@ class Postprocess(TrainModel): ...@@ -80,7 +80,7 @@ class Postprocess(TrainModel):
self.run_mode = run_mode self.run_mode = run_mode
self.data_mode = data_mode self.data_mode = data_mode
#self.channel = channel #self.channel = channel
#self.lquick = lquick self.lquick = lquick
#self.frac_data = frac_data #self.frac_data = frac_data
# Attributes set during runtime # Attributes set during runtime
#self.norm_cls = None #self.norm_cls = None
...@@ -116,6 +116,9 @@ class Postprocess(TrainModel): ...@@ -116,6 +116,9 @@ class Postprocess(TrainModel):
#self.cond_quantile_vars = self.init_cond_quantile_vars() #self.cond_quantile_vars = self.init_cond_quantile_vars()
# setup test dataset and model # setup test dataset and model
self.test_dataset, self.num_samples_per_epoch = self.setup_dataset() self.test_dataset, self.num_samples_per_epoch = self.setup_dataset()
self.lats = self.test_dataset.lats
self.lons = self.test_dataset.lons
self.vars_in = self.test_dataset.variables
# if lquick and self.test_dataset.shuffled: # if lquick and self.test_dataset.shuffled:
# self.num_samples_per_epoch = Postprocess.reduce_samples(self.num_samples_per_epoch, frac_data) # self.num_samples_per_epoch = Postprocess.reduce_samples(self.num_samples_per_epoch, frac_data)
# self.num_samples_per_epoch = 100 # reduced number of epoch samples -> useful for testing # self.num_samples_per_epoch = 100 # reduced number of epoch samples -> useful for testing
...@@ -569,12 +572,11 @@ class Postprocess(TrainModel): ...@@ -569,12 +572,11 @@ class Postprocess(TrainModel):
input_results = self.sess.run(self.input_iter) input_results = self.sess.run(self.input_iter)
t_starts = self.sess.run(self.ts_iter) t_starts = self.sess.run(self.ts_iter)
print('self.input_iter: {}'.format(self.input_iter)) # print('self.input_iter: {}'.format(self.input_iter.items()))
print('input_results: {}'.format(input_results))
# t_starts = input_results["T_start"] # t_starts = input_results["T_start"]
# feed_dict = {input_ph: input_results[name] for name, input_ph in self.input_iter.items()} # feed_dict = {input_ph: input_results[name] for name, input_ph in self.input_iter.items()}
feed_dict = {"x": input_results} feed_dict = {"IteratorGetNext:0": input_results}
gen_images = self.sess.run(self.video_model.outputs['gen_images'], feed_dict=feed_dict) gen_images = self.sess.run(self.video_model.outputs['gen_images'], feed_dict=feed_dict)
# sanity check on length of forecast sequence # sanity check on length of forecast sequence
...@@ -679,20 +681,20 @@ class Postprocess(TrainModel): ...@@ -679,20 +681,20 @@ class Postprocess(TrainModel):
""" """
method = Postprocess.get_init_time.__name__ method = Postprocess.get_init_time.__name__
t_starts = np.squeeze(np.asarray(t_starts)) #t_starts = np.squeeze(np.asarray(t_starts))
if not np.ndim(t_starts) == 1: #print('t_starts: {}'.format(t_starts))
raise ValueError("%{0}: Inputted t_starts must be a 1D list/array of date-strings with format %Y%m%d%H" #if not np.ndim(t_starts) == 1:
.format(method)) # raise ValueError("%{0}: Inputted t_starts must be a 1D list/array of date-strings with format %Y%m%d%H"
# .format(method))
for i, t_start in enumerate(t_starts): for i, t_start in enumerate(t_starts):
try: try:
#seq_ts = pd.date_range(dt.datetime.strptime(str(t_start), "%Y%m%d%H%M"), periods=self.context_frames, seq_ts = pd.date_range(dt.datetime.strptime(str(t_start[0])[2:-1], "%Y-%m-%dT%H:%M:00"), periods=self.context_frames,
# freq="10min")
print('t_start: ',t_start)
t0 = pd.date_range(dt.datetime.strptime(str(t_start), "%Y%m%d%H%M"), periods=4,
freq="-10min")
t1 = pd.date_range(dt.datetime.strptime(str(t_start), "%Y%m%d%H%M"),periods=self.context_frames-3,
freq="10min") freq="10min")
seq_ts = t0.append(t1)[1:] #t0 = pd.date_range(dt.datetime.strptime(str(t_start), "%Y-%m-%dT%H:%M:00"), periods=4,
# freq="-10min")
#t1 = pd.date_range(dt.datetime.strptime(str(t_start), "%Y-%m-%dT%H:%M:00"),periods=self.context_frames-3,
# freq="10min")
#seq_ts = t0.append(t1)[1:]
print('seq_ts: ',seq_ts) print('seq_ts: ',seq_ts)
except Exception as err: except Exception as err:
print("%{0}: Could not convert {1} to datetime object. Ensure that the date-string format is 'Y%m%d%H'". print("%{0}: Could not convert {1} to datetime object. Ensure that the date-string format is 'Y%m%d%H'".
......
...@@ -86,6 +86,9 @@ class GZawsDataset(BaseDataset): ...@@ -86,6 +86,9 @@ class GZawsDataset(BaseDataset):
self.n_samples = data_arr.shape[0] self.n_samples = data_arr.shape[0]
self.n_vars = len(self.variables) self.n_vars = len(self.variables)
self.lons = ds["lon"].values
self.lats = ds["lat"].values
return data_arr, init_times return data_arr, init_times
def make_dataset(self): def make_dataset(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment