Skip to content

Conversion to unitless seems to fail when ForwardDiff.Duals are involved #682

@Ickaser

Description

@Ickaser

Here is an MWE:

using DifferentialEquations
using Unitful

function ode_system!(du, u, p, t)
    R0, τ, Tref = p
    
    T = u[1]*u"K"

    dTdt =  -T / (1 + R0*(Tref - T)) / τ

    du[1] = ustrip(u"K/hr", dTdt)
end

params = (3.0u"1/mK", 100.0u"s", 260u"K")
u0_ndim = [250.0]
tspan = (0.0, 100.0)

prob = ODEProblem(ode_system!, u0_ndim, tspan, params)
sol = solve(prob, Rosenbrock23())

Essentially, the ode_system! is adding units, doing calculations, then stripping units at exit. For explicit ODE solvers, it works fine; if we use automatic differentiation (as in the Rosenbrock23() solver), though, it fails. When I have a term where a plain number is added to a Quantity whose dimensions cancel, (i.e. 1 + R0*(Tref - T) in the above example), I get a big ol' error if Duals are involved:

ERROR: LoadError: First call to automatic differentiation for the Jacobian
failed. This means that the user `f` function is not compatible
with automatic differentiation. Methods to fix this include:

1. Turn off automatic differentiation (e.g. Rosenbrock23() becomes
   Rosenbrock23(autodiff=false)). More details can befound at
   https://docs.sciml.ai/DiffEqDocs/stable/features/performance_overloads/
2. Improving the compatibility of `f` with ForwardDiff.jl automatic
   differentiation (using tools like PreallocationTools.jl). More details
   can be found at https://docs.sciml.ai/DiffEqDocs/stable/basics/faq/#Autodifferentiation-and-Dual-Numbers
3. Defining analytical Jacobians. More details can be
   found at https://docs.sciml.ai/DiffEqDocs/stable/types/ode_types/#SciMLBase.ODEFunction

Note: turning off automatic differentiation tends to have a very minimal
performance impact (for this use case, because it's forward mode for a
square Jacobian. This is different from optimization gradient scenarios).
However, one should be careful as some methods are more sensitive to
accurate gradients than others. Specifically, Rodas methods like `Rodas4`
and `Rodas5P` require accurate Jacobians in order to have good convergence,
while many other methods like BDF (`QNDF`, `FBDF`), SDIRK (`KenCarp4`),
and Rosenbrock-W (`Rosenbrock23`) do not. Thus if using an algorithm which
is sensitive to autodiff and solving at a low tolerance, please change the
algorithm as well.

MethodError: convert(::Type{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, ::Quantity{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}, NoDims, Unitful.FreeUnits{(mK^-1, K), NoDims, nothing}}) is ambiguous.

Candidates:
  convert(::Type{T}, y::Quantity) where T<:Real
    @ Unitful C:\Users\iwheeler\.julia\packages\Unitful\orvol\src\conversion.jl:145
  convert(::Type{ForwardDiff.Dual{T, V, N}}, x::Number) where {T, V, N}
    @ ForwardDiff C:\Users\iwheeler\.julia\packages\ForwardDiff\vXysl\src\dual.jl:435
  convert(::Type{T}, x::Number) where T<:Number
    @ Base number.jl:7
  convert(::Type{ForwardDiff.Dual{T, V, N}}, x) where {T, V, N}
    @ ForwardDiff C:\Users\iwheeler\.julia\packages\ForwardDiff\vXysl\src\dual.jl:434

Possible fix, define
  convert(::Type{ForwardDiff.Dual{T, V, N}}, ::Quantity) where {T, V, N}

Stacktrace:

(and then the stacktrace goes through a whole lot of internal SciML stuff).

There is a workaround to this problem: if I rewrite the above ODE function to explicitly ustrip(NoUnits, ...) before adding to plain numbers, the function runs with no issue:

function ode_system!(du, u, p, t)
    R0, τ, Tref = p
    
    T = u[1]*u"K"

    dTdt =  -T / (1 + ustrip(NoUnits, R0*(Tref - T))) / τ

    du[1] = ustrip(u"K/hr", dTdt)
end

Would it be reasonable to define a ForwardDiff conversion as suggested in the error, perhaps in an extension package?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions