Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
AMBS
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
AMBS
Commits
4f5441d2
Commit
4f5441d2
authored
2 years ago
by
BING GONG
Browse files
Options
Downloads
Patches
Plain Diff
Add weatherBench models
parent
a503159a
No related branches found
No related tags found
No related merge requests found
Pipeline
#101962
passed
2 years ago
Stage: build
Stage: test
Stage: deploy
Changes
1
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py
+144
-0
144 additions, 0 deletions
...odel_modules/video_prediction/models/weatherBench3DCNN.py
with
144 additions
and
0 deletions
video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py
0 → 100644
+
144
−
0
View file @
4f5441d2
# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
#
# SPDX-License-Identifier: MIT
# Weather Bench models
__email__
=
"
b.gong@fz-juelich.de
"
__author__
=
"
Bing Gong
"
__date__
=
"
2021-04-13
"
import
tensorflow
as
tf
from
tensorflow.contrib.training
import
HParams
from
model_modules.video_prediction.layers
import
layer_def
as
ld
from
model_modules.video_prediction.losses
import
*
class
WeatherBenchModel
(
object
):
def
__init__
(
self
,
hparams_dict
=
None
):
"""
This is class for building weahterBench architecture by using updated hparameters
args:
mode :str,
"
train
"
or
"
val
"
, side note: mode may not be used in the convLSTM, but this will be a useful argument for the GAN-based model
hparams_dict: dict, the dictionary contains the hparaemters names and values
"""
self
.
hparams_dict
=
hparams_dict
self
.
hparams
=
self
.
parse_hparams
()
self
.
learning_rate
=
self
.
hparams
.
lr
self
.
filters
=
self
.
hparams
.
filters
self
.
kernels
=
self
.
hparams
.
kernes
self
.
context_frames
=
self
.
hparams
.
context_frames
self
.
sequence_length
=
self
.
hparams
.
sequence_length
self
.
predict_frames
=
self
.
sequence_length
-
self
.
context_frames
self
.
max_epochs
=
self
.
hparams
.
max_epochs
self
.
loss_fun
=
self
.
hparams
.
loss_fun
self
.
batch_size
=
self
.
hparams
.
batch_size
self
.
recon_weight
=
self
.
hparams
.
recon_weight
self
.
outputs
=
{}
self
.
total_loss
=
None
def
get_default_hparams
(
self
):
return
HParams
(
**
self
.
get_default_hparams_dict
())
def
parse_hparams
(
self
):
"""
Parse the hparams setting to ovoerride the default ones
"""
parsed_hparams
=
self
.
get_default_hparams
().
override_from_dict
(
self
.
hparams_dict
or
{})
return
parsed_hparams
def
get_default_hparams_dict
(
self
):
"""
The function that contains default hparams
Returns:
A dict with the following hyperparameters.
context_frames : the number of ground-truth frames to pass in at start.
sequence_length : the number of frames in the video sequence
max_epochs : the number of epochs to train model
lr : learning rate
loss_fun : the loss function
"""
hparams
=
dict
(
context_frames
=
12
,
sequence_length
=
24
,
max_epochs
=
20
,
batch_size
=
40
,
lr
=
0.001
,
loss_fun
=
"
mse
"
,
shuffle_on_val
=
True
,
filter
=
4
,
kernels
=
4
)
return
hparams
def
build_graph
(
self
,
x
):
self
.
is_build_graph
=
False
self
.
inputs
=
x
self
.
x
=
x
[
"
images
"
]
self
.
global_step
=
tf
.
train
.
get_or_create_global_step
()
original_global_variables
=
tf
.
global_variables
()
# Architecture
x_hat
=
self
.
build_model
(
x
,
self
.
filters
,
self
.
kernels
,
dr
=
0
)
# Loss
self
.
total_loss
=
l1_loss
(
x
[...,
0
],
x_hat
[...,
0
])
# Optimizer
self
.
train_op
=
tf
.
train
.
AdamOptimizer
(
learning_rate
=
self
.
learning_rate
).
minimize
(
self
.
total_loss
,
var_list
=
self
.
gen_vars
)
# outputs
self
.
outputs
[
"
total_loss
"
]
=
self
.
total_loss
# Summary op
tf
.
summary
.
scalar
(
"
total_loss
"
,
self
.
total_loss
)
self
.
summary_op
=
tf
.
summary
.
merge_all
()
global_variables
=
[
var
for
var
in
tf
.
global_variables
()
if
var
not
in
original_global_variables
]
self
.
saveable_variables
=
[
self
.
global_step
]
+
global_variables
self
.
is_build_graph
=
True
return
self
.
is_build_graph
def
build_model
(
self
,
x
,
filters
,
kernels
,
dr
=
0
):
"""
Fully convolutional network
"""
for
f
,
k
in
zip
(
filters
[:
-
1
],
kernels
[:
-
1
]):
x
=
PeriodicConv2D
(
x
,
f
,
k
)
x
=
tf
.
nn
.
elu
(
x
)
if
dr
>
0
:
x
=
tf
.
nn
.
dropout
(
x
,
dr
)
output
=
PeriodicConv2D
(
x
,
filters
[
-
1
],
kernels
[
-
1
])
return
output
class
PeriodicPadding2D
(
object
):
def
__init__
(
self
,
x
,
pad_width
):
self
.
pad_width
=
pad_width
def
call
(
self
,
inputs
,
**
kwargs
):
if
self
.
pad_width
==
0
:
return
inputs
inputs_padded
=
tf
.
concat
(
[
inputs
[:,
:,
-
self
.
pad_width
:,
:],
inputs
,
inputs
[:,
:,
:
self
.
pad_width
,
:]],
axis
=
2
)
# Zero padding in the lat direction
inputs_padded
=
tf
.
pad
(
inputs_padded
,
[[
0
,
0
],
[
self
.
pad_width
,
self
.
pad_width
],
[
0
,
0
],
[
0
,
0
]])
return
inputs_padded
class
PeriodicConv2D
(
object
):
def
__init__
(
self
,
filters
,
kernel_size
,
conv_kwargs
=
{}):
self
.
filters
=
filters
self
.
kernel_size
=
kernel_size
self
.
conv_kwargs
=
conv_kwargs
if
type
(
kernel_size
)
is
not
int
:
assert
kernel_size
[
0
]
==
kernel_size
[
1
],
'
PeriodicConv2D only works for square kernels
'
kernel_size
=
kernel_size
[
0
]
self
.
pad_width
=
(
kernel_size
-
1
)
//
2
def
call
(
self
,
inputs
):
self
.
padding
=
PeriodicPadding2D
(
inputs
,
self
.
pad_width
)
self
.
conv
=
ld
.
conv2D
(
self
.
padding
,
self
.
filters
,
self
.
kernel_size
,
padding
=
'
valid
'
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment