Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Stefan Kesselheim
jax-md
Commits
95365de3
Commit
95365de3
authored
Nov 12, 2020
by
Sam Schoenholz
Browse files
Fixed bug when Behler-Parrinello network functions were jit.
parent
768907ad
Changes
3
Hide whitespace changes
Inline
Side-by-side
jax_md/nn.py
View file @
95365de3
...
...
@@ -14,7 +14,7 @@
"""Neural Network Primitives."""
from
typing
import
Callable
,
Tuple
,
Dict
,
Any
from
typing
import
Callable
,
Tuple
,
Dict
,
Any
,
Optional
import
numpy
as
onp
...
...
@@ -86,7 +86,7 @@ def _behler_parrinello_cutoff_fn(dr: Array,
def
radial_symmetry_functions
(
displacement_or_metric
:
DisplacementOrMetricFn
,
species
:
Array
,
species
:
Optional
[
Array
]
,
etas
:
Array
,
cutoff_distance
:
float
)
->
Callable
[[
Array
],
Array
]:
...
...
@@ -114,23 +114,28 @@ def radial_symmetry_functions(displacement_or_metric: DisplacementOrMetricFn,
"""
metric
=
space
.
canonicalize_displacement_or_metric
(
displacement_or_metric
)
def
compute_fun
(
R
:
Array
,
**
kwargs
)
->
Array
:
_metric
=
partial
(
metric
,
**
kwargs
)
_metric
=
space
.
map_product
(
_metric
)
radial_fn
=
lambda
eta
,
dr
:
(
np
.
exp
(
-
eta
*
dr
**
2
)
*
_behler_parrinello_cutoff_fn
(
dr
,
cutoff_distance
))
def
return_radial
(
atom_type
):
"""Returns the radial symmetry functions for neighbor type atom_type."""
R_neigh
=
R
[
species
==
atom_type
,
:]
dr
=
_metric
(
R
,
R_neigh
)
radial
=
vmap
(
radial_fn
,
(
0
,
None
))(
etas
,
dr
)
return
np
.
sum
(
radial
,
axis
=
1
).
T
return
np
.
hstack
([
return_radial
(
atom_type
)
for
atom_type
in
np
.
unique
(
species
)])
return
compute_fun
radial_fn
=
lambda
eta
,
dr
:
(
np
.
exp
(
-
eta
*
dr
**
2
)
*
_behler_parrinello_cutoff_fn
(
dr
,
cutoff_distance
))
radial_fn
=
vmap
(
radial_fn
,
(
0
,
None
))
if
species
is
None
:
def
compute_fn
(
R
:
Array
,
**
kwargs
)
->
Array
:
_metric
=
partial
(
metric
,
**
kwargs
)
_metric
=
space
.
map_product
(
_metric
)
return
util
.
high_precision_sum
(
radial_fn
(
etas
,
_metric
(
R
,
R
)),
axis
=
1
).
T
elif
isinstance
(
species
,
np
.
ndarray
):
species
=
onp
.
array
(
species
)
def
compute_fn
(
R
:
Array
,
**
kwargs
)
->
Array
:
_metric
=
partial
(
metric
,
**
kwargs
)
_metric
=
space
.
map_product
(
_metric
)
def
return_radial
(
atom_type
):
"""Returns the radial symmetry functions for neighbor type atom_type."""
R_neigh
=
R
[
species
==
atom_type
,
:]
dr
=
_metric
(
R
,
R_neigh
)
return
util
.
high_precision_sum
(
radial_fn
(
etas
,
dr
),
axis
=
1
).
T
return
np
.
hstack
([
return_radial
(
atom_type
)
for
atom_type
in
onp
.
unique
(
species
)])
return
compute_fn
def
radial_symmetry_functions_neighbor_list
(
...
...
@@ -162,7 +167,22 @@ def radial_symmetry_functions_neighbor_list(
"""
metric
=
space
.
canonicalize_displacement_or_metric
(
displacement_or_metric
)
def
compute_fun
(
R
:
Array
,
neighbor
:
NeighborList
,
**
kwargs
)
->
Array
:
def
radial_fn
(
eta
,
dr
):
return
(
np
.
exp
(
-
eta
*
dr
**
2
)
*
_behler_parrinello_cutoff_fn
(
dr
,
cutoff_distance
))
radial_fn
=
vmap
(
radial_fn
,
(
0
,
None
))
if
species
is
None
:
def
compute_fn
(
R
:
Array
,
neighbor
:
NeighborList
,
**
kwargs
)
->
Array
:
_metric
=
partial
(
metric
,
**
kwargs
)
_metric
=
space
.
map_neighbor
(
_metric
)
R_neigh
=
R
[
neighbor
.
idx
]
mask
=
(
neighbor
.
idx
<
R
.
shape
[
0
])[
np
.
newaxis
,
:,
:]
dr
=
_metric
(
R
,
R_neigh
)
return
util
.
high_precision_sum
(
radial_fn
(
etas
,
dr
)
*
mask
,
axis
=
2
).
T
return
compute_fn
def
compute_fn
(
R
:
Array
,
neighbor
:
NeighborList
,
**
kwargs
)
->
Array
:
_metric
=
partial
(
metric
,
**
kwargs
)
_metric
=
space
.
map_neighbor
(
_metric
)
radial_fn
=
lambda
eta
,
dr
:
(
np
.
exp
(
-
eta
*
dr
**
2
)
*
...
...
@@ -179,9 +199,9 @@ def radial_symmetry_functions_neighbor_list(
return
util
.
high_precision_sum
(
radial
*
mask
[
np
.
newaxis
,
:,
:],
axis
=
2
).
T
return
np
.
hstack
([
return_radial
(
atom_type
)
for
atom_type
in
np
.
unique
(
species
)])
atom_type
in
o
np
.
unique
(
species
)])
return
compute_f
u
n
return
compute_fn
def
single_pair_angular_symmetry_function
(
dR12
:
Array
,
...
...
@@ -238,7 +258,7 @@ def angular_symmetry_functions(displacement: DisplacementFn,
spatial_dimension]` and returns `[N_atoms, N_types * (N_types + 1) / 2]`
where N_types is the number of types of particles in the system.
"""
_angular_fn
=
vmap
(
single_pair_angular_symmetry_function
,
(
None
,
None
,
0
,
0
,
0
,
None
))
...
...
@@ -251,14 +271,23 @@ def angular_symmetry_functions(displacement: DisplacementFn,
_all_pairs_angular
=
vmap
(
vmap
(
vmap
(
_batched_angular_fn
,
(
0
,
None
)),
(
None
,
0
)),
0
)
def
compute_fun
(
R
,
**
kwargs
):
if
species
is
None
:
def
compute_fn
(
R
,
**
kwargs
):
D_fn
=
partial
(
displacement
,
**
kwargs
)
D_fn
=
space
.
map_product
(
D_fn
)
dR
=
D_fn
(
R
,
R
)
return
np
.
sum
(
_all_pairs_angular
(
dR
,
dR
),
axis
=
[
1
,
2
])
return
compute_fn
if
isinstance
(
species
,
np
.
ndarray
):
species
=
onp
.
array
(
species
)
def
compute_fn
(
R
,
**
kwargs
):
atom_types
=
onp
.
unique
(
species
)
D_fn
=
partial
(
displacement
,
**
kwargs
)
D_fn
=
space
.
map_product
(
D_fn
)
D_different_types
=
[
D_fn
(
R
[
species
==
atom_type
,
:],
R
)
for
atom_type
in
np
.
unique
(
species
)
]
D_different_types
=
[
D_fn
(
R
[
species
==
s
,
:],
R
)
for
s
in
atom_types
]
out
=
[]
atom_types
=
np
.
unique
(
species
)
for
i
in
range
(
len
(
atom_types
)):
for
j
in
range
(
i
,
len
(
atom_types
)):
out
+=
[
...
...
@@ -267,7 +296,7 @@ def angular_symmetry_functions(displacement: DisplacementFn,
axis
=
[
1
,
2
])
]
return
np
.
hstack
(
out
)
return
compute_f
u
n
return
compute_fn
def
angular_symmetry_functions_neighbor_list
(
displacement
:
DisplacementFn
,
...
...
@@ -297,7 +326,7 @@ def angular_symmetry_functions_neighbor_list(
spatial_dimension]` and returns `[N_atoms, N_types * (N_types + 1) / 2]`
where N_types is the number of types of particles in the system.
"""
_angular_fn
=
vmap
(
single_pair_angular_symmetry_function
,
(
None
,
None
,
0
,
0
,
0
,
None
))
...
...
@@ -310,14 +339,32 @@ def angular_symmetry_functions_neighbor_list(
_all_pairs_angular
=
vmap
(
vmap
(
vmap
(
_batched_angular_fn
,
(
0
,
None
)),
(
None
,
0
)),
0
)
def
compute_fun
(
R
:
Array
,
neighbor
:
NeighborList
,
**
kwargs
)
->
Array
:
if
species
is
None
:
def
compute_fn
(
R
:
Array
,
neighbor
:
NeighborList
,
**
kwargs
)
->
Array
:
D_fn
=
partial
(
displacement
,
**
kwargs
)
D_fn
=
space
.
map_neighbor
(
D_fn
)
R_neigh
=
R
[
neighbor
.
idx
]
mask
=
neighbor
.
idx
<
R
.
shape
[
0
]
dR
=
D_fn
(
R
,
R_neigh
)
all_angular
=
_all_pairs_angular
(
dR
,
dR
)
mask_i
=
mask
[:,
:,
np
.
newaxis
,
np
.
newaxis
]
mask_j
=
mask
[:,
np
.
newaxis
,
:,
np
.
newaxis
]
return
util
.
high_precision_sum
(
all_angular
*
mask_i
*
mask_j
,
axis
=
[
1
,
2
])
return
compute_fn
def
compute_fn
(
R
:
Array
,
neighbor
:
NeighborList
,
**
kwargs
)
->
Array
:
D_fn
=
partial
(
displacement
,
**
kwargs
)
D_fn
=
space
.
map_neighbor
(
D_fn
)
R_neigh
=
R
[
neighbor
.
idx
]
species_neigh
=
species
[
neighbor
.
idx
]
atom_types
=
np
.
unique
(
species
)
atom_types
=
o
np
.
unique
(
species
)
base_mask
=
neighbor
.
idx
<
len
(
R
)
mask
=
[
np
.
logical_and
(
base_mask
,
species_neigh
==
t
)
for
t
in
atom_types
]
...
...
@@ -331,10 +378,10 @@ def angular_symmetry_functions_neighbor_list(
for
j
in
range
(
i
,
len
(
atom_types
)):
mask_j
=
mask
[
j
][:,
np
.
newaxis
,
:,
np
.
newaxis
]
out
+=
[
np
.
sum
(
all_angular
*
mask_i
*
mask_j
,
axis
=
[
1
,
2
])
util
.
high_precision_
sum
(
all_angular
*
mask_i
*
mask_j
,
axis
=
[
1
,
2
])
]
return
np
.
hstack
(
out
)
return
compute_f
u
n
return
compute_fn
def
behler_parrinello_symmetry_functions_neighbor_list
(
...
...
@@ -356,7 +403,7 @@ def behler_parrinello_symmetry_functions_neighbor_list(
if
lambdas
is
None
:
lambdas
=
np
.
array
([
-
1
,
1
]
*
4
+
[
1
]
*
14
,
f32
)
if
zetas
is
None
:
zetas
=
np
.
array
([
1
,
1
,
2
,
2
]
*
2
+
[
1
,
2
]
+
[
1
,
2
,
4
,
16
]
*
3
,
f32
)
...
...
@@ -378,7 +425,7 @@ def behler_parrinello_symmetry_functions_neighbor_list(
def
behler_parrinello_symmetry_functions
(
displacement
:
DisplacementFn
,
species
:
Array
,
species
:
Array
=
None
,
radial_etas
:
Array
=
None
,
angular_etas
:
Array
=
None
,
lambdas
:
Array
=
None
,
...
...
tests/energy_test.py
View file @
95365de3
...
...
@@ -475,18 +475,42 @@ class EnergyTest(jtu.JaxTestCase):
def
test_behler_parrinello_network
(
self
,
N_types
,
dtype
):
key
=
random
.
PRNGKey
(
1
)
R
=
np
.
array
([[
0
,
0
,
0
],
[
1
,
1
,
1
],
[
1
,
1
,
0
]],
dtype
)
species
=
np
.
array
([
1
,
1
,
N_types
])
species
=
np
.
array
([
1
,
1
,
N_types
])
if
N_types
>
1
else
None
box_size
=
f32
(
1.5
)
displacement
,
_
=
space
.
periodic
(
box_size
)
nn_init
,
nn_apply
=
energy
.
behler_parrinello
(
displacement
,
species
)
params
=
nn_init
(
key
,
R
)
nn_force_fn
=
grad
(
nn_apply
,
argnums
=
1
)
nn_force
=
nn_force_fn
(
params
,
R
)
nn_energy
=
nn_apply
(
params
,
R
)
nn_force
=
jit
(
nn_force_fn
)
(
params
,
R
)
nn_energy
=
jit
(
nn_apply
)
(
params
,
R
)
self
.
assertAllClose
(
np
.
any
(
np
.
isnan
(
nn_energy
)),
False
)
self
.
assertAllClose
(
np
.
any
(
np
.
isnan
(
nn_force
)),
False
)
self
.
assertAllClose
(
nn_force
.
shape
,
[
3
,
3
])
@
parameterized
.
named_parameters
(
jtu
.
cases_from_list
(
{
'testcase_name'
:
'_N_types={}_dtype={}'
.
format
(
N_types
,
dtype
.
__name__
),
'N_types'
:
N_types
,
'dtype'
:
dtype
,
}
for
N_types
in
N_TYPES_TO_TEST
for
dtype
in
POSITION_DTYPE
))
def
test_behler_parrinello_network_neighbor_list
(
self
,
N_types
,
dtype
):
key
=
random
.
PRNGKey
(
1
)
R
=
np
.
array
([[
0
,
0
,
0
],
[
1
,
1
,
1
],
[
1
,
1
,
0
]],
dtype
)
species
=
np
.
array
([
1
,
1
,
N_types
])
if
N_types
>
1
else
None
box_size
=
f32
(
1.5
)
displacement
,
_
=
space
.
periodic
(
box_size
)
neighbor_fn
,
nn_init
,
nn_apply
=
energy
.
behler_parrinello_neighbor_list
(
displacement
,
box_size
,
species
)
nbrs
=
neighbor_fn
(
R
)
params
=
nn_init
(
key
,
R
,
nbrs
)
nn_force_fn
=
grad
(
nn_apply
,
argnums
=
1
)
nn_force
=
jit
(
nn_force_fn
)(
params
,
R
,
nbrs
)
nn_energy
=
jit
(
nn_apply
)(
params
,
R
,
nbrs
)
self
.
assertAllClose
(
np
.
any
(
np
.
isnan
(
nn_energy
)),
False
)
self
.
assertAllClose
(
np
.
any
(
np
.
isnan
(
nn_force
)),
False
)
self
.
assertAllClose
(
nn_force
.
shape
,
[
3
,
3
])
@
parameterized
.
named_parameters
(
jtu
.
cases_from_list
(
{
'testcase_name'
:
'_dim={}_dtype={}'
.
format
(
dim
,
dtype
.
__name__
),
...
...
tests/nn_test.py
View file @
95365de3
...
...
@@ -143,18 +143,19 @@ class SymmetryFunctionTest(jtu.JaxTestCase):
etas
=
np
.
linspace
(
1.
,
2.
,
N_etas
,
dtype
=
dtype
)
gr
=
nn
.
angular_symmetry_functions
(
displacement
,
species
,
etas
=
etas
,
lambdas
=
np
.
array
([
-
1.0
]
*
N_etas
,
dtype
),
etas
=
etas
,
lambdas
=
np
.
array
([
-
1.0
]
*
N_etas
,
dtype
),
zetas
=
np
.
array
([
1.0
]
*
N_etas
,
dtype
),
cutoff_distance
=
r_cutoff
)
gr_neigh
=
nn
.
angular_symmetry_functions_neighbor_list
(
displacement
,
species
,
etas
=
etas
,
lambdas
=
np
.
array
([
-
1.0
]
*
N_etas
,
dtype
),
zetas
=
np
.
array
([
1.0
]
*
N_etas
,
dtype
),
cutoff_distance
=
r_cutoff
)
gr_neigh
=
nn
.
angular_symmetry_functions_neighbor_list
(
displacement
,
species
,
etas
=
etas
,
lambdas
=
np
.
array
([
-
1.0
]
*
N_etas
,
dtype
),
zetas
=
np
.
array
([
1.0
]
*
N_etas
,
dtype
),
cutoff_distance
=
r_cutoff
)
nbrs
=
neighbor_fn
(
R
)
gr_exact
=
gr
(
R
)
gr_nbrs
=
gr_neigh
(
R
,
neighbor
=
nbrs
)
...
...
@@ -208,6 +209,36 @@ class SymmetryFunctionTest(jtu.JaxTestCase):
self
.
assertAllClose
(
gr_out
.
shape
,
(
3
,
N_etas
*
(
N_types
+
N_types
*
(
N_types
+
1
)
//
2
)))
self
.
assertAllClose
(
gr_out
[
2
,
0
],
dtype
(
1.885791
),
rtol
=
1e-6
,
atol
=
1e-6
)
@
parameterized
.
named_parameters
(
jtu
.
cases_from_list
(
{
'testcase_name'
:
'_N_types={}_N_etas={}_d_type={}'
.
format
(
N_types
,
N_etas
,
dtype
.
__name__
),
'dtype'
:
dtype
,
'N_types'
:
N_types
,
'N_etas'
:
N_etas
,
}
for
N_types
in
N_TYPES_TO_TEST
for
N_etas
in
N_ETAS_TO_TEST
for
dtype
in
DTYPES
))
def
test_behler_parrinello_symmetry_functions_neighbor_list
(
self
,
N_types
,
N_etas
,
dtype
):
displacement
,
shift
=
space
.
free
()
neighbor_fn
=
partition
.
neighbor_list
(
displacement
,
10.0
,
8.0
,
0.0
)
gr
=
nn
.
behler_parrinello_symmetry_functions_neighbor_list
(
displacement
,
np
.
array
([
1
,
1
,
N_types
]),
radial_etas
=
np
.
array
([
1e-4
/
(
0.529177
**
2
)]
*
N_etas
,
dtype
),
angular_etas
=
np
.
array
([
1e-4
/
(
0.529177
**
2
)]
*
N_etas
,
dtype
),
lambdas
=
np
.
array
([
-
1.0
]
*
N_etas
,
dtype
),
zetas
=
np
.
array
([
1.0
]
*
N_etas
,
dtype
),
cutoff_distance
=
8.0
)
R
=
np
.
array
([[
0
,
0
,
0
],
[
1
,
1
,
1
],
[
1
,
1
,
0
]],
dtype
)
nbrs
=
neighbor_fn
(
R
)
gr_out
=
gr
(
R
,
neighbor
=
nbrs
)
self
.
assertAllClose
(
gr_out
.
shape
,
(
3
,
N_etas
*
(
N_types
+
N_types
*
(
N_types
+
1
)
//
2
)))
self
.
assertAllClose
(
gr_out
[
2
,
0
],
dtype
(
1.885791
),
rtol
=
1e-6
,
atol
=
1e-6
)
def
_graph_network
(
graph_tuple
):
update_node_fn
=
lambda
n
,
se
,
re
,
g
:
n
update_edge_fn
=
lambda
e
,
sn
,
rn
,
g
:
e
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment