Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
downscaling_maelstrom
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
downscaling_maelstrom
Merge requests
!1
Draft: Resolve "Loss functions incl. vector norms"
Code
Review changes
Check out branch
Download
Patches
Plain diff
Merged
Draft: Resolve "Loss functions incl. vector norms"
michael_issue076-vector_losses
into
develop
Overview
0
Commits
3
Pipelines
0
Changes
2
Merged
Michael Langguth
requested to merge
michael_issue076-vector_losses
into
develop
1 year ago
Overview
0
Commits
3
Pipelines
0
Changes
2
Expand
Closes
#76 (closed)
0
0
Merge request reports
Compare
develop
version 2
e0e0063c
1 year ago
version 1
642bb425
1 year ago
develop (base)
and
latest version
latest version
60a5bd1a
3 commits,
1 year ago
version 2
e0e0063c
2 commits,
1 year ago
version 1
642bb425
1 commit,
1 year ago
2 files
+
182
−
30
Inline
Compare changes
Side-by-side
Inline
Show whitespace changes
Show one file at a time
Files
2
Search (e.g. *.vue) (Ctrl+P)
downscaling_ap5/models/custom_losses.py
0 → 100644
+
160
−
0
Options
# SPDX-FileCopyrightText: 2022 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
#
# SPDX-License-Identifier: MIT
"""
Some custmoized losses (e.g. on vector quantities)
"""
__author__
=
"
Michael Langguth
"
__email__
=
"
m.langguth@fz-juelich.de
"
__date__
=
"
2023-06-16
"
__update__
=
"
2023-06-16
"
# import module
import
inspect
import
tensorflow
as
tf
def
fix_channels
(
n_channels
):
"""
Decorator to fix number of channels in loss functions.
"""
def
decorator
(
func
):
def
wrapper
(
y_true
,
y_pred
,
**
func_kwargs
):
return
func
(
y_true
,
y_pred
,
n_channels
,
**
func_kwargs
)
return
wrapper
return
decorator
def
get_custom_loss
(
loss_name
,
**
kwargs
):
"""
Loss factory including some customized losses and all available Keras losses
:param loss_name: name of the loss function
:return: the respective layer to deploy desired activation
"""
known_losses
=
[
"
mse_channels
"
,
"
mae_channels
"
,
"
mae_vec
"
,
"
mse_vec
"
,
"
critic
"
,
"
critic_generator
"
]
+
\
[
loss_cls
[
0
]
for
loss_cls
in
inspect
.
getmembers
(
tf
.
keras
.
losses
,
inspect
.
isclass
)]
loss_name
=
loss_name
.
lower
()
n_channels
=
kwargs
.
get
(
"
n_channels
"
,
None
)
if
loss_name
==
"
mse_channels
"
:
assert
n_channels
>
0
,
f
"
n_channels must be a number larger than zero, but is
{
n_channels
}
.
"
loss_fn
=
fix_channels
(
**
kwargs
)(
mse_channels
)
elif
loss_name
==
"
mae_channels
"
:
assert
n_channels
>
0
,
f
"
n_channels must be a number larger than zero, but is
{
n_channels
}
.
"
loss_fn
=
fix_channels
(
**
kwargs
)(
mae_channels
)
elif
loss_name
==
"
mae_vec
"
:
assert
n_channels
>
0
,
f
"
n_channels must be a number larger than zero, but is
{
n_channels
}
.
"
loss_fn
=
fix_channels
(
**
kwargs
)(
mae_vec
)
elif
loss_name
==
"
mse_vec
"
:
assert
n_channels
>
0
,
f
"
n_channels must be a number larger than zero, but is
{
n_channels
}
.
"
loss_fn
=
fix_channels
(
**
kwargs
)(
mse_vec
)
elif
loss_name
==
"
critic
"
:
loss_fn
=
critic_loss
elif
loss_name
==
"
critic_generator
"
:
loss_fn
=
critic_gen_loss
else
:
try
:
loss_fn
=
getattr
(
tf
.
keras
.
losses
,
loss_name
)(
**
kwargs
)
except
AttributeError
:
raise
ValueError
(
f
"
{
loss_name
}
is not a valid loss function. Choose one of the following:
{
known_losses
}
"
)
return
loss_fn
def
mae_channels
(
x
,
x_hat
,
n_channels
:
int
=
None
,
channels_last
:
bool
=
True
,
avg_channels
:
bool
=
False
):
rloss
=
0.
if
channels_last
:
# get MAE for all output heads
for
i
in
range
(
n_channels
):
rloss
+=
tf
.
reduce_mean
(
tf
.
abs
(
tf
.
squeeze
(
x_hat
[...,
i
])
-
x
[...,
i
]))
else
:
for
i
in
range
(
n_channels
):
rloss
+=
tf
.
reduce_mean
(
tf
.
abs
(
tf
.
squeeze
(
x_hat
[
i
,
...])
-
x
[
i
,
...]))
if
avg_channels
:
rloss
/=
n_channels
return
rloss
def
mse_channels
(
x
,
x_hat
,
n_channels
,
channels_last
:
bool
=
True
,
avg_channels
:
bool
=
False
):
rloss
=
0.
if
channels_last
:
# get MAE for all output heads
for
i
in
range
(
n_channels
):
rloss
+=
tf
.
reduce_mean
(
tf
.
square
(
tf
.
squeeze
(
x_hat
[...,
i
])
-
x
[...,
i
]))
else
:
for
i
in
range
(
n_channels
):
rloss
+=
tf
.
reduce_mean
(
tf
.
square
(
tf
.
squeeze
(
x_hat
[
i
,
...])
-
x
[
i
,
...]))
if
avg_channels
:
rloss
/=
n_channels
return
rloss
def
mae_vec
(
x
,
x_hat
,
n_channels
,
channels_last
:
bool
=
True
,
avg_channels
:
bool
=
False
,
nd_vec
:
int
=
None
):
if
nd_vec
is
None
:
nd_vec
=
n_channels
rloss
=
0.
if
channels_last
:
vec_ind
=
-
1
diff
=
tf
.
squeeze
(
x_hat
[...,
0
:
nd_vec
])
-
x
[...,
0
:
nd_vec
]
else
:
vec_ind
=
1
diff
=
tf
.
squeeze
(
x_hat
[:,
0
:
nd_vec
,
...])
-
x
[:,
0
:
nd_vec
,
...]
rloss
=
tf
.
reduce_mean
(
tf
.
norm
(
diff
,
axis
=
vec_ind
))
#rloss = tf.reduce_mean(tf.math.reduce_euclidean_norm(diff, axis=vec_ind))
if
nd_vec
>
n_channels
:
if
channels_last
:
rloss
+=
mae_channels
(
x
[...,
nd_vec
::],
x_hat
[...,
nd_vec
::],
True
,
avg_channels
)
else
:
rloss
+=
mae_channels
(
x
[:,
nd_vec
::,
...],
x_hat
[:,
nd_vec
::,
...],
True
,
avg_channels
)
return
rloss
def
mse_vec
(
x
,
x_hat
,
n_channels
,
channels_last
:
bool
=
True
,
avg_channels
:
bool
=
False
,
nd_vec
:
int
=
None
):
if
nd_vec
is
None
:
nd_vec
=
n_channels
rloss
=
0.
if
channels_last
:
vec_ind
=
-
1
diff
=
tf
.
squeeze
(
x_hat
[...,
0
:
nd_vec
])
-
x
[...,
0
:
nd_vec
]
else
:
vec_ind
=
1
diff
=
tf
.
squeeze
(
x_hat
[:,
0
:
nd_vec
,
...])
-
x
[:,
0
:
nd_vec
,
...]
rloss
=
tf
.
reduce_mean
(
tf
.
square
(
tf
.
norm
(
diff
,
axis
=
vec_ind
)))
if
nd_vec
>
n_channels
:
if
channels_last
:
rloss
+=
mse_channels
(
x
[...,
nd_vec
::],
x_hat
[...,
nd_vec
::],
True
,
avg_channels
)
else
:
rloss
+=
mse_channels
(
x
[:,
nd_vec
::,
...],
x_hat
[:,
nd_vec
::,
...],
True
,
avg_channels
)
return
rloss
def
critic_loss
(
critic_real
,
critic_gen
):
"""
The critic is optimized to maximize the difference between the generated and the real data max(real - gen).
This is equivalent to minimizing the negative of this difference, i.e. min(gen - real) = max(real - gen)
:param critic_real: critic on the real data
:param critic_gen: critic on the generated data
:return c_loss: loss to optize the critic
"""
c_loss
=
tf
.
reduce_mean
(
critic_gen
-
critic_real
)
return
c_loss
def
critic_gen_loss
(
critic_gen
):
cg_loss
=
-
tf
.
reduce_mean
(
critic_gen
)
return
cg_loss
\ No newline at end of file
Loading