Commit 2a7136aa authored by ekindogus's avatar ekindogus
Browse files

Added function to compute cosine of angles.

parent 97a7d9fa
......@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from jax import grad
from jax import grad, vmap
import jax.numpy as np
from jax_md import space
......@@ -90,6 +90,29 @@ def canonicalize_mass(mass):
raise ValueError(msg)
def cosine_angles(dR):
"""Returns cosine of angles for all atom triplets.
Args:
dR: Matrix of displacements; ndarray(shape=[num_atoms, num_neighbors,
spatial_dim]).
Returns:
Tensor of cosine of angles;
ndarray(shape=[num_atoms, num_neighbors, num_neighbors]).
"""
def angle_between_two_vectors(dR_12, dR_13):
dr_12 = space.distance(dR_12) + 1e-7
dr_13 = space.distance(dR_13) + 1e-7
cos_angle = np.dot(dR_12, dR_13) / dr_12 / dr_13
return np.clip(cos_angle, -1.0, 1.0)
angles_between_all_triplets = vmap(
vmap(vmap(angle_between_two_vectors, (0, None)), (None, 0)), 0)
return angles_between_all_triplets(dR, dR)
def pair_correlation(displacement_or_metric, rs, sigma):
metric = space.canonicalize_displacement_or_metric(displacement_or_metric)
......
......@@ -25,8 +25,8 @@ from jax.config import config as jax_config
from jax import random
import jax.numpy as np
from jax.api import jit, grad
from jax_md import quantity
from jax.api import jit, grad, vmap
from jax_md import space, quantity, test_util
from jax_md.util import *
from jax import test_util as jtu
......@@ -34,6 +34,7 @@ from jax import test_util as jtu
jax_config.parse_flags_with_absl()
FLAGS = jax_config.FLAGS
test_util.update_test_tolerance(1e-5, 2e-7)
PARTICLE_COUNT = 10
STOCHASTIC_SAMPLES = 10
......@@ -65,6 +66,62 @@ class QuantityTest(jtu.JaxTestCase):
grad(do_fn)(2.0)
@parameterized.named_parameters(jtu.cases_from_list(
{
'testcase_name': '_dtype={}'.format(dtype.__name__),
'dtype': dtype,
} for dtype in DTYPES))
def test_cosine_angles(self, dtype):
displacement, _ = space.free()
displacement = space.map_product(displacement)
R = np.array(
[[0, 0],
[0, 1],
[1, 1]], dtype=dtype)
dR = displacement(R, R)
cangles = quantity.cosine_angles(dR)
c45 = 1 / np.sqrt(2)
true_cangles = np.array(
[[[0, 0, 0],
[0, 1, c45],
[0, c45, 1]],
[[1, 0, 0],
[0, 0, 0],
[0, 0, 1]],
[[1, c45, 0],
[c45, 1, 0],
[0, 0, 0]]], dtype=dtype)
self.assertAllClose(cangles, true_cangles, True)
@parameterized.named_parameters(jtu.cases_from_list(
{
'testcase_name': '_dtype={}'.format(dtype.__name__),
'dtype': dtype,
} for dtype in DTYPES))
def test_cosine_angles_neighbors(self, dtype):
displacement, _ = space.free()
displacement = vmap(vmap(displacement, (None, 0)), 0)
R = np.array(
[[0, 0],
[0, 1],
[1, 1]], dtype=dtype)
R_neigh = np.array(
[[[0, 1], [1, 1]],
[[0, 0], [0, 0]],
[[0, 0], [0, 0]]], dtype=dtype)
dR = displacement(R, R_neigh)
cangles = quantity.cosine_angles(dR)
c45 = 1 / np.sqrt(2)
true_cangles = np.array(
[[[1, c45], [c45, 1]],
[[1, 1], [1, 1]],
[[1, 1], [1, 1]]], dtype=dtype)
self.assertAllClose(cangles, true_cangles, True)
@parameterized.named_parameters(jtu.cases_from_list(
{
'testcase_name': '_dtype={}'.format(dtype.__name__),
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment