Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Stefan Kesselheim
jax-md
Commits
2a7136aa
Commit
2a7136aa
authored
Jan 17, 2020
by
ekindogus
Browse files
Added function to compute cosine of angles.
parent
97a7d9fa
Changes
2
Hide whitespace changes
Inline
Side-by-side
jax_md/quantity.py
View file @
2a7136aa
...
...
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
jax
import
grad
from
jax
import
grad
,
vmap
import
jax.numpy
as
np
from
jax_md
import
space
...
...
@@ -90,6 +90,29 @@ def canonicalize_mass(mass):
raise
ValueError
(
msg
)
def
cosine_angles
(
dR
):
"""Returns cosine of angles for all atom triplets.
Args:
dR: Matrix of displacements; ndarray(shape=[num_atoms, num_neighbors,
spatial_dim]).
Returns:
Tensor of cosine of angles;
ndarray(shape=[num_atoms, num_neighbors, num_neighbors]).
"""
def
angle_between_two_vectors
(
dR_12
,
dR_13
):
dr_12
=
space
.
distance
(
dR_12
)
+
1e-7
dr_13
=
space
.
distance
(
dR_13
)
+
1e-7
cos_angle
=
np
.
dot
(
dR_12
,
dR_13
)
/
dr_12
/
dr_13
return
np
.
clip
(
cos_angle
,
-
1.0
,
1.0
)
angles_between_all_triplets
=
vmap
(
vmap
(
vmap
(
angle_between_two_vectors
,
(
0
,
None
)),
(
None
,
0
)),
0
)
return
angles_between_all_triplets
(
dR
,
dR
)
def
pair_correlation
(
displacement_or_metric
,
rs
,
sigma
):
metric
=
space
.
canonicalize_displacement_or_metric
(
displacement_or_metric
)
...
...
tests/quantity_test.py
View file @
2a7136aa
...
...
@@ -25,8 +25,8 @@ from jax.config import config as jax_config
from
jax
import
random
import
jax.numpy
as
np
from
jax.api
import
jit
,
grad
from
jax_md
import
quantity
from
jax.api
import
jit
,
grad
,
vmap
from
jax_md
import
space
,
quantity
,
test_util
from
jax_md.util
import
*
from
jax
import
test_util
as
jtu
...
...
@@ -34,6 +34,7 @@ from jax import test_util as jtu
jax_config
.
parse_flags_with_absl
()
FLAGS
=
jax_config
.
FLAGS
test_util
.
update_test_tolerance
(
1e-5
,
2e-7
)
PARTICLE_COUNT
=
10
STOCHASTIC_SAMPLES
=
10
...
...
@@ -65,6 +66,62 @@ class QuantityTest(jtu.JaxTestCase):
grad
(
do_fn
)(
2.0
)
@
parameterized
.
named_parameters
(
jtu
.
cases_from_list
(
{
'testcase_name'
:
'_dtype={}'
.
format
(
dtype
.
__name__
),
'dtype'
:
dtype
,
}
for
dtype
in
DTYPES
))
def
test_cosine_angles
(
self
,
dtype
):
displacement
,
_
=
space
.
free
()
displacement
=
space
.
map_product
(
displacement
)
R
=
np
.
array
(
[[
0
,
0
],
[
0
,
1
],
[
1
,
1
]],
dtype
=
dtype
)
dR
=
displacement
(
R
,
R
)
cangles
=
quantity
.
cosine_angles
(
dR
)
c45
=
1
/
np
.
sqrt
(
2
)
true_cangles
=
np
.
array
(
[[[
0
,
0
,
0
],
[
0
,
1
,
c45
],
[
0
,
c45
,
1
]],
[[
1
,
0
,
0
],
[
0
,
0
,
0
],
[
0
,
0
,
1
]],
[[
1
,
c45
,
0
],
[
c45
,
1
,
0
],
[
0
,
0
,
0
]]],
dtype
=
dtype
)
self
.
assertAllClose
(
cangles
,
true_cangles
,
True
)
@
parameterized
.
named_parameters
(
jtu
.
cases_from_list
(
{
'testcase_name'
:
'_dtype={}'
.
format
(
dtype
.
__name__
),
'dtype'
:
dtype
,
}
for
dtype
in
DTYPES
))
def
test_cosine_angles_neighbors
(
self
,
dtype
):
displacement
,
_
=
space
.
free
()
displacement
=
vmap
(
vmap
(
displacement
,
(
None
,
0
)),
0
)
R
=
np
.
array
(
[[
0
,
0
],
[
0
,
1
],
[
1
,
1
]],
dtype
=
dtype
)
R_neigh
=
np
.
array
(
[[[
0
,
1
],
[
1
,
1
]],
[[
0
,
0
],
[
0
,
0
]],
[[
0
,
0
],
[
0
,
0
]]],
dtype
=
dtype
)
dR
=
displacement
(
R
,
R_neigh
)
cangles
=
quantity
.
cosine_angles
(
dR
)
c45
=
1
/
np
.
sqrt
(
2
)
true_cangles
=
np
.
array
(
[[[
1
,
c45
],
[
c45
,
1
]],
[[
1
,
1
],
[
1
,
1
]],
[[
1
,
1
],
[
1
,
1
]]],
dtype
=
dtype
)
self
.
assertAllClose
(
cangles
,
true_cangles
,
True
)
@
parameterized
.
named_parameters
(
jtu
.
cases_from_list
(
{
'testcase_name'
:
'_dtype={}'
.
format
(
dtype
.
__name__
),
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment