Commit 95365de3 authored by Sam Schoenholz's avatar Sam Schoenholz
Browse files

Fixed bug when Behler-Parrinello network functions were jit.

parent 768907ad
......@@ -14,7 +14,7 @@
"""Neural Network Primitives."""
from typing import Callable, Tuple, Dict, Any
from typing import Callable, Tuple, Dict, Any, Optional
import numpy as onp
......@@ -86,7 +86,7 @@ def _behler_parrinello_cutoff_fn(dr: Array,
def radial_symmetry_functions(displacement_or_metric: DisplacementOrMetricFn,
species: Array,
species: Optional[Array],
etas: Array,
cutoff_distance: float
) -> Callable[[Array], Array]:
......@@ -114,23 +114,28 @@ def radial_symmetry_functions(displacement_or_metric: DisplacementOrMetricFn,
"""
metric = space.canonicalize_displacement_or_metric(displacement_or_metric)
def compute_fun(R: Array, **kwargs) -> Array:
_metric = partial(metric, **kwargs)
_metric = space.map_product(_metric)
radial_fn = lambda eta, dr: (np.exp(-eta * dr**2) *
_behler_parrinello_cutoff_fn(dr, cutoff_distance))
def return_radial(atom_type):
"""Returns the radial symmetry functions for neighbor type atom_type."""
R_neigh = R[species == atom_type, :]
dr = _metric(R, R_neigh)
radial = vmap(radial_fn, (0, None))(etas, dr)
return np.sum(radial, axis=1).T
return np.hstack([return_radial(atom_type) for
atom_type in np.unique(species)])
return compute_fun
radial_fn = lambda eta, dr: (np.exp(-eta * dr**2) *
_behler_parrinello_cutoff_fn(dr, cutoff_distance))
radial_fn = vmap(radial_fn, (0, None))
if species is None:
def compute_fn(R: Array, **kwargs) -> Array:
_metric = partial(metric, **kwargs)
_metric = space.map_product(_metric)
return util.high_precision_sum(radial_fn(etas, _metric(R, R)), axis=1).T
elif isinstance(species, np.ndarray):
species = onp.array(species)
def compute_fn(R: Array, **kwargs) -> Array:
_metric = partial(metric, **kwargs)
_metric = space.map_product(_metric)
def return_radial(atom_type):
"""Returns the radial symmetry functions for neighbor type atom_type."""
R_neigh = R[species == atom_type, :]
dr = _metric(R, R_neigh)
return util.high_precision_sum(radial_fn(etas, dr), axis=1).T
return np.hstack([return_radial(atom_type) for
atom_type in onp.unique(species)])
return compute_fn
def radial_symmetry_functions_neighbor_list(
......@@ -162,7 +167,22 @@ def radial_symmetry_functions_neighbor_list(
"""
metric = space.canonicalize_displacement_or_metric(displacement_or_metric)
def compute_fun(R: Array, neighbor: NeighborList, **kwargs) -> Array:
def radial_fn(eta, dr):
return (np.exp(-eta * dr**2) *
_behler_parrinello_cutoff_fn(dr, cutoff_distance))
radial_fn = vmap(radial_fn, (0, None))
if species is None:
def compute_fn(R: Array, neighbor: NeighborList, **kwargs) -> Array:
_metric = partial(metric, **kwargs)
_metric = space.map_neighbor(_metric)
R_neigh = R[neighbor.idx]
mask = (neighbor.idx < R.shape[0])[np.newaxis, :, :]
dr = _metric(R, R_neigh)
return util.high_precision_sum(radial_fn(etas, dr) * mask, axis=2).T
return compute_fn
def compute_fn(R: Array, neighbor: NeighborList, **kwargs) -> Array:
_metric = partial(metric, **kwargs)
_metric = space.map_neighbor(_metric)
radial_fn = lambda eta, dr: (np.exp(-eta * dr**2) *
......@@ -179,9 +199,9 @@ def radial_symmetry_functions_neighbor_list(
return util.high_precision_sum(radial * mask[np.newaxis, :, :], axis=2).T
return np.hstack([return_radial(atom_type) for
atom_type in np.unique(species)])
atom_type in onp.unique(species)])
return compute_fun
return compute_fn
def single_pair_angular_symmetry_function(dR12: Array,
......@@ -238,7 +258,7 @@ def angular_symmetry_functions(displacement: DisplacementFn,
spatial_dimension]` and returns `[N_atoms, N_types * (N_types + 1) / 2]`
where N_types is the number of types of particles in the system.
"""
_angular_fn = vmap(single_pair_angular_symmetry_function,
(None, None, 0, 0, 0, None))
......@@ -251,14 +271,23 @@ def angular_symmetry_functions(displacement: DisplacementFn,
_all_pairs_angular = vmap(
vmap(vmap(_batched_angular_fn, (0, None)), (None, 0)), 0)
def compute_fun(R, **kwargs):
if species is None:
def compute_fn(R, **kwargs):
D_fn = partial(displacement, **kwargs)
D_fn = space.map_product(D_fn)
dR = D_fn(R, R)
return np.sum(_all_pairs_angular(dR, dR), axis=[1, 2])
return compute_fn
if isinstance(species, np.ndarray):
species = onp.array(species)
def compute_fn(R, **kwargs):
atom_types = onp.unique(species)
D_fn = partial(displacement, **kwargs)
D_fn = space.map_product(D_fn)
D_different_types = [
D_fn(R[species == atom_type, :], R) for atom_type in np.unique(species)
]
D_different_types = [D_fn(R[species == s, :], R) for s in atom_types]
out = []
atom_types = np.unique(species)
for i in range(len(atom_types)):
for j in range(i, len(atom_types)):
out += [
......@@ -267,7 +296,7 @@ def angular_symmetry_functions(displacement: DisplacementFn,
axis=[1, 2])
]
return np.hstack(out)
return compute_fun
return compute_fn
def angular_symmetry_functions_neighbor_list(
displacement: DisplacementFn,
......@@ -297,7 +326,7 @@ def angular_symmetry_functions_neighbor_list(
spatial_dimension]` and returns `[N_atoms, N_types * (N_types + 1) / 2]`
where N_types is the number of types of particles in the system.
"""
_angular_fn = vmap(single_pair_angular_symmetry_function,
(None, None, 0, 0, 0, None))
......@@ -310,14 +339,32 @@ def angular_symmetry_functions_neighbor_list(
_all_pairs_angular = vmap(
vmap(vmap(_batched_angular_fn, (0, None)), (None, 0)), 0)
def compute_fun(R: Array, neighbor: NeighborList, **kwargs) -> Array:
if species is None:
def compute_fn(R: Array, neighbor: NeighborList, **kwargs) -> Array:
D_fn = partial(displacement, **kwargs)
D_fn = space.map_neighbor(D_fn)
R_neigh = R[neighbor.idx]
mask = neighbor.idx < R.shape[0]
dR = D_fn(R, R_neigh)
all_angular = _all_pairs_angular(dR, dR)
mask_i = mask[:, :, np.newaxis, np.newaxis]
mask_j = mask[:, np.newaxis, :, np.newaxis]
return util.high_precision_sum(all_angular * mask_i * mask_j,
axis=[1, 2])
return compute_fn
def compute_fn(R: Array, neighbor: NeighborList, **kwargs) -> Array:
D_fn = partial(displacement, **kwargs)
D_fn = space.map_neighbor(D_fn)
R_neigh = R[neighbor.idx]
species_neigh = species[neighbor.idx]
atom_types = np.unique(species)
atom_types = onp.unique(species)
base_mask = neighbor.idx < len(R)
mask = [np.logical_and(base_mask, species_neigh == t) for t in atom_types]
......@@ -331,10 +378,10 @@ def angular_symmetry_functions_neighbor_list(
for j in range(i, len(atom_types)):
mask_j = mask[j][:, np.newaxis, :, np.newaxis]
out += [
np.sum(all_angular * mask_i * mask_j , axis=[1, 2])
util.high_precision_sum(all_angular * mask_i * mask_j, axis=[1, 2])
]
return np.hstack(out)
return compute_fun
return compute_fn
def behler_parrinello_symmetry_functions_neighbor_list(
......@@ -356,7 +403,7 @@ def behler_parrinello_symmetry_functions_neighbor_list(
if lambdas is None:
lambdas = np.array([-1, 1] * 4 + [1] * 14, f32)
if zetas is None:
zetas = np.array([1, 1, 2, 2] * 2 + [1, 2] + [1, 2, 4, 16] * 3, f32)
......@@ -378,7 +425,7 @@ def behler_parrinello_symmetry_functions_neighbor_list(
def behler_parrinello_symmetry_functions(displacement: DisplacementFn,
species: Array,
species: Array=None,
radial_etas: Array=None,
angular_etas: Array=None,
lambdas: Array=None,
......
......@@ -475,18 +475,42 @@ class EnergyTest(jtu.JaxTestCase):
def test_behler_parrinello_network(self, N_types, dtype):
key = random.PRNGKey(1)
R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype)
species = np.array([1, 1, N_types])
species = np.array([1, 1, N_types]) if N_types > 1 else None
box_size = f32(1.5)
displacement, _ = space.periodic(box_size)
nn_init, nn_apply = energy.behler_parrinello(displacement, species)
params = nn_init(key, R)
nn_force_fn = grad(nn_apply, argnums=1)
nn_force = nn_force_fn(params, R)
nn_energy = nn_apply(params, R)
nn_force = jit(nn_force_fn)(params, R)
nn_energy = jit(nn_apply)(params, R)
self.assertAllClose(np.any(np.isnan(nn_energy)), False)
self.assertAllClose(np.any(np.isnan(nn_force)), False)
self.assertAllClose(nn_force.shape, [3,3])
@parameterized.named_parameters(jtu.cases_from_list(
{
'testcase_name': '_N_types={}_dtype={}'.format(N_types, dtype.__name__),
'N_types': N_types,
'dtype': dtype,
} for N_types in N_TYPES_TO_TEST for dtype in POSITION_DTYPE))
def test_behler_parrinello_network_neighbor_list(self, N_types, dtype):
key = random.PRNGKey(1)
R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype)
species = np.array([1, 1, N_types]) if N_types > 1 else None
box_size = f32(1.5)
displacement, _ = space.periodic(box_size)
neighbor_fn, nn_init, nn_apply = energy.behler_parrinello_neighbor_list(
displacement, box_size, species)
nbrs = neighbor_fn(R)
params = nn_init(key, R, nbrs)
nn_force_fn = grad(nn_apply, argnums=1)
nn_force = jit(nn_force_fn)(params, R, nbrs)
nn_energy = jit(nn_apply)(params, R, nbrs)
self.assertAllClose(np.any(np.isnan(nn_energy)), False)
self.assertAllClose(np.any(np.isnan(nn_force)), False)
self.assertAllClose(nn_force.shape, [3,3])
@parameterized.named_parameters(jtu.cases_from_list(
{
'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
......
......@@ -143,18 +143,19 @@ class SymmetryFunctionTest(jtu.JaxTestCase):
etas = np.linspace(1., 2., N_etas, dtype=dtype)
gr = nn.angular_symmetry_functions(displacement,
species,
etas=etas,
lambdas=np.array([-1.0] * N_etas, dtype),
etas=etas,
lambdas=np.array([-1.0] * N_etas, dtype),
zetas=np.array([1.0] * N_etas, dtype),
cutoff_distance=r_cutoff)
gr_neigh = nn.angular_symmetry_functions_neighbor_list(displacement,
species,
etas=etas,
lambdas=np.array([-1.0] * N_etas, dtype),
zetas=np.array([1.0] * N_etas, dtype),
cutoff_distance=r_cutoff)
gr_neigh = nn.angular_symmetry_functions_neighbor_list(
displacement,
species,
etas=etas,
lambdas=np.array([-1.0] * N_etas, dtype),
zetas=np.array([1.0] * N_etas, dtype),
cutoff_distance=r_cutoff)
nbrs = neighbor_fn(R)
gr_exact = gr(R)
gr_nbrs = gr_neigh(R, neighbor=nbrs)
......@@ -208,6 +209,36 @@ class SymmetryFunctionTest(jtu.JaxTestCase):
self.assertAllClose(gr_out.shape, (3, N_etas * (N_types + N_types * (N_types + 1) // 2)))
self.assertAllClose(gr_out[2, 0], dtype(1.885791), rtol=1e-6, atol=1e-6)
@parameterized.named_parameters(jtu.cases_from_list(
{
'testcase_name': '_N_types={}_N_etas={}_d_type={}'.format(
N_types, N_etas, dtype.__name__),
'dtype': dtype,
'N_types': N_types,
'N_etas': N_etas,
} for N_types in N_TYPES_TO_TEST
for N_etas in N_ETAS_TO_TEST
for dtype in DTYPES))
def test_behler_parrinello_symmetry_functions_neighbor_list(self,
N_types,
N_etas,
dtype):
displacement, shift = space.free()
neighbor_fn = partition.neighbor_list(displacement, 10.0, 8.0, 0.0)
gr = nn.behler_parrinello_symmetry_functions_neighbor_list(
displacement,np.array([1, 1, N_types]),
radial_etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype),
angular_etas=np.array([1e-4/(0.529177 ** 2)] * N_etas, dtype),
lambdas=np.array([-1.0] * N_etas, dtype),
zetas=np.array([1.0] * N_etas, dtype),
cutoff_distance=8.0)
R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype)
nbrs = neighbor_fn(R)
gr_out = gr(R, neighbor=nbrs)
self.assertAllClose(gr_out.shape,
(3, N_etas * (N_types + N_types * (N_types + 1) // 2)))
self.assertAllClose(gr_out[2, 0], dtype(1.885791), rtol=1e-6, atol=1e-6)
def _graph_network(graph_tuple):
update_node_fn = lambda n, se, re, g: n
update_edge_fn = lambda e, sn, rn, g: e
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment