Skip to content
Snippets Groups Projects
Commit 00a5d43c authored by gong1's avatar gong1
Browse files

Fix the issue for plot both mse and psnr

parent a5df5d07
No related branches found
No related tags found
No related merge requests found
Pipeline #65969 passed
...@@ -635,7 +635,7 @@ class Postprocess(TrainModel): ...@@ -635,7 +635,7 @@ class Postprocess(TrainModel):
""" """
save list to pickle file in results directory save list to pickle file in results directory
""" """
self.eval_metrics = {} eval_metrics = {}
if metric == "mse": if metric == "mse":
fcst_metric_all = self.stochastic_loss_all_batches # mse loss fcst_metric_all = self.stochastic_loss_all_batches # mse loss
prst_metric_all = self.prst_mse_avg_batches prst_metric_all = self.prst_mse_avg_batches
...@@ -646,18 +646,19 @@ class Postprocess(TrainModel): ...@@ -646,18 +646,19 @@ class Postprocess(TrainModel):
raise ValueError( raise ValueError(
"We currently only support metric 'mse' and 'psnr' as evaluation metric for detereminstic forecasting") "We currently only support metric 'mse' and 'psnr' as evaluation metric for detereminstic forecasting")
for ts in range(self.future_length): for ts in range(self.future_length):
self.eval_metrics["persistent_ts_" + str(ts)] = [str(prst_metric_all[ts])] eval_metrics["persistent_ts_" + str(ts)] = [str(prst_metric_all[ts])]
# for stochastic_sample_ind in range(self.num_stochastic_samples): # for stochastic_sample_ind in range(self.num_stochastic_samples):
self.eval_metrics["model_ts_" + str(ts)] = [str(i) for i in fcst_metric_all[:, ts]] eval_metrics["model_ts_" + str(ts)] = [str(i) for i in fcst_metric_all[:, ts]]
with open(os.path.join(self.results_dir, metric), "w") as fjs: with open(os.path.join(self.results_dir, metric), "w") as fjs:
json.dump(self.eval_metrics, fjs) json.dump(eval_metrics, fjs)
return eval_metrics
def save_eval_metric_to_json(self): def save_eval_metric_to_json(self):
""" """
Save all the evaluation metrics to the json file Save all the evaluation metrics to the json file
""" """
self.save_one_eval_metric_to_json(metric="mse") self.mse_metrics = self.save_one_eval_metric_to_json(metric="mse")
self.save_one_eval_metric_to_json(metric="psnr") self.psnr_metrics = self.save_one_eval_metric_to_json(metric="psnr")
@staticmethod @staticmethod
def check_gen_images_stochastic_shape(gen_images_stochastic): def check_gen_images_stochastic_shape(gen_images_stochastic):
...@@ -977,33 +978,50 @@ class Postprocess(TrainModel): ...@@ -977,33 +978,50 @@ class Postprocess(TrainModel):
var = pickle.load(infile) var = pickle.load(infile)
return var return var
def plot_evalution_metrics(self): def plot_evaluation_per_metric(self, eval_metrics,metric_name="mse"):
model_names = self.eval_metrics.keys()
model_names = eval_metrics.keys()
model_ts_errors = [] #[timestamps,stochastic_number] model_ts_errors = [] #[timestamps,stochastic_number]
persistent_ts_errors = [] persistent_ts_errors = []
for ts in range(self.future_length - 1): for ts in range(self.future_length ):
stochastic_err = self.eval_metrics["model_ts_" + str(ts)] stochastic_err = eval_metrics["model_ts_" + str(ts)]
stochastic_err = [float(item) for item in stochastic_err] stochastic_err = [float(item) for item in stochastic_err]
model_ts_errors.append(stochastic_err) model_ts_errors.append(stochastic_err)
persistent_err = self.eval_metrics["persistent_ts_" + str(ts)] persistent_err = eval_metrics["persistent_ts_" + str(ts)]
persistent_err = float(persistent_err[0]) persistent_err = float(persistent_err[0])
persistent_ts_errors.append(persistent_err) persistent_ts_errors.append(persistent_err)
if len(np.array(model_ts_errors).shape) == 1: if len(np.array(model_ts_errors).shape) == 1:
model_ts_errors = np.expand_dims(np.array(model_ts_errors), axis=1) model_ts_errors = np.expand_dims(np.array(model_ts_errors), axis=1)
model_ts_errors = np.array(model_ts_errors) model_ts_errors = np.array(model_ts_errors)
persistent_ts_errors = np.array(persistent_ts_errors) persistent_ts_errors = np.array(persistent_ts_errors)
fig = plt.figure(figsize=(6, 4)) fig = plt.figure(figsize=(6, 4))
ax = plt.axes([0.1, 0.15, 0.75, 0.75]) ax = plt.axes([0.1, 0.15, 0.75, 0.75])
for stoch_ind in range(len(model_ts_errors[0])): for stoch_ind in range(len(model_ts_errors[0])):
plt.plot(model_ts_errors[:, stoch_ind], lw=1) plt.plot(model_ts_errors[:, stoch_ind], lw=1,label=self.model + "_" + str(stoch_ind))
plt.plot(persistent_ts_errors) plt.plot(persistent_ts_errors,label="persistent")
plt.xticks(np.arange(1, self.future_length)) if metric_name == "mse":
ax.set_ylim(0., 10) max_errors = 6
legend = ax.legend(loc='upper left') min_errors = 0
elif metric_name == "psnr":
max_errors = 0
min_errors = -13
else:
raise ("Currently we only support evaluation metrics mse and psnr")
plt.xticks(np.arange(0, self.future_length))
ax.set_ylim(min_errors, max_errors)
legend = ax.legend(loc='upper right',bbox_to_anchor=(1.15, 1))
ax.set_xlabel('Time stamps') ax.set_xlabel('Time stamps')
ax.set_ylabel("Errors") ax.set_ylabel(metric_name)
print("Saving plot for err") print("Saving plot for err")
plt.savefig(os.path.join(self.results_dir, "evaluation.png")) plt.savefig(os.path.join(self.results_dir, metric_name + "_eval.png"))
def plot_evalution_metrics(self):
self.plot_evaluation_per_metric(eval_metrics=self.mse_metrics, metric_name="mse")
self.plot_evaluation_per_metric(eval_metrics=self.psnr_metrics, metric_name="psnr")
def plot_example_forecasts(self, metric="mse", var_ind=0): def plot_example_forecasts(self, metric="mse", var_ind=0):
""" """
... ...
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment