Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Stefan Kesselheim
jax-md
Commits
203a06e8
Commit
203a06e8
authored
Jun 10, 2020
by
ekindogus
Browse files
Added Behler-Parrinello neural network.
parent
3074c619
Changes
3
Hide whitespace changes
Inline
Side-by-side
jax_md/energy.py
View file @
203a06e8
...
...
@@ -23,7 +23,7 @@ from functools import wraps, partial
import
jax
import
jax.numpy
as
np
from
jax.tree_util
import
tree_map
from
jax
import
vmap
import
haiku
as
hk
from
jax_md
import
space
,
smap
,
partition
,
nn
...
...
@@ -414,6 +414,37 @@ def eam_from_lammps_parameters(displacement, f):
return
eam
(
displacement
,
*
load_lammps_eam_parameters
(
f
)[:
-
1
])
def
behler_parrinello
(
displacement
,
species
=
None
,
mlp_sizes
=
(
30
,
30
),
mlp_kwargs
=
None
,
sym_kwargs
=
None
):
if
sym_kwargs
is
None
:
sym_kwargs
=
{}
if
mlp_kwargs
is
None
:
mlp_kwargs
=
{
'activation'
:
np
.
tanh
}
sym_fn
=
nn
.
behler_parrinello_symmetry_functions
(
displacement
,
species
,
**
sym_kwargs
)
@
hk
.
transform
def
model
(
R
,
**
kwargs
):
embedding_fn
=
hk
.
nets
.
MLP
(
output_sizes
=
mlp_sizes
+
(
1
,),
activate_final
=
False
,
name
=
'BPEncoder'
,
**
mlp_kwargs
)
embedding_fn
=
vmap
(
embedding_fn
)
sym
=
sym_fn
(
R
,
**
kwargs
)
readout
=
embedding_fn
(
sym
)
return
np
.
sum
(
readout
)
return
model
.
init
,
model
.
apply
class
EnergyGraphNet
(
hk
.
Module
):
"""Implements a Graph Neural Network for energy fitting.
...
...
jax_md/nn.py
View file @
203a06e8
...
...
@@ -96,8 +96,8 @@ def radial_symmetry_functions(displacement_or_metric,
"""
metric
=
space
.
canonicalize_displacement_or_metric
(
displacement_or_metric
)
def
compute_fun
(
R
):
_metric
=
partial
(
metric
)
def
compute_fun
(
R
,
**
kwargs
):
_metric
=
partial
(
metric
,
**
kwargs
)
_metric
=
space
.
map_product
(
_metric
)
radial_fn
=
lambda
eta
,
dr
:
(
np
.
exp
(
-
eta
*
dr
**
2
)
*
_behler_parrinello_cutoff_fn
(
dr
,
cutoff_distance
))
...
...
@@ -167,7 +167,7 @@ def angular_symmetry_functions(displacement,
spatial_dimension]` and returns `[N_atoms, N_types * (N_types + 1) / 2]`
where N_types is the number of types of particles in the system.
"""
_angular_fn
=
vmap
(
single_pair_angular_symmetry_function
,
(
None
,
None
,
0
,
0
,
0
,
None
))
...
...
@@ -180,8 +180,9 @@ def angular_symmetry_functions(displacement,
_all_pairs_angular
=
vmap
(
vmap
(
vmap
(
_batched_angular_fn
,
(
0
,
None
)),
(
None
,
0
)),
0
)
def
compute_fun
(
R
):
D_fn
=
space
.
map_product
(
displacement
)
def
compute_fun
(
R
,
**
kwargs
):
D_fn
=
partial
(
displacement
,
**
kwargs
)
D_fn
=
space
.
map_product
(
D_fn
)
D_different_types
=
[
D_fn
(
R
[
species
==
atom_type
,
:],
R
)
for
atom_type
in
np
.
unique
(
species
)
]
...
...
@@ -210,8 +211,9 @@ def behler_parrinello_symmetry_functions(displacement,
f32
)
/
f32
(
0.529177
**
2
)
if
angular_etas
is
None
:
angular_etas
=
np
.
array
([
1e-4
]
*
4
+
[
0.003
]
*
4
+
[
0.008
]
*
2
+
[
0.015
]
*
4
+
[
0.025
]
*
4
+
[
0.045
]
*
4
,
f32
)
/
f32
(
0.529177
**
2
)
angular_etas
=
np
.
array
([
1e-4
]
*
4
+
[
0.003
]
*
4
+
[
0.008
]
*
2
+
[
0.015
]
*
4
+
[
0.025
]
*
4
+
[
0.045
]
*
4
,
f32
)
/
f32
(
0.529177
**
2
)
if
lambdas
is
None
:
lambdas
=
np
.
array
([
-
1
,
1
]
*
4
+
[
1
]
*
14
,
f32
)
...
...
@@ -229,7 +231,9 @@ def behler_parrinello_symmetry_functions(displacement,
lambdas
=
lambdas
,
zetas
=
zetas
,
cutoff_distance
=
cutoff_distance
)
return
lambda
R
:
np
.
hstack
((
radial_fn
(
R
),
angular_fn
(
R
)))
return
lambda
R
,
**
kwargs
:
np
.
hstack
((
radial_fn
(
R
,
**
kwargs
),
angular_fn
(
R
,
**
kwargs
)))
# Graph neural network primitives
...
...
tests/energy_test.py
View file @
203a06e8
...
...
@@ -45,6 +45,7 @@ SPATIAL_DIMENSION = [2, 3]
UNIT_CELL_SIZE
=
[
7
,
8
]
SOFT_SPHERE_ALPHA
=
[
2.0
,
3.0
]
N_TYPES_TO_TEST
=
[
1
,
2
]
if
FLAGS
.
jax_enable_x64
:
POSITION_DTYPE
=
[
f32
,
f64
]
...
...
@@ -415,6 +416,27 @@ class EnergyTest(jtu.JaxTestCase):
np
.
array
(
exact_force_fn
(
r
),
dtype
=
dtype
),
force_fn
(
r
,
nbrs
))
@
parameterized
.
named_parameters
(
jtu
.
cases_from_list
(
{
'testcase_name'
:
'_N_types={}_dtype={}'
.
format
(
N_types
,
dtype
.
__name__
),
'N_types'
:
N_types
,
'dtype'
:
dtype
,
}
for
N_types
in
N_TYPES_TO_TEST
for
dtype
in
POSITION_DTYPE
))
def
test_behler_parrinello_network
(
self
,
N_types
,
dtype
):
key
=
random
.
PRNGKey
(
1
)
R
=
np
.
array
([[
0
,
0
,
0
],
[
1
,
1
,
1
],
[
1
,
1
,
0
]],
dtype
)
species
=
np
.
array
([
1
,
1
,
N_types
])
box_size
=
f32
(
1.5
)
displacement
,
_
=
space
.
periodic
(
box_size
)
nn_init
,
nn_apply
=
energy
.
behler_parrinello
(
displacement
,
species
)
params
=
nn_init
(
key
,
R
)
nn_force_fn
=
grad
(
nn_apply
,
argnums
=
1
)
nn_force
=
nn_force_fn
(
params
,
R
)
nn_energy
=
nn_apply
(
params
,
R
)
self
.
assertAllClose
(
np
.
any
(
np
.
isnan
(
nn_energy
)),
False
)
self
.
assertAllClose
(
np
.
any
(
np
.
isnan
(
nn_force
)),
False
)
self
.
assertAllClose
(
nn_force
.
shape
,
[
3
,
3
])
@
parameterized
.
named_parameters
(
jtu
.
cases_from_list
(
{
'testcase_name'
:
'_dim={}_dtype={}'
.
format
(
dim
,
dtype
.
__name__
),
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment