Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
AMBS
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
AMBS
Commits
40a570fe
Commit
40a570fe
authored
5 years ago
by
Bing Gong
Browse files
Options
Downloads
Patches
Plain Diff
update scripts
parent
ac96c837
Branches
Branches containing commit
Tags
Tags containing commit
No related merge requests found
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
scripts/generate_transfer_learning_finetune.py
+186
-187
186 additions, 187 deletions
scripts/generate_transfer_learning_finetune.py
with
186 additions
and
187 deletions
scripts/generate_transfer_learning_finetune.py
+
186
−
187
View file @
40a570fe
...
@@ -33,7 +33,7 @@ from matplotlib.colors import LinearSegmentedColormap
...
@@ -33,7 +33,7 @@ from matplotlib.colors import LinearSegmentedColormap
from
skimage.metrics
import
structural_similarity
as
ssim
from
skimage.metrics
import
structural_similarity
as
ssim
import
pickle
import
pickle
with
open
(
"
./splits_size_64_64_1/
geo_info.json
"
,
"
r
"
)
as
json_file
:
with
open
(
"
geo_info.json
"
,
"
r
"
)
as
json_file
:
geo
=
json
.
load
(
json_file
)
geo
=
json
.
load
(
json_file
)
lat
=
[
round
(
i
,
2
)
for
i
in
geo
[
"
lat
"
]]
lat
=
[
round
(
i
,
2
)
for
i
in
geo
[
"
lat
"
]]
lon
=
[
round
(
i
,
2
)
for
i
in
geo
[
"
lon
"
]]
lon
=
[
round
(
i
,
2
)
for
i
in
geo
[
"
lon
"
]]
...
@@ -196,77 +196,77 @@ def main():
...
@@ -196,77 +196,77 @@ def main():
gen_images_all
=
[]
gen_images_all
=
[]
input_images_all
=
[]
input_images_all
=
[]
# while True:
while
True
:
# print("Sample id", sample_ind)
print
(
"
Sample id
"
,
sample_ind
)
# gen_images_stochastic = []
gen_images_stochastic
=
[]
# if args.num_samples and sample_ind >= args.num_samples:
if
args
.
num_samples
and
sample_ind
>=
args
.
num_samples
:
# break
break
# try:
try
:
# input_results = sess.run(inputs)
input_results
=
sess
.
run
(
inputs
)
# input_images = input_results["images"]
input_images
=
input_results
[
"
images
"
]
# input_images_all.extend(input_images)
input_images_all
.
extend
(
input_images
)
# with open(os.path.join(args.output_png_dir, "input_images_all"), "wb") as input_files:
with
open
(
os
.
path
.
join
(
args
.
output_png_dir
,
"
input_images_all
"
),
"
wb
"
)
as
input_files
:
# pickle.dump(list(input_images_all), input_files)
pickle
.
dump
(
list
(
input_images_all
),
input_files
)
#
# except tf.errors.OutOfRangeError:
except
tf
.
errors
.
OutOfRangeError
:
# break
break
#
# feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
feed_dict
=
{
input_ph
:
input_results
[
name
]
for
name
,
input_ph
in
input_phs
.
items
()}
# for stochastic_sample_ind in range(args.num_stochastic_samples):
for
stochastic_sample_ind
in
range
(
args
.
num_stochastic_samples
):
# gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)
gen_images
=
sess
.
run
(
model
.
outputs
[
'
gen_images
'
],
feed_dict
=
feed_dict
)
# gen_images_stochastic.append(gen_images)
gen_images_stochastic
.
append
(
gen_images
)
# print("Stochastic_sample,", stochastic_sample_ind)
print
(
"
Stochastic_sample,
"
,
stochastic_sample_ind
)
# for i in range(args.batch_size):
for
i
in
range
(
args
.
batch_size
):
# print("batch", i)
print
(
"
batch
"
,
i
)
# #colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
#colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
# cmap_name = 'my_list'
cmap_name
=
'
my_list
'
# if sample_ind < 20 and i == 1:
if
sample_ind
<
20
and
i
==
1
:
# name = 'Stochastic_id_' + str(stochastic_sample_ind) + 'Batch_id_' + str(
name
=
'
Stochastic_id_
'
+
str
(
stochastic_sample_ind
)
+
'
Batch_id_
'
+
str
(
# sample_ind) + " + Sample_" + str(i)
sample_ind
)
+
"
+ Sample_
"
+
str
(
i
)
# gen_images_ = np.array(list(input_images[i,:context_frames]) + list(gen_images[i,-future_length:, :]))
gen_images_
=
np
.
array
(
list
(
input_images
[
i
,:
context_frames
])
+
list
(
gen_images
[
i
,
-
future_length
:,
:]))
# #gen_images_ = gen_images[i, :]
#gen_images_ = gen_images[i, :]
# input_images_ = input_images[i, :]
input_images_
=
input_images
[
i
,
:]
# input_gen_diff = (input_images_[:, :, :,0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) - (gen_images_[:, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922)
input_gen_diff
=
(
input_images_
[:,
:,
:,
0
]
*
(
321.46630859375
-
235.2141571044922
)
+
235.2141571044922
)
-
(
gen_images_
[:,
:,
:,
0
]
*
(
321.46630859375
-
235.2141571044922
)
+
235.2141571044922
)
#
# gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in
gen_mse_avg_
=
[
np
.
mean
(
input_gen_diff
[
frame
,
:,
:]
**
2
)
for
frame
in
# range(sequence_length)] # return the list with 10 (sequence) mse
range
(
sequence_length
)]
# return the list with 10 (sequence) mse
fig
=
plt
.
figure
(
figsize
=
(
18
,
6
))
gs
=
gridspec
.
GridSpec
(
1
,
10
)
gs
.
update
(
wspace
=
0.
,
hspace
=
0.
)
ts
=
[
0
,
5
,
9
,
10
,
12
,
14
,
16
,
18
,
19
]
xlables
=
[
round
(
i
,
2
)
for
i
in
list
(
np
.
linspace
(
np
.
min
(
lon
),
np
.
max
(
lon
),
5
))]
ylabels
=
[
round
(
i
,
2
)
for
i
in
list
(
np
.
linspace
(
np
.
max
(
lat
),
np
.
min
(
lat
),
5
))]
for
t
in
range
(
len
(
ts
)):
#if t==0 : ax1=plt.subplot(gs[t])
ax1
=
plt
.
subplot
(
gs
[
t
])
input_image
=
input_images_
[
ts
[
t
],
:,
:,
0
]
*
(
321.46630859375
-
235.2141571044922
)
+
235.2141571044922
plt
.
imshow
(
input_image
,
cmap
=
'
jet
'
,
vmin
=
270
,
vmax
=
300
)
ax1
.
title
.
set_text
(
"
t =
"
+
str
(
ts
[
t
]
+
1
))
plt
.
setp
([
ax1
],
xticks
=
[],
xticklabels
=
[],
yticks
=
[],
yticklabels
=
[])
if
t
==
0
:
plt
.
setp
([
ax1
],
xticks
=
list
(
np
.
linspace
(
0
,
64
,
3
)),
xticklabels
=
xlables
,
yticks
=
list
(
np
.
linspace
(
0
,
64
,
3
)),
yticklabels
=
ylabels
)
plt
.
ylabel
(
"
Ground Truth
"
,
fontsize
=
10
)
plt
.
savefig
(
os
.
path
.
join
(
args
.
output_png_dir
,
"
Ground_Truth_Sample_
"
+
str
(
name
)
+
"
.jpg
"
))
plt
.
clf
()
fig
=
plt
.
figure
(
figsize
=
(
12
,
6
))
gs
=
gridspec
.
GridSpec
(
1
,
10
)
gs
.
update
(
wspace
=
0.
,
hspace
=
0.
)
ts
=
[
10
,
12
,
14
,
16
,
18
,
19
]
for
t
in
range
(
len
(
ts
)):
#if t==0 : ax1=plt.subplot(gs[t])
ax1
=
plt
.
subplot
(
gs
[
t
])
gen_image
=
gen_images_
[
ts
[
t
],
:,
:,
0
]
*
(
321.46630859375
-
235.2141571044922
)
+
235.2141571044922
plt
.
imshow
(
gen_image
,
cmap
=
'
jet
'
,
vmin
=
270
,
vmax
=
300
)
ax1
.
title
.
set_text
(
"
t =
"
+
str
(
ts
[
t
]
+
1
))
plt
.
setp
([
ax1
],
xticks
=
[],
xticklabels
=
[],
yticks
=
[],
yticklabels
=
[])
plt
.
savefig
(
os
.
path
.
join
(
args
.
output_png_dir
,
"
Predicted_Sample_
"
+
str
(
name
)
+
"
.jpg
"
))
plt
.
clf
()
# fig = plt.figure(figsize=(18,6))
# gs = gridspec.GridSpec(1, 10)
# gs.update(wspace = 0., hspace = 0.)
# ts = [0,5,9,10,12,14,16,18,19]
# xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))]
# ylabels = [round(i,2) for i in list(np.linspace(np.max(lat),np.min(lat),5))]
#
# for t in range(len(ts)):
# #if t==0 : ax1=plt.subplot(gs[t])
# ax1 = plt.subplot(gs[t])
# input_image = input_images_[ts[t], :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
# plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300)
# ax1.title.set_text("t = " + str(ts[t]+1))
# plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
#
# if t == 0:
# plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels)
# plt.ylabel("Ground Truth", fontsize=10)
# plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg"))
# plt.clf()
#
# fig = plt.figure(figsize=(12,6))
# gs = gridspec.GridSpec(1, 10)
# gs.update(wspace = 0., hspace = 0.)
# ts = [10,12,14,16,18,19]
# for t in range(len(ts)):
# #if t==0 : ax1=plt.subplot(gs[t])
# ax1 = plt.subplot(gs[t])
# gen_image = gen_images_[ts[t], :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
# plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300)
# ax1.title.set_text("t = " + str(ts[t]+1))
# plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
#
# plt.savefig(os.path.join(args.output_png_dir, "Predicted_Sample_" + str(name) + ".jpg"))
# plt.clf()
#
# fig = plt.figure()
# fig = plt.figure()
# gs = gridspec.GridSpec(4,6)
# gs = gridspec.GridSpec(4,6)
# gs.update(wspace = 0.7,hspace=0.8)
# gs.update(wspace = 0.7,hspace=0.8)
...
@@ -356,30 +356,30 @@ def main():
...
@@ -356,30 +356,30 @@ def main():
# ani = animation.FuncAnimation(fig, animation_sample, frames=len(gen_mse_avg_), interval = 1000,
# ani = animation.FuncAnimation(fig, animation_sample, frames=len(gen_mse_avg_), interval = 1000,
# repeat_delay=2000)
# repeat_delay=2000)
# ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4"))
# ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4"))
#
# else:
# pass
else
:
pass
#
# if sample_ind == 0:
# gen_images_all = gen_images_stochastic
if
sample_ind
==
0
:
# else:
gen_images_all
=
gen_images_stochastic
# gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1)
else
:
#
gen_images_all
=
np
.
concatenate
((
np
.
array
(
gen_images_all
),
np
.
array
(
gen_images_stochastic
)),
axis
=
1
)
# if args.num_stochastic_samples == 1:
# with open(os.path.join(args.output_png_dir, "gen_images_all"), "wb") as gen_files:
if
args
.
num_stochastic_samples
==
1
:
# pickle.dump(list(gen_images_all[0]), gen_files)
with
open
(
os
.
path
.
join
(
args
.
output_png_dir
,
"
gen_images_all
"
),
"
wb
"
)
as
gen_files
:
# else:
pickle
.
dump
(
list
(
gen_images_all
[
0
]),
gen_files
)
# with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files:
else
:
# pickle.dump(list(gen_images_stochastic), gen_files)
with
open
(
os
.
path
.
join
(
args
.
output_png_dir
,
"
gen_images_sample_id_
"
+
str
(
sample_ind
)),
"
wb
"
)
as
gen_files
:
# with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files:
pickle
.
dump
(
list
(
gen_images_stochastic
),
gen_files
)
# pickle.dump(list(gen_images_all), gen_files)
with
open
(
os
.
path
.
join
(
args
.
output_png_dir
,
"
gen_images_all_stochastic
"
),
"
wb
"
)
as
gen_files
:
#
pickle
.
dump
(
list
(
gen_images_all
),
gen_files
)
#
#
#
# sample_ind += args.batch_size
sample_ind
+=
args
.
batch_size
# # for i, gen_mse_avg_ in enumerate(gen_mse_avg):
# # for i, gen_mse_avg_ in enumerate(gen_mse_avg):
...
@@ -390,8 +390,7 @@ def main():
...
@@ -390,8 +390,7 @@ def main():
# # plt.xlabel("Frames")
# # plt.xlabel("Frames")
# # plt.ylabel("MSE_AVG")
# # plt.ylabel("MSE_AVG")
# # #X = list(range(len(gen_mse_avg_)))
# # #X = list(range(len(gen_mse_avg_)))
# # #for t, gen_mse_avg_ in enume
# # #for t, gen_mse_avg_ in enumerate(gen_mse_avg):
# rate(gen_mse_avg):
# # def animate_metric(j):
# # def animate_metric(j):
# # data = gen_mse_avg_[:(j+1)]
# # data = gen_mse_avg_[:(j+1)]
# # x = list(range(len(gen_mse_avg_)))[:(j+1)]
# # x = list(range(len(gen_mse_avg_)))[:(j+1)]
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment