Unverified Commit a8b0643f authored by ekindogus's avatar ekindogus Committed by GitHub
Browse files

Merge pull request #75 from google/neighbor_list

Update neighbor list code for increased safety.
parents 3f9fd3a6 3966f466
......@@ -38,6 +38,7 @@ To get started playing around with JAX MD check out the following colab notebook
- [Minimization](https://colab.research.google.com/github/google/jax-md/blob/master/notebooks/minimization.ipynb)
- [NVE Simulation](https://colab.research.google.com/github/google/jax-md/blob/master/notebooks/nve_simulation.ipynb)
- [NVT Simulation](https://colab.research.google.com/github/google/jax-md/blob/master/notebooks/nvt_simulation.ipynb)
- [NVE with Neighbor Lists](https://colab.research.google.com/github/google/jax-md/blob/master/notebooks/nve_neighbor_list.ipynb)
You can install JAX MD locally with pip,
```
......
......@@ -12,4 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from jax_md import space, energy, minimize, simulate, smap, partition
from jax_md import space
from jax_md import energy
from jax_md import minimize
from jax_md import simulate
from jax_md import smap
from jax_md import partition
from jax_md import dataclasses
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for defining dataclasses that can be used with jax transformations.
This code was copied and adapted from https://github.com/google/flax/struct.py.
Accessed on 04/29/2020.
"""
import dataclasses
import jax
def dataclass(clz):
"""Create a class which can be passed to functional transformations.
Jax transformations such as `jax.jit` and `jax.grad` require objects that are
immutable and can be mapped over using the `jax.tree_util` methods.
The `dataclass` decorator makes it easy to define custom classes that can be
passed safely to Jax.
Args:
clz: the class that will be transformed by the decorator.
Returns:
The new class.
"""
data_clz = dataclasses.dataclass(frozen=True)(clz)
meta_fields = []
data_fields = []
for name, field_info in data_clz.__dataclass_fields__.items():
is_static = field_info.metadata.get('static', False)
if is_static:
meta_fields.append(name)
else:
data_fields.append(name)
def iterate_clz(x):
meta = tuple(getattr(x, name) for name in meta_fields)
data = tuple(getattr(x, name) for name in data_fields)
return data, meta
def clz_from_iterable(meta, data):
meta_args = tuple(zip(meta_fields, meta))
data_args = tuple(zip(data_fields, data))
kwargs = dict(meta_args + data_args)
return data_clz(**kwargs)
jax.tree_util.register_pytree_node(data_clz,
iterate_clz,
clz_from_iterable)
return data_clz
def static_field():
return dataclasses.field(metadata={'static': True})
......@@ -94,19 +94,20 @@ def soft_sphere_pair(
def soft_sphere_neighbor_list(
displacement_or_metric,
box_size,
example_R,
species=None,
sigma=1.0,
epsilon=1.0,
alpha=2.0,
list_cutoff=1.2):
dr_threshold=0.2):
"""Convenience wrapper to compute soft spheres using a neighbor list."""
sigma = np.array(sigma, dtype=f32)
epsilon = np.array(epsilon, dtype=f32)
alpha = np.array(alpha, dtype=f32)
list_cutoff = f32(np.max(sigma) * list_cutoff)
list_cutoff = f32(np.max(sigma))
dr_threshold = f32(list_cutoff * dr_threshold)
neighbor_fn = partition.neighbor_list(
displacement_or_metric, box_size, list_cutoff, example_R)
displacement_or_metric, box_size, list_cutoff, dr_threshold)
energy_fn = smap.pair_neighbor_list(
soft_sphere,
space.canonicalize_displacement_or_metric(displacement_or_metric),
......@@ -159,23 +160,22 @@ def lennard_jones_pair(
def lennard_jones_neighbor_list(
displacement_or_metric,
box_size,
example_R,
species=None,
sigma=1.0,
epsilon=1.0,
alpha=2.0,
r_onset=2.0,
r_cutoff=2.5,
neighborlist_cutoff=3.0): # TODO(schsam) Optimize this.
dr_threshold=0.5): # TODO(schsam) Optimize this.
"""Convenience wrapper to compute lennard-jones using a neighbor list."""
sigma = np.array(sigma, f32)
epsilon = np.array(epsilon, f32)
r_onset = np.array(r_onset * np.max(sigma), f32)
r_cutoff = np.array(r_cutoff * np.max(sigma), f32)
list_cutoff = np.array(np.max(sigma) * neighborlist_cutoff, f32)
dr_threshold = np.array(np.max(sigma) * dr_threshold, f32)
neighbor_fn = partition.neighbor_list(
displacement_or_metric, box_size, list_cutoff, example_R)
displacement_or_metric, box_size, r_cutoff, dr_threshold)
energy_fn = smap.pair_neighbor_list(
multiplicative_isotropic_cutoff(lennard_jones, r_onset, r_cutoff),
space.canonicalize_displacement_or_metric(displacement_or_metric),
......
......@@ -20,6 +20,7 @@ from __future__ import print_function
from functools import reduce, partial
from collections import namedtuple
from typing import Any, Callable
import math
from operator import mul
......@@ -30,7 +31,7 @@ from jax.abstract_arrays import ShapedArray
from jax.interpreters import partial_eval as pe
import jax.numpy as np
from jax_md import quantity, space
from jax_md import quantity, space, dataclasses
from jax_md.util import *
......@@ -361,48 +362,108 @@ def _displacement_or_metric_to_metric_sq(displacement_or_metric):
'than 4.')
def neighbor_list(
displacement_or_metric, box_size, cutoff, example_R,
buffer_size_multiplier=1.1, cell_size=None, **static_kwargs):
"""Returns a function that builds a list neighbors for each point.
@dataclasses.dataclass
class NeighborList(object):
"""A struct containing the state of a Neighbor List.
Attributes:
idx: For an N particle system this is an `[N, max_occupancy]` array of
integers such that `idx[i, j]` is the jth neighbor of particle i.
reference_position: The positions of particles when the neighbor list was
constructed. This is used to decide whether the neighbor list ought to be
updated.
did_buffer_overflow: A boolean that starts out False. If there are ever
more neighbors than max_neighbors this is set to true to indicate that
there was a buffer overflow. If this happens, it means that the results
of the simulation will be incorrect and the simulation needs to be rerun
using a larger buffer.
max_occupancy: A static integer specifying the maximum size of the
neighbor list. Changing this will involk a recompilation.
cell_list_fn: A static python callable that is used to construct a cell
list used in an intermediate step of the neighbor list calculation.
"""
idx: np.ndarray
reference_position: np.ndarray
did_buffer_overflow: bool
max_occupancy: int = dataclasses.static_field()
cell_list_fn: Callable = dataclasses.static_field()
Since XLA requires fixed shape, we use example point configurations to
estimate the maximum number of points within a neighborhood. However, if the
configuration changes substantially over time it might be necessary to
revise this estimate.
def neighbor_list(
displacement_or_metric, box_size, r_cutoff, dr_threshold,
capacity_multiplier=1.25, cell_size=None, **static_kwargs):
"""Returns a function that builds a list neighbors for collections of points.
Neighbor lists must balance the need to be jit compatable with the fact that
under a jit the maximum number of neighbors cannot change (owing to static
shape requirements). To deal with this, our `neighbor_list` returns a
function `neighbor_fn` that can operate in two modes: 1) create a new
neighbor list or 2) update an existing neighbor list. Case 1) cannot be jit
and it creates a neighbor list with a maximum neighbor count of the current
neighbor count times capacity_multiplier. Case 2) is jit compatable, if any
particle has more neighbors than the maximum, the `did_buffer_overflow` bit
will be set to `True` and a new neighbor list will need to be created.
Here is a typical example of a simulation loop with neighbor lists:
>>> init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3)
>>> exact_init_fn, exact_apply_fn = simulate.nve(exact_energy_fn, shift, 1e-3)
>>>
>>> nbrs = neighbor_fn(R)
>>> state = init_fn(random.PRNGKey(0), R, neighbor_idx=nbrs.idx)
>>>
>>> def body_fn(i, state):
>>> state, nbrs = state
>>> nbrs = neighbor_fn(state.position, nbrs)
>>> state = apply_fn(state, neighbor_idx=nbrs.idx)
>>> return state, nbrs
>>>
>>> step = 0
>>> for _ in range(20):
>>> new_state, nbrs = lax.fori_loop(0, 100, body_fn, (state, nbrs))
>>> if nbrs.did_buffer_overflow:
>>> nbrs = neighbor_fn(state.position)
>>> else:
>>> state = new_state
>>> step += 1
Args:
displacement: A function `d(R_a, R_b)` that computes the displacement
between pairs of points.
box_size: Either a float specifying the size of the box or an array of
shape [spatial_dim] specifying the box size in each spatial dimension.
cutoff: A scalar specifying the neighborhood radius.
example_R: An ndarray of example points of shape [point_count, spatial_dim]
used to estimate a maximum neighborhood size.
buffer_size_multiplier: A floating point scalar specifying the fractional
r_cutoff: A scalar specifying the neighborhood radius.
dr_threshold: A scalar specifying the maximum distance particles can move
before rebuilding the neighbor list.
capacity_multiplier: A floating point scalar specifying the fractional
increase in maximum neighborhood occupancy we allocate compared with the
maximum in the example positions.
cell_size: A scalar specifying the size of cells in the cell list used
in an intermediate step.
cell_size: An optional scalar specifying the size of cells in the cell list
used in an intermediate step.
**static_kwargs: kwargs that get threaded through the calculation of
example positions.
Returns:
An ndarray of shape [point_count, maximum_neighbors_per_point] of ids
specifying points in the neighborhood of each point. Empty elements are
given an id = point_count.
A pair. The first element is a NeighborList containing the current neighbor
list. The second element contains a function
`neighbor_list_fn(R, neighbor_list=None)` that will update the neighbor
list. If neighbor_list is None then the function will construct a new
neighbor list whose capacity is inferred from R. If neighbor_list is given
then it will update the neighbor list (with fixed capacity) if any particle
has moved more than dr_threshold / 2. Note that only
`neighbor_list_fn(R, neighbor_list)` can be `jit` since it keeps array
shapes fixed.
"""
box_size = f32(box_size)
cutoff = r_cutoff + dr_threshold
cutoff_sq = cutoff ** 2
threshold_sq = (dr_threshold / f32(2)) ** 2
metric_sq = _displacement_or_metric_to_metric_sq(displacement_or_metric)
if cell_size is None:
cell_size = cutoff
cell_list_fn = cell_list(
box_size, cell_size, example_R, buffer_size_multiplier)
def neighbor_list_candidate_fn(R, **kwargs):
def neighbor_list_candidate_fn(cell_list_fn, R, **kwargs):
cl = cell_list_fn(R)
N, dim = R.shape
......@@ -437,36 +498,53 @@ def neighbor_list(
neighbor_idx = copy_values_from_cell(neighbor_idx, cell_idx, idx)
return neighbor_idx[:-1, :, 0]
# Use the example positions to estimate the maximum occupancy of the verlet
# list.
d_ex = partial(metric_sq, **static_kwargs)
d_ex = vmap(vmap(d_ex, (None, 0)))
N = example_R.shape[0]
example_idx = neighbor_list_candidate_fn(example_R)
example_neigh_R = example_R[example_idx]
example_neigh_dR = d_ex(example_R, example_neigh_R)
mask = np.logical_and(example_neigh_dR < cutoff_sq, example_idx < N)
max_occupancy = np.max(np.sum(mask, axis=1))
max_occupancy = int(max_occupancy * buffer_size_multiplier)
def neighbor_list_fn(R, **kwargs):
idx = neighbor_list_candidate_fn(R, **kwargs)
def prune_neighbor_list(R, idx, **kwargs):
d = partial(metric_sq, **kwargs)
d = vmap(vmap(d, (None, 0)))
N = R.shape[0]
neigh_R = R[idx]
dR = d(R, neigh_R)
argsort = np.argsort(
f32(1) - np.logical_and(dR < cutoff_sq, idx < N), axis=1)
mask = np.logical_and(dR < cutoff_sq, idx < N)
max_occupancy = np.max(np.sum(mask, axis=1))
argsort = np.argsort(f32(1) - mask, axis=1)
# TODO(schsam): Error checking for list exceeding maximum occupancy.
idx = np.take_along_axis(idx, argsort, axis=1)
idx = idx[:, :max_occupancy]
return idx, max_occupancy
def mask_self(idx):
self_mask = idx == np.reshape(np.arange(idx.shape[0]), (idx.shape[0], 1))
idx = np.where(self_mask, idx.shape[0], idx)
return np.where(self_mask, idx.shape[0], idx)
def neighbor_list_fn(R, neighbor_list=None, extra_capacity=0, **kwargs):
nbrs = neighbor_list
def neighbor_fn(R_and_overflow, max_occupancy=None):
R, overflow = R_and_overflow
idx = neighbor_list_candidate_fn(cell_list_fn, R, **kwargs)
idx, occupancy = prune_neighbor_list(R, idx, **kwargs)
if max_occupancy is None:
max_occupancy = int(occupancy * capacity_multiplier + extra_capacity)
return NeighborList(
mask_self(idx[:, :max_occupancy]), R,
np.logical_or(overflow, (max_occupancy <= occupancy)),
max_occupancy,
cell_list_fn)
if nbrs is None:
cell_list_fn = cell_list(box_size, cell_size, R, capacity_multiplier)
return neighbor_fn((R, False))
else:
cell_list_fn = nbrs.cell_list_fn
neighbor_fn = partial(neighbor_fn, max_occupancy=nbrs.max_occupancy)
return idx
d = partial(metric_sq, **kwargs)
d = vmap(d)
return lax.cond(
np.any(d(R, nbrs.reference_position) > threshold_sq),
(R, nbrs.did_buffer_overflow), neighbor_fn,
nbrs, lambda x: x)
return neighbor_list_fn
This diff is collapsed.
This diff is collapsed.
......@@ -25,12 +25,13 @@ INSTALL_REQUIRES = [
'absl-py',
'numpy',
'jax>=0.1.55',
'jaxlib>=0.1.37'
'jaxlib>=0.1.37',
'dataclasses'
]
setuptools.setup(
name='jax-md',
version='0.1.4',
version='0.1.5',
license='Apache 2.0',
author='Google',
author_email='jax-md-dev@google.com',
......
......@@ -53,7 +53,7 @@ if FLAGS.jax_enable_x64:
else:
POSITION_DTYPE = [f32]
update_test_tolerance(2e-5, 1e-7)
update_test_tolerance(2e-5, 1e-6)
def lattice_repeater(small_cell_pos, latvec, no_rep):
......@@ -243,13 +243,13 @@ class EnergyTest(jtu.JaxTestCase):
R = box_size * random.uniform(
key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
neighbor_fn, energy_fn = energy.soft_sphere_neighbor_list(
displacement, box_size, R)
displacement, box_size)
idx = neighbor_fn(R)
nbrs = neighbor_fn(R)
self.assertAllClose(
np.array(exact_energy_fn(R), dtype=dtype),
energy_fn(R, idx), True)
energy_fn(R, nbrs.idx), True)
@parameterized.named_parameters(jtu.cases_from_list(
{
......@@ -269,12 +269,12 @@ class EnergyTest(jtu.JaxTestCase):
R = box_size * random.uniform(
key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(
displacement, box_size, R)
displacement, box_size)
idx = neighbor_fn(R)
nbrs = neighbor_fn(R)
self.assertAllClose(
np.array(exact_energy_fn(R), dtype=dtype),
energy_fn(R, idx), True)
energy_fn(R, nbrs.idx), True)
@parameterized.named_parameters(jtu.cases_from_list(
{
......@@ -293,10 +293,10 @@ class EnergyTest(jtu.JaxTestCase):
r = box_size * random.uniform(
key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(
displacement, box_size, r)
displacement, box_size)
force_fn = quantity.force(energy_fn)
idx = neighbor_fn(r)
idx = neighbor_fn(r).idx
self.assertAllClose(
np.array(exact_force_fn(r), dtype=dtype),
force_fn(r, neighbor_idx=idx), True)
......@@ -321,11 +321,10 @@ class EnergyTest(jtu.JaxTestCase):
assert embedding_fn(np.array(1.0, dtype)).dtype == dtype
assert pairwise_fn(np.array(1.0, dtype)).dtype == dtype
eam_energy = energy.eam(displacement, charge_fn, embedding_fn, pairwise_fn)
tol = 1e-5 if dtype == np.float32 else 1e-6
self.assertAllClose(
eam_energy(
np.dot(atoms_repeated, inv_latvec)) / np.array(num_repetitions ** 3, dtype),
dtype(-3.363338), True, tol, tol)
dtype(-3.363338), True)
if __name__ == '__main__':
absltest.main()
......
......@@ -191,10 +191,10 @@ class NeighborListTest(jtu.JaxTestCase):
R = box_size * random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype)
N = R.shape[0]
neighbor_list_fn = partition.neighbor_list(
displacement, box_size, cutoff, R)
neighbor_fn = partition.neighbor_list(
displacement, box_size, cutoff, 0.0, 1.1)
idx = neighbor_list_fn(R)
idx = neighbor_fn(R).idx
R_neigh = R[idx]
mask = idx < N
......@@ -226,7 +226,7 @@ class NeighborListTest(jtu.JaxTestCase):
'dtype': dtype,
'dim': dim,
} for dtype in POSITION_DTYPE for dim in SPATIAL_DIMENSION))
def disabled_test_neighbor_list_build_time_dependent(self, dtype, dim):
def test_neighbor_list_build_time_dependent(self, dtype, dim):
key = random.PRNGKey(1)
if dim == 2:
......@@ -238,18 +238,21 @@ class NeighborListTest(jtu.JaxTestCase):
[[9.0, 0.0, t],
[0.0, 4.0, 0.0],
[0.0, 0.0, 7.25]])
min_length = np.min(np.diag(box_fn(0.)))
cutoff = f32(1.23)
cell_size = cutoff / np.diag(box_fn(0.))
# TODO(schsam): Get cell-list working with anisotropic cell sizes.
cell_size = cutoff / min_length
displacement, _ = space.periodic_general(box_fn)
metric = space.metric(displacement)
R = random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype)
N = R.shape[0]
neighbor_list_fn = partition.neighbor_list(metric, 1., cutoff, R,
cell_size=cell_size, t=f32(0.))
neighbor_list_fn = partition.neighbor_list(metric, 1., cutoff, 0.0,
1.1, cell_size=cell_size,
t=np.array(0.))
idx = neighbor_list_fn(R, t=0.25)
idx = neighbor_list_fn(R, t=np.array(0.25)).idx
R_neigh = R[idx]
mask = idx < N
......
......@@ -96,6 +96,57 @@ class SimulateTest(jtu.JaxTestCase):
assert np.abs(E_total - E_initial) < E_initial * 0.01
assert state.position.dtype == dtype
@parameterized.named_parameters(jtu.cases_from_list(
{
'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
'spatial_dimension': dim,
'dtype': dtype
} for dim in SPATIAL_DIMENSION for dtype in DTYPE))
def test_nve_neighbor_list(self, spatial_dimension, dtype):
Nx = particles_per_side = 8
spacing = f32(1.25)
tol = 1e-10 if dtype == np.float64 else 1e-3
L = Nx * spacing
if spatial_dimension == 2:
R = np.stack([np.array(r) for r in onp.ndindex(Nx, Nx)]) * spacing
elif spatial_dimension == 3:
R = np.stack([np.array(r) for r in onp.ndindex(Nx, Nx, Nx)]) * spacing
R = np.array(R, dtype)
displacement, shift = space.periodic(L)
neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(displacement, L)
exact_energy_fn = energy.lennard_jones_pair(displacement)
init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3)
exact_init_fn, exact_apply_fn = simulate.nve(exact_energy_fn, shift, 1e-3)
nbrs = neighbor_fn(R)
state = init_fn(random.PRNGKey(0), R, neighbor_idx=nbrs.idx)
exact_state = exact_init_fn(random.PRNGKey(0), R)
def body_fn(i, state):
state, nbrs, exact_state = state
nbrs = neighbor_fn(state.position, nbrs)
state = apply_fn(state, neighbor_idx=nbrs.idx)
return state, nbrs, exact_apply_fn(exact_state)
step = 0
for i in range(20):
new_state, nbrs, new_exact_state = lax.fori_loop(
0, 100, body_fn, (state, nbrs, exact_state))
if nbrs.did_buffer_overflow:
nbrs = neighbor_fn(state.position)
else:
state = new_state
exact_state = new_exact_state
step += 1
assert state.position.dtype == dtype
self.assertAllClose(state.position, exact_state.position, True, tol, tol)
@parameterized.named_parameters(jtu.cases_from_list(
{
'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
......
......@@ -544,8 +544,8 @@ class SMapTest(jtu.JaxTestCase):
R = box_size * random.uniform(
split, (N, spatial_dimension), dtype=dtype)
sigma = random.uniform(key, (), minval=0.5, maxval=2.5)
neighbor_fn = jit(partition.neighbor_list(disp, box_size, sigma, R))
idx = neighbor_fn(R)
neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.0)
idx = neighbor_fn(R).idx
self.assertAllClose(mapped_square(R, sigma=sigma),
neighbor_square(R, idx, sigma=sigma), True)
......@@ -577,8 +577,8 @@ class SMapTest(jtu.JaxTestCase):
R = box_size * random.uniform(
split, (N, spatial_dimension), dtype=dtype)
sigma = random.uniform(key, (), minval=0.5, maxval=2.5)
neighbor_fn = jit(partition.neighbor_list(disp, box_size, sigma, R))
idx = neighbor_fn(R)
neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.0)
idx = neighbor_fn(R).idx
self.assertAllClose(mapped_square(R, sigma=sigma),
neighbor_square(R, idx, sigma=sigma), True)
......@@ -610,8 +610,8 @@ class SMapTest(jtu.JaxTestCase):
R = box_size * random.uniform(
split, (N, spatial_dimension), dtype=dtype)
sigma = random.uniform(key, (), minval=0.5, maxval=4.5)
neighbor_fn = jit(partition.neighbor_list(disp, box_size, sigma, R))
idx = neighbor_fn(R)
neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.0)
idx = neighbor_fn(R).idx
self.assertAllClose(mapped_square(R, sigma=sigma),
neighbor_square(R, idx, sigma=sigma), True)
......@@ -642,9 +642,8 @@ class SMapTest(jtu.JaxTestCase):
key, split = random.split(key)
R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype)
sigma = random.uniform(key, (N,), minval=0.5, maxval=1.5)
neighbor_fn = jit(
partition.neighbor_list(disp, box_size, np.max(sigma), R))
idx = neighbor_fn(R)
neighbor_fn = partition.neighbor_list(disp, box_size, np.max(sigma), 0.)
idx = neighbor_fn(R).idx
self.assertAllClose(mapped_square(R, sigma=sigma),
neighbor_square(R, idx, sigma=sigma), True)
......@@ -676,9 +675,8 @@ class SMapTest(jtu.JaxTestCase):
R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype)
sigma = random.uniform(key, (N, N), minval=0.5, maxval=1.5)
sigma = 0.5 * (sigma + sigma.T)
neighbor_fn = jit(