Conversation
|
FWIW I'm not sure I like the idea of functions that sometimes return one value and sometimes two value. I view as jacfwd one a function that should be kept as simple as possible, anything more complicated you should just vmap over jvps manually. Else you'd also have to do this for jax.grad |
|
Thanks so much for this contribution, and crystal-clear explanation! All that thinking makes sense to me. The code looks perfect as well. The only thing we should think through is this kwargs business, and I have to say up front that I don't know what the right balance is. The kwargs-vs-separate-functions thing is always tricky. We've clearly gone with a bit of both already, as you pointed out ( One slight bias that might tip the scales is just to be consistent with I have an idea, based on something @skye recently suggested for I'm not sure if that idea would be preferable, but I think it's worth trying out so we can step back and look over the alternatives. What do you think, @matt-graham? Want to give that a shot, or should I try it? |
|
(By the way, I can't remember why I closed #499...) |
|
Defining internal At the moment I've kept the order the return values for |
|
+1 for adding |
|
Is this or the PR #499 to add has_aux going to be merged? |
|
Sorry, I dropped the ball on this issue. I think probably not, because we want to revise the api a little. But we’ll merge something like it. Is this affecting you? Maybe we can give you a workaround for the time being. |
|
It's okay I can work around it in the meantime until it is merged. I just wanted to see if it was going to be included at some point or not. Thanks ! |
|
This is still something I am interested in too! I would be happy to work on submitting a new pull-request or updating this one. What changes to the API do you think are needed @mattjj? |
|
Matt might have been talking about the kwarg-vs-separate functions thing, which it looks like you addressed as he suggested (it certainly looks good to me). Can you rebase, though? |
2231b26 to
a131828
Compare
|
@jekbradbury I've now rebased. Let me know if anything else needs changing! |
|
@mattjj and @jekbradbury: is there anything else that needs changing / adding for this to be merged? |
|
This would be a useful feature to have--are there plans to release it? |
|
In my opinion, the ideal would be that writing your own fully featured version of |
|
Are there any updates on this? It would be super helpful for my research work. :) (Or, alternatively, @shoyer, could you post your recipe for a DIY value_and_jacobian here, even if it doesn't quite handle pytrees?) |
@mattjj @shoyer Any updates or workarounds? I agree with @kach, this would be a very useful feature to have - either merged into JAX or given as some type of workaround. |
|
Hi all - as a temporary workaround, I wrote my own small I posted my code here in case anyone else finds it useful: https://github.com/kach/jax.value_and_jacfwd. @mattjj, I'm happy to create a pull request to have this formally merged into JAX, if you're interested. Of course, if you're still thinking about the API design and don't want to commit to this solution, I totally understand. :) |
|
Thanks everyone, helpful discussion here. I was able to make a solution similar to the OP’s but updated to account for the changes in more recent versions of JAX. I would just like to pass along that I think it would be beneficial to other users to have value_and_jacfwd and value_and_jacrev functions in the API, but of course that’s up to the JAX team and I’m not fully aware of all the potential issues at play. |
|
One reason not to bother with supporting this functionality in JAX that it isn't clear if that would actually more computationally efficient. Unless you only have a few outputs (or inputs), the cost of computing the Jacobian is going to dominate the cost the cost of evaluating the function, so evaluating it twice probably doesn't matter. |
For reference, here are minimal versions of import functools
import jax
import jax.numpy as jnp
def value_and_jacfwd(f, x):
pushfwd = functools.partial(jax.jvp, f, (x,))
basis = jnp.eye(x.size, dtype=x.dtype)
y, jac = jax.vmap(pushfwd, out_axes=(None, 1))((basis,))
return y, jac
def value_and_jacrev(f, x):
y, pullback = jax.vjp(f, x)
basis = jnp.eye(y.size, dtype=y.dtype)
jac = jax.vmap(pullback)(basis)
return y, jac |
I'm doing a computation where the bottleneck is decomposing a matrix, so jacobian and hessian are relatively light once I've computed the base function. The way I'm doing import functools
import jax
def value_and_ops(f, *ops, **kw):
if not ops:
return f
def fop(*args, **kw):
return f(*args, **kw), ()
for op in ops:
def fop(*args, _fop=fop, **kw):
y, aux = _fop(*args, **kw)
return y, aux + (y,)
fop = op(fop, has_aux=True, **kw)
@functools.wraps(f)
def fop(*args, _fop=fop, **kw):
y, aux = _fop(*args, **kw)
return aux + (y,)
return fop
def f(x):
print('ciao')
return 1/2 * x ** 2
fg = value_and_ops(f, jax.grad, jax.grad, jax.grad)
print(fg(3.))Doubts I have:
|
|
@shoyer, thanks for providing your implementation of |
|
Just revisiting this from #23645, in which I basically implemented what was here. I think this thread/function remains valuable to either be in the jax source code or at least have some representation with the docs. Part of it stems from a minor followup with:
I agree with the textual statement, but disagree that this makes it unimportant. There are a non-negliable amount of situations in which you have (a) a specialized jacobian which you can evaluate on the same order of magnitude as the function, or (b) small parameter regimes where even if the Jacobian computation is more than the evaluation, the amount of time saved would be non trivial (e.g. certain regimes of differentiation equation solving). |
At the moment while the Jacobian functions returned by both
jacfwdandjacrevcalculate the value of the function evaluated at the passed arguments the Jacobian is being computed for, there is no way to access this value. Often the value of the function will be needed in conjunction with the Jacobian - for example in a when applying Newton's method to find a root of a vector-valued function we require both the Jacobian and value of the function on each iteration.This pull-request adds an optional
return_valueargument to bothjacfwdandjacrevwhich if True makes the returned function return a tuple(jacobian, value)wherejacobianis the Jacobian evaluated at the passed arguments andvalueis the value of the function evaluated at the passed arguments. Ifreturn_valueis False (the default),jacfwdandjacrevretain their current behaviour - i.e. they return a function which returns only the Jacobian.There already exists a
value_and_gradfunction which implements an analogous operator for the specific case of scalar-valued functions. It would be more consistent with this existing function to definevalue_and_jacrevandvalue_and_jacfwdfunctions rather than using an additional argument, however this seemed like it could end up producing a lot of code and documentation duplication. The design proposed here is also more in keeping with thehas_auxoptional argument togradandvalue_and_gradwhich alter the return signature of the returned function.The same idea could also be applied to
hessianfunction to optionally return both the Jacobian and value of the function being differentiated in addition to the Hessian itself. This would be useful for example when using Newton's method in an optimisation setting, where the objective function would be scalar valued and we would need both the Hessian and Jacobian (gradient) for the Newton iteration, and the value objective function to monitor convergence. The current pull-request implements part of the necessary functionality to do this ashessianinternally callsjacfwdon the output of ajacrevcall, with respectively the Jacobian and value being computed as intermediate values during these calls. It seems the best way to plumb everything together however would be to add a similarhas_auxoptional argument tojacfwdto allow it to be applied to a function returned byjacrevwhich returns both the Jacobian and value of a function. Addinghas_auxtojacfwdandjacrevwas mentioned as a todo in PR #499 however it looks like that PR was closed without merging and the related issue #497 remains open, so I'm not sure if there is some issue with addinghas_auxfor forward-mode?