Commit fbbc342c authored by Sam Schoenholz's avatar Sam Schoenholz
Browse files

Reverted custom transforms.

parent 47d48f66
......@@ -77,9 +77,53 @@ def _small_inverse(T):
# TODO(schsam): Check whether matrices are singular. @ErrorChecking
return np.linalg.inv(T)
# Tempororay code to do transforms while JAX's custom_transform + vmap is
# broken. See TODO below.
from jax.interpreters import ad
from jax.interpreters import batching
from jax.abstract_arrays import raise_to_shaped
from jax import lax
from jax.lib import xla_client
from jax.api import _dtype
def transform_shape_rule(T, v):
return v.shape
def transform_dtype_rule(T, v):
return v.dtype
def transform_translation_rule(c, T, v):
v_dim = len(c.GetShape(v).dimensions()) - 1
dimension_numbers = xla_client.make_dot_dimension_numbers((((v_dim,), (0,)), ((), ())))
return xla_client.ops.DotGeneral(v, T, dimension_numbers)
transform_p = lax.standard_primitive(
transform_shape_rule, transform_dtype_rule,
'transform', transform_translation_rule)
def transform_batching_rule(operands, batch_dims):
T, v = operands
T_dim, v_dim = batch_dims
assert T_dim is None
assert v_dim == 0
return transform_p.bind(T, v), v_dim
batching.primitive_batchers[transform_p] = transform_batching_rule # type: ignore
def transform_jvp(primals, tangents):
T, v = primals
gT, gv = tangents
return transform_p.bind(T, v), gv
ad.primitive_jvps[transform_p] = transform_jvp
@custom_jvp
def transform(T, v):
T_dtype, v_dtype = _dtype(T), _dtype(v)
if T_dtype != v_dtype:
higher_dtype = lax.dtypes.promote_types(T_dtype, v_dtype)
if higher_dtype == v_dtype:
T = lax.convert_element_type(T, v_dtype)
else:
v = lax.convert_element_type(v, T_dtype)
return transform_p.bind(T, v)
@custom_jvp
def _transform(T, v):
"""Apply a linear transformation, T, to a collection of vectors, v.
Transform is written such that it acts as the identity during gradient
......@@ -96,7 +140,7 @@ def transform(T, v):
return np.dot(v, T)
@transform.defjvp
@_transform.defjvp
def transform_jvp(primals, tangents):
T, v = primals
dT, dv = tangents
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment