Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Stefan Kesselheim
jax-md
Commits
fbbc342c
Commit
fbbc342c
authored
Jun 10, 2020
by
Sam Schoenholz
Browse files
Reverted custom transforms.
parent
47d48f66
Changes
1
Hide whitespace changes
Inline
Side-by-side
jax_md/space.py
View file @
fbbc342c
...
...
@@ -77,9 +77,53 @@ def _small_inverse(T):
# TODO(schsam): Check whether matrices are singular. @ErrorChecking
return
np
.
linalg
.
inv
(
T
)
# Tempororay code to do transforms while JAX's custom_transform + vmap is
# broken. See TODO below.
from
jax.interpreters
import
ad
from
jax.interpreters
import
batching
from
jax.abstract_arrays
import
raise_to_shaped
from
jax
import
lax
from
jax.lib
import
xla_client
from
jax.api
import
_dtype
def
transform_shape_rule
(
T
,
v
):
return
v
.
shape
def
transform_dtype_rule
(
T
,
v
):
return
v
.
dtype
def
transform_translation_rule
(
c
,
T
,
v
):
v_dim
=
len
(
c
.
GetShape
(
v
).
dimensions
())
-
1
dimension_numbers
=
xla_client
.
make_dot_dimension_numbers
((((
v_dim
,),
(
0
,)),
((),
())))
return
xla_client
.
ops
.
DotGeneral
(
v
,
T
,
dimension_numbers
)
transform_p
=
lax
.
standard_primitive
(
transform_shape_rule
,
transform_dtype_rule
,
'transform'
,
transform_translation_rule
)
def
transform_batching_rule
(
operands
,
batch_dims
):
T
,
v
=
operands
T_dim
,
v_dim
=
batch_dims
assert
T_dim
is
None
assert
v_dim
==
0
return
transform_p
.
bind
(
T
,
v
),
v_dim
batching
.
primitive_batchers
[
transform_p
]
=
transform_batching_rule
# type: ignore
def
transform_jvp
(
primals
,
tangents
):
T
,
v
=
primals
gT
,
gv
=
tangents
return
transform_p
.
bind
(
T
,
v
),
gv
ad
.
primitive_jvps
[
transform_p
]
=
transform_jvp
@
custom_jvp
def
transform
(
T
,
v
):
T_dtype
,
v_dtype
=
_dtype
(
T
),
_dtype
(
v
)
if
T_dtype
!=
v_dtype
:
higher_dtype
=
lax
.
dtypes
.
promote_types
(
T_dtype
,
v_dtype
)
if
higher_dtype
==
v_dtype
:
T
=
lax
.
convert_element_type
(
T
,
v_dtype
)
else
:
v
=
lax
.
convert_element_type
(
v
,
T_dtype
)
return
transform_p
.
bind
(
T
,
v
)
@
custom_jvp
def
_transform
(
T
,
v
):
"""Apply a linear transformation, T, to a collection of vectors, v.
Transform is written such that it acts as the identity during gradient
...
...
@@ -96,7 +140,7 @@ def transform(T, v):
return
np
.
dot
(
v
,
T
)
@
transform
.
defjvp
@
_
transform
.
defjvp
def
transform_jvp
(
primals
,
tangents
):
T
,
v
=
primals
dT
,
dv
=
tangents
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment