Skip to content

Adding value_and_jac* functions#762

Open
matt-graham wants to merge 1 commit intojax-ml:mainfrom
matt-graham:side-values
Open

Adding value_and_jac* functions#762
matt-graham wants to merge 1 commit intojax-ml:mainfrom
matt-graham:side-values

Conversation

@matt-graham
Copy link
Copy Markdown
Contributor

At the moment while the Jacobian functions returned by both jacfwd and jacrev calculate the value of the function evaluated at the passed arguments the Jacobian is being computed for, there is no way to access this value. Often the value of the function will be needed in conjunction with the Jacobian - for example in a when applying Newton's method to find a root of a vector-valued function we require both the Jacobian and value of the function on each iteration.

This pull-request adds an optional return_value argument to both jacfwd and jacrev which if True makes the returned function return a tuple (jacobian, value) where jacobian is the Jacobian evaluated at the passed arguments and value is the value of the function evaluated at the passed arguments. If return_value is False (the default), jacfwd and jacrev retain their current behaviour - i.e. they return a function which returns only the Jacobian.

There already exists a value_and_grad function which implements an analogous operator for the specific case of scalar-valued functions. It would be more consistent with this existing function to define value_and_jacrev and value_and_jacfwd functions rather than using an additional argument, however this seemed like it could end up producing a lot of code and documentation duplication. The design proposed here is also more in keeping with the has_aux optional argument to grad and value_and_grad which alter the return signature of the returned function.

The same idea could also be applied to hessian function to optionally return both the Jacobian and value of the function being differentiated in addition to the Hessian itself. This would be useful for example when using Newton's method in an optimisation setting, where the objective function would be scalar valued and we would need both the Hessian and Jacobian (gradient) for the Newton iteration, and the value objective function to monitor convergence. The current pull-request implements part of the necessary functionality to do this as hessian internally calls jacfwd on the output of a jacrev call, with respectively the Jacobian and value being computed as intermediate values during these calls. It seems the best way to plumb everything together however would be to add a similar has_aux optional argument to jacfwd to allow it to be applied to a function returned by jacrev which returns both the Jacobian and value of a function. Adding has_aux to jacfwd and jacrev was mentioned as a todo in PR #499 however it looks like that PR was closed without merging and the related issue #497 remains open, so I'm not sure if there is some issue with adding has_aux for forward-mode?

@proteneer
Copy link
Copy Markdown
Contributor

FWIW I'm not sure I like the idea of functions that sometimes return one value and sometimes two value. I view as jacfwd one a function that should be kept as simple as possible, anything more complicated you should just vmap over jvps manually. Else you'd also have to do this for jax.grad

@mattjj
Copy link
Copy Markdown
Collaborator

mattjj commented May 23, 2019

Thanks so much for this contribution, and crystal-clear explanation! All that thinking makes sense to me. The code looks perfect as well. The only thing we should think through is this kwargs business, and I have to say up front that I don't know what the right balance is.

The kwargs-vs-separate-functions thing is always tricky. We've clearly gone with a bit of both already, as you pointed out (value_and_grad vs has_aux). As @matt-graham wrote, by making this a flag we might avoid code and documentation duplication, but as @proteneer pointed out it means the functions themselves get more complicated.

One slight bias that might tip the scales is just to be consistent with value_and_grad.

I have an idea, based on something @skye recently suggested for jit: maybe we can factor things into internal _jacfwd and _jacrev functions that have the return_value parameter (and look like the code in this PR, thus avoiding code duplication), then write thin wrappers jacfwd, jacrev, value_and_jacfwd, value_and_jacrev that have the right docstrings and just call the underscore-methods with the appropriate value for the return_value arg (thus keeping each API function simple). (That plan wouldn't necessarily a void duplicating the docstrings themselves, but if we really wanted to avoid that we could templatize the docstring and build it programmatically.)

I'm not sure if that idea would be preferable, but I think it's worth trying out so we can step back and look over the alternatives. What do you think, @matt-graham? Want to give that a shot, or should I try it?

@mattjj
Copy link
Copy Markdown
Collaborator

mattjj commented May 23, 2019

(By the way, I can't remember why I closed #499...)

@matt-graham
Copy link
Copy Markdown
Contributor Author

Defining internal _jacrev and _jacfwd functions with a return_value keyword argument and providing separate value_and_jac* and jac* functions as thin-wrappers seems a good compromise to me - I've pushed a commit which attempts to implement something along those lines.

At the moment I've kept the order the return values for _jacfwd and _jacrev with return_values set to True as the same i.e. (jacobian, value) rather than (value, jacobian). This means the value_and_jac* wrappers have to wrap the returned function to swap the return order which is bit ugly. My rationale for keeping the return values in this order though is to allow _jacfwd and _jacrev with return_value=True to be easily composable with differential operators which implement has_aux to allow higher-order differentiation - for example to add a value_jacobian_and_hessian function. It might be that this is a rare use-case though and it would be better to do any swapping within such compositions rather than in the value_and_jac* functions. @mattjj any thoughts on this?

@shoyer
Copy link
Copy Markdown
Collaborator

shoyer commented Aug 1, 2019

+1 for adding value_and_* rather than return_value to JAX's public APIs.

@pierthodo
Copy link
Copy Markdown

Is this or the PR #499 to add has_aux going to be merged?

@mattjj
Copy link
Copy Markdown
Collaborator

mattjj commented Jan 21, 2020

Sorry, I dropped the ball on this issue.

I think probably not, because we want to revise the api a little. But we’ll merge something like it.

Is this affecting you? Maybe we can give you a workaround for the time being.

@pierthodo
Copy link
Copy Markdown

pierthodo commented Jan 21, 2020

It's okay I can work around it in the meantime until it is merged. I just wanted to see if it was going to be included at some point or not. Thanks !

@matt-graham
Copy link
Copy Markdown
Contributor Author

This is still something I am interested in too! I would be happy to work on submitting a new pull-request or updating this one. What changes to the API do you think are needed @mattjj?

@jekbradbury
Copy link
Copy Markdown
Contributor

Matt might have been talking about the kwarg-vs-separate functions thing, which it looks like you addressed as he suggested (it certainly looks good to me). Can you rebase, though?

@matt-graham
Copy link
Copy Markdown
Contributor Author

@jekbradbury I've now rebased. Let me know if anything else needs changing!

@matt-graham matt-graham changed the title Adding option to return value from Jacobian functions Adding value_and_jac* functions Feb 3, 2020
@matt-graham
Copy link
Copy Markdown
Contributor Author

@mattjj and @jekbradbury: is there anything else that needs changing / adding for this to be merged?

@wagnew3
Copy link
Copy Markdown

wagnew3 commented Dec 28, 2020

This would be a useful feature to have--are there plans to release it?

@shoyer
Copy link
Copy Markdown
Collaborator

shoyer commented Dec 29, 2020

In my opinion, the ideal would be that writing your own fully featured version of value_and_jacobian is only four lines of code and could be included as a recipe somewhere in the docs. We are actually pretty close to that already, aside from handling pytrees (which could be handled by tree_vectorize, if that ever gets merged!).

@kach
Copy link
Copy Markdown

kach commented Mar 5, 2021

Are there any updates on this? It would be super helpful for my research work. :)

(Or, alternatively, @shoyer, could you post your recipe for a DIY value_and_jacobian here, even if it doesn't quite handle pytrees?)

@fordmatt18
Copy link
Copy Markdown

fordmatt18 commented Dec 26, 2021

Are there any updates on this? It would be super helpful for my research work. :)

(Or, alternatively, @shoyer, could you post your recipe for a DIY value_and_jacobian here, even if it doesn't quite handle pytrees?)

@mattjj @shoyer Any updates or workarounds? I agree with @kach, this would be a very useful feature to have - either merged into JAX or given as some type of workaround.

@kach
Copy link
Copy Markdown

kach commented Dec 27, 2021

Hi all - as a temporary workaround, I wrote my own small value_and_jacfwd by analogy to the API and implementation of the value_and_grad provided by JAX. This was enough to unblock me from my own research. :)

I posted my code here in case anyone else finds it useful: https://github.com/kach/jax.value_and_jacfwd.

@mattjj, I'm happy to create a pull request to have this formally merged into JAX, if you're interested. Of course, if you're still thinking about the API design and don't want to commit to this solution, I totally understand. :)

@fordmatt18
Copy link
Copy Markdown

Thanks everyone, helpful discussion here. I was able to make a solution similar to the OP’s but updated to account for the changes in more recent versions of JAX. I would just like to pass along that I think it would be beneficial to other users to have value_and_jacfwd and value_and_jacrev functions in the API, but of course that’s up to the JAX team and I’m not fully aware of all the potential issues at play.

@shoyer
Copy link
Copy Markdown
Collaborator

shoyer commented Dec 28, 2021

One reason not to bother with supporting this functionality in JAX that it isn't clear if that would actually more computationally efficient. Unless you only have a few outputs (or inputs), the cost of computing the Jacobian is going to dominate the cost the cost of evaluating the function, so evaluating it twice probably doesn't matter.

@shoyer
Copy link
Copy Markdown
Collaborator

shoyer commented Dec 28, 2021

(Or, alternatively, @shoyer, could you post your recipe for a DIY value_and_jacobian here, even if it doesn't quite handle pytrees?)

For reference, here are minimal versions of value_and_jacfwd and value_and_jacrev. They only work on functions on that map from vectors to vectors:

import functools
import jax
import jax.numpy as jnp

def value_and_jacfwd(f, x):
  pushfwd = functools.partial(jax.jvp, f, (x,))
  basis = jnp.eye(x.size, dtype=x.dtype)
  y, jac = jax.vmap(pushfwd, out_axes=(None, 1))((basis,))
  return y, jac

def value_and_jacrev(f, x):
  y, pullback = jax.vjp(f, x)
  basis = jnp.eye(y.size, dtype=y.dtype)
  jac = jax.vmap(pullback)(basis)
  return y, jac

@Gattocrucco
Copy link
Copy Markdown
Contributor

One reason not to bother with supporting this functionality in JAX that it isn't clear if that would actually more computationally efficient. Unless you only have a few outputs (or inputs), the cost of computing the Jacobian is going to dominate the cost the cost of evaluating the function, so evaluating it twice probably doesn't matter.

I'm doing a computation where the bottleneck is decomposing a matrix, so jacobian and hessian are relatively light once I've computed the base function. The way I'm doing value_and_jac is with has_aux:

import functools
import jax

def value_and_ops(f, *ops, **kw):
    if not ops:
        return f
    def fop(*args, **kw):
        return f(*args, **kw), ()
    for op in ops:
        def fop(*args, _fop=fop, **kw):
            y, aux = _fop(*args, **kw)
            return y, aux + (y,)
        fop = op(fop, has_aux=True, **kw)
    @functools.wraps(f)
    def fop(*args, _fop=fop, **kw):
        y, aux = _fop(*args, **kw)
        return aux + (y,)
    return fop

def f(x):
    print('ciao')
    return 1/2 * x ** 2

fg = value_and_ops(f, jax.grad, jax.grad, jax.grad)

print(fg(3.))
ciao
(DeviceArray(4.5, dtype=float32, weak_type=True), DeviceArray(3., dtype=float32, weak_type=True), DeviceArray(1., dtype=float32, weak_type=True), DeviceArray(0., dtype=float32, weak_type=True))

Doubts I have:

  1. Is this a legal usage of has_aux? Will it break depending on internal changes?
  2. Is the function really being executed once because it printed ciao once, or somehow jax is tracing and re-executing stuff? (I guess not)

@itk22
Copy link
Copy Markdown

itk22 commented Jun 26, 2023

@shoyer, thanks for providing your implementation of value_and_jacfwd! I was wondering if you know how I could apply this to a nested Pytree input. In particular, I am looking to apply this to a function whose input is the params object from Haiku.

@axch axch added the P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR label Jun 27, 2023
@lockwo
Copy link
Copy Markdown
Contributor

lockwo commented Sep 17, 2024

Just revisiting this from #23645, in which I basically implemented what was here. I think this thread/function remains valuable to either be in the jax source code or at least have some representation with the docs. Part of it stems from a minor followup with:

One reason not to bother with supporting this functionality in JAX that it isn't clear if that would actually more computationally efficient. Unless you only have a few outputs (or inputs), the cost of computing the Jacobian is going to dominate the cost the cost of evaluating the function, so evaluating it twice probably doesn't matter.

I agree with the textual statement, but disagree that this makes it unimportant. There are a non-negliable amount of situations in which you have (a) a specialized jacobian which you can evaluate on the same order of magnitude as the function, or (b) small parameter regimes where even if the Jacobian computation is more than the evaluation, the amount of time saved would be non trivial (e.g. certain regimes of differentiation equation solving).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla: yes P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.