Skip to content
Snippets Groups Projects
Commit 9cc3367a authored by sschoenholz's avatar sschoenholz
Browse files

Initial OSS commit of JAX, MD.

parents
Branches
No related tags found
No related merge requests found
*.pyc
*.so
*.egg-info
.ipynb_checkpoints
/bazel-*
.bazelrc
/tensorflow
.DS_Store
build/
dist/
.mypy_cache/
# How to Contribute
We'd love to accept your patches and contributions to this project. There are
just a few small guidelines you need to follow.
## Contributor License Agreement
Contributions to this project must be accompanied by a Contributor License
Agreement. You (or your employer) retain the copyright to your contribution;
this simply gives us permission to use and redistribute your contributions as
part of the project. Head over to <https://cla.developers.google.com/> to see
your current agreements on file or to sign a new one.
You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.
## Code reviews
All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.
## Community Guidelines
This project follows
[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
LICENSE 0 → 100644
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
\ No newline at end of file
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
README.md 0 → 100644
# JAX, M.D.
# Accelerated, Differentiable, Molecular Dynamics
Molecular dynamics is a workhorse of modern computational condensed matter
physics. It is frequently used to simulate materials to observe how small scale
interactions can give rise to complex large-scale phenomenology. Most molecular
dynamics packages (e.g. HOOMD Blue or LAMMPS) are complicated, specialized
pieces of code that are many thousands of lines long. They typically involve
significant code duplication to allow for running simulations on CPU and GPU.
Additionally, large amounts of code is often devoted to taking derivatives
of quantities to compute functions of interest (e.g. gradients of energies
to compute forces).
However, recent work in machine learning has led to significant software
developments that might make it possible to write significantly more concise
molecular dynamics simulations that offer a range of benefits. Here we target
JAX, which allows us to write python code that gets compiled to XLA and allows
us to run on CPU, GPU, or TPU. Moreover, JAX allows us to take derivatives of
python code. Thus, not only is this molecular dynamics simulation automatically
hardware accelerated, it is also __end-to-end__ differentiable. This should
allow for some interesting experiments that we're excited to explore.
JAX, MD is a research project that is currently under development. Expect
sharp edges and possibly some API breaking changes as we continue to support
a broader set of simulations.
### Getting Started
To get started playing around with JAX, MD check out the following colab
notebooks on Google Cloud without needing to install anything.
- [Minimization] (https://colab.research.google.com/github/google/jax-md/blob/master/notebooks/minimization.ipynb)
- [NVE Simulation] (https://colab.research.google.com/github/google/jax-md/blob/master/notebooks/nve_simulation.ipynb)
- [NVT Simulation] (https://colab.research.google.com/github/google/jax-md/blob/master/notebooks/nvt_simulation.ipynb)
Alternatively, you can install JAX, MD by first following the [JAX's](https://www.github.com/google/jax/)
installation instructions. Then installing JAX, MD should be as easy as,
```
git clone https://github.com/google/jax-md
pip install -e jax-md
```
# Overview
There are several aspects of the library.
## Spaces
In general we must have a way of computing the pairwise distance between atoms.
We must also have efficient strategies for moving atoms in some space that may
or may not be globally isomorphic to R^N. For example, periodic boundary
conditions are commonplace in simulations and must be respected. This part of
the code implements these functions.
Example:
```python
box_size = 25.0
displacement_fn, shift_fn = periodic(box_size)
```
## Potential Energy
In the simplest case, molecular dynamics calculations are often based on a pair
potential that is defined by a user. This then is used to compute a total energy
whose negative gradient gives forces. One of the very nice things about JAX is
that we get forces for free! The second part of the code is devoted to computing
energies. We provide a Soft Sphere potential and a Lennard Jones potential. We
also offer a convenience wrapper to compute the force.
Example:
```python
N = 1000
spatial_dimension = 2
key = random.PRNGKey(0)
R = random.uniform(key, (N, spatial_dimension), minval=0.0, maxval=1.0)
energy_fn = lennard_jones_pairwise(displacement)
print('E = {}'.format(energy(R)))
force_fn = force(energy_fn)
print('Total Squared Force = {}'.format(np.sum(force_fn(R) ** 2)))
```
## Dynamics
Given an energy function and a system, there are a number of dynamics are useful
to simulate. The simulation code is based on the structure of the optimizers
found in JAX. In particular, each simulation function returns an initialization
function and an update function. The initialization function takes a set of
positions and creates the necessary dynamical state variables. The update
function does a single step of dynamics to the dynamical state variables and
returns an updated state.
We include a several different kinds of dynamics. However, there is certainly room
to add more for e.g. constaint strain simulations.
It is often desirable to find an energy minimum of the system. We provide
two methods to do this. We provide simple gradient descent minimization. This is
mostly for pedagogical purposes, since it often performs poorly. We additionally
include the FIRE algorithm which often sees significantly faster convergence.
Moreover a common experiment to run in the context of molecular dynamics is to
simulate a system with a fixed volume and temperature. We provide the function
`nvt_nose_hoover` to do this.
Example:
```python
temperature = 1.0
dt = 1e-3
init, update = nvt_nose_hoover(energy, wrap_fn, dt, temperature)
state = init(R)
for i in range(100):
state = update(i, state)
R = get_positions(state)
```
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example showing the simple minimization of a two-dimensional system."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from jax import random
import jax.numpy as np
from jax_md import space, energy, minimize, quantity
def main(unused_argv):
key = random.PRNGKey(0)
# Setup some variables describing the system.
N = 5000
dimension = 2
box_size = 80.0
# Create helper functions to define a periodic box of some size.
displacement, shift = space.periodic(box_size)
# Use JAX's random number generator to generate random initial positions.
key, split = random.split(key)
R = random.uniform(
split, (N, dimension), minval=0.0, maxval=box_size, dtype=np.float64)
# The system ought to be a 50:50 mixture of two types of particles, one
# large and one small.
sigma = np.array([[1.0, 1.2], [1.2, 1.4]])
N_2 = int(N / 2)
species = np.array([0] * N_2 + [1] * N_2, dtype=np.int32)
# Create an energy function.
energy_fn = energy.soft_sphere_pairwise(displacement, species, sigma)
force_fn = quantity.force(energy_fn)
# Create a minimizer.
init_fn, apply_fn = minimize.fire_descent(energy_fn, shift)
opt_state = init_fn(R)
# Minimize the system.
minimize_steps = 200
print_every = 10
print('Minimizing.')
print('Step\tEnergy\tMax Force')
print('-----------------------------------')
for step in range(minimize_steps):
opt_state = apply_fn(opt_state)
if step % print_every == 0:
R = opt_state.position
print('{:.2f}\t{:.2f}\t{:.2f}'.format(
step, energy_fn(R), np.max(force_fn(R))))
if __name__ == '__main__':
app.run(main)
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from jax_md import space, energy, minimize, simulate, smap
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definitions of various standard energy functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import jax.numpy as np
from jax_md import space, smap
def soft_sphere(dR, sigma=1.0, epsilon=1.0, alpha=2.0):
"""Finite ranged repulsive interaction between soft spheres.
Args:
dR: An ndarray of shape [n, m, spatial_dimension] of displacement vectors
between particles.
sigma: Particle radii. Should either be a floating point scalar or an
ndarray whose shape is [n, m].
epsilon: Interaction energy scale. Should either be a floating point scalar
or an ndarray whose shape is [n, m].
alpha: Exponent specifying interaction stiffness. Should either be a float
point scalar or an ndarray whose shape is [n, m].
Returns:
Matrix of energies whose shape is [n, m].
"""
dr = space.distance(dR)
dr = dr / sigma
U = epsilon * np.where(dr < 1.0, 1.0 / alpha * (1.0 - dr) ** alpha, 0.0)
# NOTE(schsam): This seems a little bit janky. However, it seems possibly
# necessary because constants seemed to be upcast to float64.
return np.array(U, dtype=dr.dtype)
def soft_sphere_pairwise(
metric, species=None, sigma=1.0, epsilon=1.0, alpha=2.0):
"""Convenience wrapper to compute soft sphere energy over a system."""
return smap.pairwise(
soft_sphere,
metric,
species=species,
sigma=sigma,
epsilon=epsilon,
alpha=alpha)
def lennard_jones(dR, sigma, epsilon):
"""Lennard-Jones interaction between particles with a minimum at sigma.
Args:
dR: An ndarray of shape [n, m, spatial_dimension] of displacement vectors
between particles.
sigma: Distance between particles where the energy has a minimum. Should
either be a floating point scalar or an ndarray whose shape is [n, m].
epsilon: Interaction energy scale. Should either be a floating point scalar
or an ndarray whose shape is [n, m].
Returns:
Matrix of energies of shape [n, m].
"""
dr = space.square_distance(dR)
dr = sigma ** 2 / dr
idr6 = dr ** 3.0
idr12 = idr6 ** 2.0
return epsilon * (idr12 - 2 * idr6)
def lennard_jones_pairwise(
metric, species=None, sigma=1.0, epsilon=1.0):
"""Convenience wrapper to compute Lennard-Jones energy over a system."""
return smap.pairwise(
lennard_jones, metric, species=species, sigma=sigma, epsilon=epsilon)
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code to minimize the energy of a system.
This file contains a number of different methods that can be used to find the
nearest minimum (inherent structure) to some initial system described by a
position R.
In general, minimization code follows the same overall structure as optimizers
in JAX. Optimizers return two functions:
init_fn: function that initializes the state of an optimizer. Should take
positions as an ndarray of shape [n, output_dimension]. Returns a state
which will be a namedtuple.
apply_fn: function that takes a state and produces a new state after one
step of optimization.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import jax.numpy as np
from jax_md import quantity
from jax_md.util import register_pytree_namedtuple
GradientDescentState = namedtuple('GradientDescentState', ['position'])
register_pytree_namedtuple(GradientDescentState)
# pylint: disable=invalid-name
def gradient_descent(
energy_or_force, shift_fn, step_size, quant=quantity.Energy):
"""Defines gradient descent minimization.
This is the simplest optimization strategy that moves particles down their
gradient to the nearest minimum. Generally, gradient descent is slower than
other methods and is included mostly for its simplicity.
Args:
energy_or_force: A function that produces either an energy or a force from
a set of particle positions specified as an ndarray of shape
[n, spatial_dimension].
shift_fn: A function that displaces positions, R, by an amount dR. Both R
and dR should be ndarrays of shape [n, spatial_dimension].
step_size: A floating point specifying the size of each step.
quant: Either a quantity.Energy or a quantity.Force specifying whether
energy_or_force is an energy or force respectively.
Returns:
See above.
"""
force = quantity.canonicalize_force(energy_or_force, quant)
def init_fun(R, **unused_kwargs):
return GradientDescentState(R)
def apply_fun(state, **kwargs):
R, = state
R = shift_fn(R, step_size * force(R, **kwargs), **kwargs)
return GradientDescentState(R)
return init_fun, apply_fun
class FireDescentState(namedtuple(
'FireDescentState',
['position', 'velocity', 'force', 'dt', 'alpha', 'n_pos'])):
"""A tuple containing state information for the Fire Descent minimizer.
Attributes:
position: The current position of particles. An ndarray of floats
with shape [n, spatial_dimension].
velocity: The current velocity of particles. An ndarray of floats
with shape [n, spatial_dimension].
force: The current force on particles. An ndarray of floats
with shape [n, spatial_dimension].
dt: A float specifying the current step size.
alpha: A float specifying the current momentum.
n_pos: The number of steps in the right direction, so far.
"""
def __new__(cls, position, velocity, force, dt, alpha, n_pos):
return super(FireDescentState, cls).__new__(
cls, position, velocity, force, dt, alpha, n_pos)
register_pytree_namedtuple(FireDescentState)
def fire_descent(
energy_or_force, shift_fn, quant=quantity.Energy, dt_start=0.1,
dt_max=0.4, n_min=5, f_inc=1.1, f_dec=0.5, alpha_start=0.1, f_alpha=0.99):
"""Defines FIRE minimization.
This code implements the "Fast Inertial Relaxation Engine" from [1].
Args:
energy_or_force: A function that produces either an energy or a force from
a set of particle positions specified as an ndarray of shape
[n, spatial_dimension].
shift_fn: A function that displaces positions, R, by an amount dR. Both R
and dR should be ndarrays of shape [n, spatial_dimension].
quant: Either a quantity.Energy or a quantity.Force specifying whether
energy_or_force is an energy or force respectively.
dt_start: The initial step size during minimization as a float.
dt_max: The maximum step size during minimization as a float.
n_min: An integer specifying the minimum number of steps moving in the
correct direction before dt and f_alpha should be updated.
f_inc: A float specifying the fractional rate by which the step size
should be increased.
f_dec: A float specifying the fractional rate by which the step size
should be decreased.
alpha_start: A float specifying the initial momentum.
f_alpha: A float specifying the fractional change in momentum.
Returns:
See above.
[1] Bitzek, Erik, Pekka Koskinen, Franz Gahler, Michael Moseler,
and Peter Gumbsch. "Structural relaxation made simple."
Physical review letters 97, no. 17 (2006): 170201.
"""
force = quantity.canonicalize_force(energy_or_force, quant)
def init_fun(R, **kwargs):
V = np.zeros_like(R)
return FireDescentState(
R, V, force(R, **kwargs), np.array(dt_start), alpha_start, np.array(0))
def apply_fun(state, **kwargs):
R, V, F_old, dt, alpha, n_pos = state
R = shift_fn(R, dt * V + dt ** 2 * F_old, **kwargs)
F = force(R, **kwargs)
V = V + dt * 0.5 * (F_old + F)
# NOTE(schsam): This will be wrong if F_norm ~< 1e-8.
# TODO(schsam): We should check for forces below 1e-6. @ErrorChecking
F_norm = np.sqrt(np.sum(F ** 2) + 1e-6)
V_norm = np.sqrt(np.sum(V ** 2))
P = np.array(np.dot(np.reshape(F, (-1)), np.reshape(V, (-1))))
V = V + alpha * (F * V_norm / F_norm - V)
# NOTE(schsam): Can we clean this up at all?
n_pos = np.where(P > 0, n_pos + 1.0, n_pos)
dt_choice = np.array([dt * f_inc, dt_max])
dt = np.where(
P > 0, np.where(n_pos > n_min, np.min(dt_choice), dt), dt)
dt = np.where(P < 0, dt * f_dec, dt)
alpha = np.where(
P > 0, np.where(n_pos > n_min, alpha * f_alpha, alpha), alpha)
alpha = np.where(P < 0, alpha_start, alpha)
V = (P < 0) * np.zeros_like(V) + (P >= 0) * V
return FireDescentState(R, V, F, dt, alpha, n_pos)
return init_fun, apply_fun
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Describes different physical quantities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from jax import grad
import jax.numpy as np
def force(energy):
"""Computes the force as the negative gradient of an energy."""
return grad(lambda R, *args, **kwargs: -energy(R, *args, **kwargs))
def canonicalize_force(energy_or_force, quantity):
if quantity is Force:
return energy_or_force
elif quantity is Energy:
return force(energy_or_force)
raise ValueError(
'Expected quantity to be Energy or Force, but found {}'.format(quantity))
class Force(object):
"""Dummy object to denote whether a quantity is a force."""
pass
Force = Force()
class Energy(object):
"""Dummy object to denote whether a quantity is an energy."""
pass
Energy = Energy()
class Dynamic(object):
"""Object used to denote dynamic shapes and species."""
pass
Dynamic = Dynamic()
def kinetic_energy(V, mass=1.0):
"""Computes the kinetic energy of a system with some velocities."""
return 0.5 * np.sum(mass * V ** 2)
def temperature(V, mass=1.0):
"""Computes the temperature of a system with some velocities."""
N, dim = V.shape
return np.sum(mass * V ** 2) / (N * dim)
def canonicalize_mass(mass):
if isinstance(mass, float):
return mass
elif isinstance(mass, np.ndarray):
if len(mass.shape) == 2 and mass.shape[1] == 1:
return mass
elif len(mass.shape) == 1:
return np.reshape(mass, (mass.shape[0], 1))
elif len(mass.shape) == 0:
return mass
msg = (
'Expected mass to be either a floating point number or a one-dimensional'
'ndarray. Found {}.'.format(mass)
)
raise ValueError(msg)
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code to simulate systems in various statistical ensembles.
This file contains a number of different methods that can be used to
simulate systems in a variety of ensembles.
In general, simulation code follows the same overall structure as optimizers
in JAX. Simulations are tuples of two functions:
init_fn: function that initializes the state of a system. Should take
positions as an ndarray of shape [n, output_dimension]. Returns a state
which will be a namedtuple.
apply_fn: function that takes a state and produces a new state after one
step of optimization.
One question that we need to think about is whether the simulations should
also return a function that computes the invariant for that ensemble. This can
be used for testing purposes, but is not often used otherwise.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
from jax import ops
from jax import random
import jax.numpy as np
from jax_md import quantity
from jax_md import time_dependence
from jax_md.util import register_pytree_namedtuple
class NVEState(namedtuple(
'NVEState', ['position', 'velocity', 'acceleration', 'mass'])):
"""A tuple containing the state of an NVE simulation.
This tuple stores the state of a simulation that samples from the
microcanonical ensemble in which the (N)umber of particles, the (V)olume, and
the (E)nergy of the system are held fixed.
Attributes:
position: An ndarray of shape [n, spatial_dimension] storing the position
of particles.
velocity: An ndarray of shape [n, spatial_dimension] storing the velocity
of particles.
acceleration: An ndarray of shape [n, spatial_dimension] storing the
acceleration of particles from the previous step.
mass: A float or an ndarray of shape [n] containing the masses of the
particles.
"""
def __new__(cls, position, velocity, acceleration, mass):
return super(NVEState, cls).__new__(
cls, position, velocity, acceleration, mass)
register_pytree_namedtuple(NVEState)
# pylint: disable=invalid-name
def nve(energy_or_force, shift_fn, dt, quant=quantity.Energy):
"""Simulates a system in the NVE ensemble.
Samples from the microcanonical ensemble in which the number of particles (N),
the system volume (V), and the energy (E) are held constant. We use a standard
velocity verlet integration scheme.
Args:
energy_or_force: A function that produces either an energy or a force from
a set of particle positions specified as an ndarray of shape
[n, spatial_dimension].
shift_fn: A function that displaces positions, R, by an amount dR. Both R
and dR should be ndarrays of shape [n, spatial_dimension].
dt: Floating point number specifying the timescale (step size) of the
simulation.
quant: Either a quantity.Energy or a quantity.Force specifying whether
energy_or_force is an energy or force respectively.
Returns:
See above.
"""
force = quantity.canonicalize_force(energy_or_force, quant)
dt_2 = 0.5 * dt ** 2
def init_fun(key, R, velocity_scale=1.0, mass=1.0):
V = np.sqrt(velocity_scale) * random.normal(key, R.shape)
mass = quantity.canonicalize_mass(mass)
return NVEState(R, V, force(R) / mass, mass)
def apply_fun(state, t=None, **kwargs):
R, V, A, mass = state
R = shift_fn(R, V * dt + A * dt_2, t=t, **kwargs)
A_prime = force(R, t=t, **kwargs) / mass
V = V + 0.5 * (A + A_prime) * dt
return NVEState(R, V, A_prime, mass)
return init_fun, apply_fun
class NVTNoseHooverState(namedtuple(
'NVTNoseHooverState',
[
'position',
'velocity',
'mass',
'kinetic_energy',
'xi',
'v_xi',
'Q',
])):
"""A tuple containing state information for the Nose-Hoover chain thermostat.
Attributes:
position: The current position of particles. An ndarray of floats
with shape [n, spatial_dimension].
velocity: The velocity of particles. An ndarray of floats
with shape [n, spatial_dimension].
mass: The mass of the particles. Can either be a float or an ndarray
of floats with shape [n].
kinetic_energy: A float that stores the current kinetic energy of the
system.
xi: An ndarray of shape [chain_length] that stores the "positional" degrees
of freedom for the Nose-Hoover thermostat.
v_xi: An ndarray of shape [chain_length] that stores the "velocity" degrees
of freedom for the Nose-Hoover thermostat.
Q: An ndarray of shape [chain_length] that stores the mass of the
Nose-Hoover chain.
"""
def __new__(cls, position, velocity, mass, kinetic_energy, xi, v_xi, Q):
return super(NVTNoseHooverState, cls).__new__(
cls, position, velocity, mass, kinetic_energy, xi, v_xi, Q)
register_pytree_namedtuple(NVTNoseHooverState)
def nvt_nose_hoover(
energy_or_force, shift_fn, dt, T_schedule, quant=quantity.Energy,
chain_length=5, tau=0.01):
"""Simulation in the NVT ensemble using a Nose Hoover Chain thermostat.
Samples from the canonical ensemble in which the number of particles (N),
the system volume (V), and the temperature (T) are held constant. We use a
Nose Hoover Chain thermostat described in [1, 2, 3]. We employ a similar
notation to [2] and the interested reader might want to look at that paper as
a reference.
Currently, the implementation only does a single timestep per Nose-Hoover
step. At some point we should support the multi-step case.
Args:
energy_or_force: A function that produces either an energy or a force from
a set of particle positions specified as an ndarray of shape
[n, spatial_dimension].
shift_fn: A function that displaces positions, R, by an amount dR. Both R
and dR should be ndarrays of shape [n, spatial_dimension].
dt: Floating point number specifying the timescale (step size) of the
simulation.
T_schedule: Either a floating point number specifying a constant temperature
or a function specifying temperature as a function of time.
quant: Either a quantity.Energy or a quantity.Force specifying whether
energy_or_force is an energy or force respectively.
chain_length: An integer specifying the length of the Nose-Hoover chain.
tau: A floating point timescale over which temperature equilibration occurs.
The performance of the Nose-Hoover chain thermostat is quite sensitive to
this choice.
Returns:
See above.
[1] Martyna, Glenn J., Michael L. Klein, and Mark Tuckerman.
"Nose-Hoover chains: The canonical ensemble via continuous dynamics."
The Journal of chemical physics 97, no. 4 (1992): 2635-2643.
[2] Martyna, Glenn, Mark Tuckerman, Douglas J. Tobias, and Michael L. Klein.
"Explicit reversible integrators for extended systems dynamics."
Molecular Physics 87. (1998) 1117-1157.
[3] Tuckerman, Mark E., Jose Alejandre, Roberto Lopez-Rendon,
Andrea L. Jochim, and Glenn J. Martyna.
"A Liouville-operator derived measure-preserving integrator for molecular
dynamics simulations in the isothermal-isobaric ensemble."
Journal of Physics A: Mathematical and General 39, no. 19 (2006): 5629.
"""
force = quantity.canonicalize_force(energy_or_force, quant)
dt_2 = dt / 2.0
dt_4 = dt_2 / 2.0
dt_8 = dt_4 / 2.0
T_schedule = time_dependence.canonicalize(T_schedule)
def init_fun(key, R, mass=1.0, T_initial=1.0):
mass = quantity.canonicalize_mass(mass)
V = np.sqrt(T_initial / mass) * random.normal(key, R.shape)
V = V - np.mean(V, axis=0, keepdims=True)
KE = quantity.kinetic_energy(V, mass)
# Nose-Hoover parameters.
xi = np.zeros(chain_length)
v_xi = np.zeros(chain_length)
# TODO(schsam): Really, it seems like Q should be set by the goal
# temperature rather than the initial temperature.
DOF = R.shape[0] * R.shape[1]
Q = T_initial * tau ** 2 * np.ones(chain_length)
Q = ops.index_update(Q, 0, Q[0] * DOF)
return NVTNoseHooverState(R, V, mass, KE, xi, v_xi, Q)
def step_chain(KE, V, xi, v_xi, Q, DOF, T):
"""Applies a single update to the chain parameters and rescales velocity."""
M = chain_length - 1
# TODO(schsam): We can probably cache the G parameters from the previous
# update.
# TODO(schsam): It is also probably the case that we could do a better job
# of vectorizing this code.
G = (Q[M - 1] * v_xi[M - 1] ** 2 - T) / Q[M]
v_xi = ops.index_add(v_xi, M, dt_4 * G)
for m in range(M - 1, 0, -1):
G = (Q[m - 1] * v_xi[m - 1] ** 2 - T) / Q[m]
scale = np.exp(-dt_8 * v_xi[m + 1])
v_xi = ops.index_update(v_xi, m, scale * (scale * v_xi[m] + dt_4 * G))
G = (2.0 * KE - DOF * T) / Q[0]
scale = np.exp(-dt_8 * v_xi[1])
v_xi = ops.index_update(v_xi, 0, scale * (scale * v_xi[0] + dt_4 * G))
scale = np.exp(-dt_2 * v_xi[0])
KE = KE * scale ** 2
V = V * scale
xi = xi + dt_2 * v_xi
G = (2.0 * KE - DOF * T) / Q[0]
for m in range(M):
scale = np.exp(-dt_8 * v_xi[m + 1])
v_xi = ops.index_update(v_xi, m, scale * (scale * v_xi[m] + dt_4 * G))
G = (Q[m] * v_xi[m] ** 2 - T) / Q[m + 1]
v_xi = ops.index_add(v_xi, M, dt_4 * G)
return KE, V, xi, v_xi
def apply_fun(state, t=0.0, **kwargs):
T = T_schedule(t)
R, V, mass, KE, xi, v_xi, Q = state
DOF = R.shape[0] * R.shape[1]
Q = T * tau ** 2 * np.ones(chain_length)
Q = ops.index_update(Q, 0, Q[0] * DOF)
KE, V, xi, v_xi = step_chain(KE, V, xi, v_xi, Q, DOF, T)
R = shift_fn(R, V * dt_2, t=t, **kwargs)
F = force(R, t=t, **kwargs)
V = V + dt * F / mass
# NOTE(schsam): Do we need to mean subtraction here?
V = V - np.mean(V, axis=0, keepdims=True)
KE = quantity.kinetic_energy(V, mass)
R = shift_fn(R, V * dt_2, t=t, **kwargs)
KE, V, xi, v_xi = step_chain(KE, V, xi, v_xi, Q, DOF, T)
return NVTNoseHooverState(R, V, mass, KE, xi, v_xi, Q)
return init_fun, apply_fun
This diff is collapsed.
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code to define different spaces in which particles are simulated.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from jax import ad_util
from jax import custom_transforms
from jax.interpreters import ad
import jax.numpy as np
# Primitive Spatial Transforms
# pylint: disable=invalid-name
def _check_transform_shapes(T, v=None):
"""Check whether a transform and collection of vectors have valid shape."""
if len(T.shape) != 2:
raise ValueError(
('Transform has invalid rank.'
' Found rank {}, expected rank 2.'.format(len(T.shape))))
if T.shape[0] != T.shape[1]:
raise ValueError('Found non-square transform.')
if v is not None and v.shape[-1] != T.shape[1]:
raise ValueError(
('Transform and vectors have incommensurate spatial dimension. '
'Found {} and {} respectively.'.format(T.shape[1], v.shape[-1])))
def _small_inverse(T):
"""Compute the inverse of a small matrix."""
_check_transform_shapes(T)
dim = T.shape[0]
# TODO(schsam): Check whether matrices are singular. @ErrorChecking
if dim == 2:
det = T[0, 0] * T[1, 1] - T[0, 1] * T[1, 0]
return np.array([[T[1, 1], -T[0, 1]], [-T[1, 0], T[0, 0]]]) / det
# TODO(schsam): Fill in the 3x3 case by hand.
return np.linalg.inv(T)
@custom_transforms
def transform(T, v):
"""Apply a linear transformation, T, to a collection of vectors, v.
Transform is written such that it acts as the identity during gradient
backpropagation.
Args:
T: Transformation; ndarray(shape=[spatial_dim, spatial_dim]).
v: Collection of vectors; ndarray(shape=[..., spatial_dim]).
Returns:
Transformed vectors; ndarray(shape=[..., spatial_dim]).
"""
_check_transform_shapes(T, v)
return np.dot(v, T)
ad.defjvp(
transform.primitive,
lambda g, T, v: ad_util.zero,
lambda g, T, v: g
)
def pairwise_displacement(Ra, Rb):
"""Compute a matrix of pairwise displacements given two sets of positions.
Args:
Ra: Vector of positions; ndarray(shape=[n, spatial_dim]).
Rb: Vector of positions; ndarray(shape=[m, spatial_dim]).
Returns:
Matrix of displacements; ndarray(shape=[n, m, spatial_dim]).
"""
return Ra[:, np.newaxis, :] - Rb[np.newaxis, :, :]
def periodic_displacement(side, dR):
"""Wraps displacement vectors into a hypercube.
Args:
side: Specification of hypercube size. Either,
(a) float if all sides have equal length.
(b) ndarray(spatial_dim) if sides have different lengths.
dR: Matrix of displacements; ndarray(shape=[..., spatial_dim]).
Returns:
Matrix of wrapped displacements; ndarray(shape=[..., spatial_dim]).
"""
zero = np.array(0.0, dtype=dR.dtype)
return dR - np.where(np.abs(dR) < 0.5 * side, zero, np.sign(dR) * side)
def square_distance(dR):
"""Computes square distances.
Args:
dR: Matrix of displacements; ndarray(shape=[..., spatial_dim]).
Returns:
Matrix of squared distances; ndarray(shape=[...]).
"""
return np.sum(dR ** 2, axis=-1)
def distance(dR):
"""Computes distances.
Args:
dR: Matrix of displacements; ndarray(shape=[..., spatial_dim]).
Returns:
Matrix of distances; ndarray(shape=[...]).
"""
return np.sqrt(square_distance(dR))
def periodic_shift(side, R, dR):
"""Shifts positions, wrapping them back within a periodic hypercube."""
return np.mod(R + dR, side)
"""
Spaces
The following functions provide the necessary transformations to perform
simulations in different spaces.
Spaces are tuples containing:
displacement_fn(Ra, Rb, **kwargs): Computes displacements between pairs of
particles. Ra and Rb should be ndarrays of shape [N, spatial_dim] and
[M, spatial_dim] respectively. Returns an ndarray of shape
[N, M, spatial_dim].
shift_fn(R, dR, **kwargs): Moves points at position R by an amount dR.
In each case, **kwargs is optional keyword arguments that can be supplied to
the different functions. In cases where the space features time dependence
this will be passed through a "t" keyword argument.
"""
def free():
"""Free boundary conditions."""
def displacement_fn(Ra, Rb, **unused_kwargs):
return pairwise_displacement(Ra, Rb)
def shift_fn(R, dR, **unused_kwargs):
return R + dR
return displacement_fn, shift_fn
def periodic(side):
"""Periodic boundary conditions on a hypercube of sidelength side."""
def displacement_fn(Ra, Rb, **unused_kwargs):
return periodic_displacement(side, pairwise_displacement(Ra, Rb))
def shift_fn(R, dR, **unused_kwargs):
return periodic_shift(side, R, dR)
return displacement_fn, shift_fn
def _check_time_dependence(t):
if t is None:
msg = ('Space has time-dependent transform, but no time has been '
'provided. (t = {})'.format(t))
raise ValueError(msg)
def periodic_general(T):
"""Periodic boundary conditions on a parallelepiped.
This function defines a simulation on a parellelepiped formed by applying an
affine transformation to the unit hypercube [0, 1]^spatial_dimension.
When using periodic_general, particles positions should be stored in the unit
hypercube. To get real positions from the simulation you should call
R_sim = space.transform(T, R_unit_cube).
The affine transformation can feature time dependence (if T is a function
instead of a scalar). In this case the resulting space will also be time
dependent. This can be useful for simulating systems under mechanical strain.
Args:
T: An affine transformation.
Either:
1) An ndarray of shape [spatial_dim, spatial_dim].
2) A function that takes floating point times and produces ndarrays of
shape [spatial_dim, spatial_dim].
Returns:
(metric_fn, displacement_fn, shift_fn) tuple.
"""
if callable(T):
def displacement(Ra, Rb, t=None, **unused_kwargs):
_check_time_dependence(t)
dR = periodic_displacement(1.0, pairwise_displacement(Ra, Rb))
return transform(T(t), dR)
# Can we cache the inverse? @Optimization
def shift(R, dR, t=None, **unused_kwargs):
_check_time_dependence(t)
return periodic_shift(1.0, R, transform(_small_inverse(T(t)), dR))
else:
T_inv = _small_inverse(T)
def displacement(Ra, Rb, **unused_kwargs):
dR = periodic_displacement(1.0, pairwise_displacement(Ra, Rb))
return transform(T, dR)
def shift(R, dR, **unused_kwargs):
return periodic_shift(1.0, R, transform(T_inv, dR))
return displacement, shift
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for constructing various interpolating functions.
This code was adapted from the way learning rate schedules are are built in JAX.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import jax.numpy as np
def constant(f):
def schedule(unused_t):
return f
return schedule
def canonicalize(scalar_or_schedule_fun):
if callable(scalar_or_schedule_fun):
return scalar_or_schedule_fun
elif np.ndim(scalar_or_schedule_fun) == 0:
return constant(scalar_or_schedule_fun)
else:
raise TypeError(type(scalar_or_schedule_fun))
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines utility functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from jax.tree_util import register_pytree_node
def register_pytree_namedtuple(cls):
register_pytree_node(
cls,
lambda xs: (tuple(xs), None),
lambda _, xs: cls(*xs))
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
setup.py 0 → 100644
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import setuptools
INSTALL_REQUIRES = [
'absl-py',
'numpy',
'jax',
]
setuptools.setup(
name='jax-md',
version='0.0.0',
license='Apache 2.0',
author='Google',
author_email='jax-md-dev@google.com',
install_requires=INSTALL_REQUIRES,
url='https://github.com/google/jax-md',
packages=setuptools.find_packages()
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment