Commit 203a06e8 authored by ekindogus's avatar ekindogus
Browse files

Added Behler-Parrinello neural network.

parent 3074c619
......@@ -23,7 +23,7 @@ from functools import wraps, partial
import jax
import jax.numpy as np
from jax.tree_util import tree_map
from jax import vmap
import haiku as hk
from jax_md import space, smap, partition, nn
......@@ -414,6 +414,37 @@ def eam_from_lammps_parameters(displacement, f):
return eam(displacement, *load_lammps_eam_parameters(f)[:-1])
def behler_parrinello(displacement,
species=None,
mlp_sizes=(30, 30),
mlp_kwargs=None,
sym_kwargs=None):
if sym_kwargs is None:
sym_kwargs = {}
if mlp_kwargs is None:
mlp_kwargs = {
'activation': np.tanh
}
sym_fn = nn.behler_parrinello_symmetry_functions(displacement,
species,
**sym_kwargs)
@hk.transform
def model(R, **kwargs):
embedding_fn = hk.nets.MLP(output_sizes=mlp_sizes+(1,),
activate_final=False,
name='BPEncoder',
**mlp_kwargs)
embedding_fn = vmap(embedding_fn)
sym = sym_fn(R, **kwargs)
readout = embedding_fn(sym)
return np.sum(readout)
return model.init, model.apply
class EnergyGraphNet(hk.Module):
"""Implements a Graph Neural Network for energy fitting.
......
......@@ -96,8 +96,8 @@ def radial_symmetry_functions(displacement_or_metric,
"""
metric = space.canonicalize_displacement_or_metric(displacement_or_metric)
def compute_fun(R):
_metric = partial(metric)
def compute_fun(R, **kwargs):
_metric = partial(metric, **kwargs)
_metric = space.map_product(_metric)
radial_fn = lambda eta, dr: (np.exp(-eta * dr**2) *
_behler_parrinello_cutoff_fn(dr, cutoff_distance))
......@@ -167,7 +167,7 @@ def angular_symmetry_functions(displacement,
spatial_dimension]` and returns `[N_atoms, N_types * (N_types + 1) / 2]`
where N_types is the number of types of particles in the system.
"""
_angular_fn = vmap(single_pair_angular_symmetry_function,
(None, None, 0, 0, 0, None))
......@@ -180,8 +180,9 @@ def angular_symmetry_functions(displacement,
_all_pairs_angular = vmap(
vmap(vmap(_batched_angular_fn, (0, None)), (None, 0)), 0)
def compute_fun(R):
D_fn = space.map_product(displacement)
def compute_fun(R, **kwargs):
D_fn = partial(displacement, **kwargs)
D_fn = space.map_product(D_fn)
D_different_types = [
D_fn(R[species == atom_type, :], R) for atom_type in np.unique(species)
]
......@@ -210,8 +211,9 @@ def behler_parrinello_symmetry_functions(displacement,
f32) / f32(0.529177 ** 2)
if angular_etas is None:
angular_etas = np.array([1e-4] * 4 + [0.003] * 4 + [0.008] * 2 + [0.015] * 4
+ [0.025] * 4 + [0.045] * 4, f32) / f32(0.529177 ** 2)
angular_etas = np.array([1e-4] * 4 + [0.003] * 4 + [0.008] * 2 +
[0.015] * 4 + [0.025] * 4 + [0.045] * 4,
f32) / f32(0.529177 ** 2)
if lambdas is None:
lambdas = np.array([-1, 1] * 4 + [1] * 14, f32)
......@@ -229,7 +231,9 @@ def behler_parrinello_symmetry_functions(displacement,
lambdas=lambdas,
zetas=zetas,
cutoff_distance=cutoff_distance)
return lambda R: np.hstack((radial_fn(R), angular_fn(R)))
return lambda R, **kwargs: np.hstack((radial_fn(R, **kwargs),
angular_fn(R, **kwargs)))
# Graph neural network primitives
......
......@@ -45,6 +45,7 @@ SPATIAL_DIMENSION = [2, 3]
UNIT_CELL_SIZE = [7, 8]
SOFT_SPHERE_ALPHA = [2.0, 3.0]
N_TYPES_TO_TEST = [1, 2]
if FLAGS.jax_enable_x64:
POSITION_DTYPE = [f32, f64]
......@@ -415,6 +416,27 @@ class EnergyTest(jtu.JaxTestCase):
np.array(exact_force_fn(r), dtype=dtype),
force_fn(r, nbrs))
@parameterized.named_parameters(jtu.cases_from_list(
{
'testcase_name': '_N_types={}_dtype={}'.format(N_types, dtype.__name__),
'N_types': N_types,
'dtype': dtype,
} for N_types in N_TYPES_TO_TEST for dtype in POSITION_DTYPE))
def test_behler_parrinello_network(self, N_types, dtype):
key = random.PRNGKey(1)
R = np.array([[0,0,0], [1,1,1], [1,1,0]], dtype)
species = np.array([1, 1, N_types])
box_size = f32(1.5)
displacement, _ = space.periodic(box_size)
nn_init, nn_apply = energy.behler_parrinello(displacement, species)
params = nn_init(key, R)
nn_force_fn = grad(nn_apply, argnums=1)
nn_force = nn_force_fn(params, R)
nn_energy = nn_apply(params, R)
self.assertAllClose(np.any(np.isnan(nn_energy)), False)
self.assertAllClose(np.any(np.isnan(nn_force)), False)
self.assertAllClose(nn_force.shape, [3,3])
@parameterized.named_parameters(jtu.cases_from_list(
{
'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment