From 6e2d7c45bc8657e057f22203c21f0db393deec31 Mon Sep 17 00:00:00 2001 From: Hamid Arian Date: Fri, 19 Jun 2026 20:37:17 -0400 Subject: [PATCH] feat(pde): Deep-BSDE neural solver (Lux.jl) --- PARITY.md | 16 ++++-- Project.toml | 8 ++- src/Pde/DeepBSDESolver.jl | 106 ++++++++++++++++++++++++++++++++++++++ src/Pde/Pde.jl | 6 ++- src/RiskLabAI.jl | 6 +-- test/runtests.jl | 18 +++++++ 6 files changed, 152 insertions(+), 8 deletions(-) create mode 100644 src/Pde/DeepBSDESolver.jl diff --git a/PARITY.md b/PARITY.md index d866c0a..d351f88 100644 --- a/PARITY.md +++ b/PARITY.md @@ -461,8 +461,18 @@ with a deep-learning backend (Lux.jl). **Deliberate divergence:** torch tensors → `Matrix`es (`batch × dim`); the method names are `pde_driver` (`r_u`), `pde_hamiltonian` (`h_z`), `pde_terminal`, -`pde_sigma`. The neural solver (`FBSDESolver`/`DeepBSDE` + the PyTorch -set-transformer architectures) is deferred to the Lux.jl slice; the transformer -variants are research scaffolding. +`pde_sigma`. + +### Deep-BSDE solver (Lux.jl backend) — wired + +| Concept | Python | Julia | Notes | +|---|---|---|---| +| Deep-BSDE solver | `FBSDESolver(..., "DeepBSDE")` | `Pde.solve_deep_bsde` | **behavioural**; per-time-step MLP (`Lux.jl`), trainable `Y_0`/`Z_0`, terminal-MSE loss, Adam (`Optimisers.jl`) + `Zygote.jl` gradients; validated structurally (loss is finite and decreases) | + +**Deliberate divergence:** PyTorch → Lux.jl + Zygote + Optimisers (added +dependencies). The bias-free MLP matches the Python `DeepBSDE` subnetwork (the +commented-out batch-norm is omitted). The set-transformer architectures +(`ISAB`/`MAB`/`SAB`/`PMA`/`DeepTimeSetTransformer`) and the `Monte-Carlo`/`DTNN`/ +`FBSNN` solver variants are research scaffolding and are not ported. _(further submodules appended as they are wired)_ diff --git a/Project.toml b/Project.toml index 317db18..2928f59 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RiskLabAI" uuid = "a72881da-fdaa-49c1-8962-99caf4ccfee8" -authors = ["RiskLab AI "] version = "0.0.1" +authors = ["RiskLab AI "] [deps] Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" @@ -12,9 +12,12 @@ DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" TimeSeries = "9e3dc215-6440-5c97-bce1-76c03772f85e" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Clustering = "0.15" @@ -23,7 +26,10 @@ DataFrames = "1" DecisionTree = "0.12" Distributions = "0.25" HypothesisTests = "0.11" +Lux = "1.31.4" +Optimisers = "0.4.7" TimeSeries = "0.20 - 0.24" +Zygote = "0.7.11" julia = "1.10" [extras] diff --git a/src/Pde/DeepBSDESolver.jl b/src/Pde/DeepBSDESolver.jl new file mode 100644 index 0000000..f1a4cb6 --- /dev/null +++ b/src/Pde/DeepBSDESolver.jl @@ -0,0 +1,106 @@ +""" +Deep BSDE solver — native Julia port mirroring the Python +`RiskLabAI.pde.solver.FBSDESolver` (`"DeepBSDE"` method) on the **Lux.jl** +backend (Han, Jentzen & E, 2018). A separate feed-forward subnetwork per time +step predicts the BSDE control `Z_t`; the initial value `Y_0` and gradient `Z_0` +are trainable scalars/vectors. Training minimises the terminal-condition MSE by +Adam (`Optimisers.jl`) with reverse-mode gradients (`Zygote.jl`). + +Parity note: this is a **behavioural** port — a trained network is not +numerically identical to the PyTorch one. It is validated structurally (the +training loss is finite and decreases). The per-step subnetwork is the plain MLP +of the Python `DeepBSDE` module (ReLU between bias-free linear layers; the +commented-out batch-norm is omitted). The PyTorch set-transformer architectures +(`ISAB`/`MAB`/`SAB`/`PMA`/`DeepTimeSetTransformer`) are research scaffolding and +are not ported. + +Reference: Han, J., Jentzen, A., E, W. (2018), PNAS. +""" + +using Random: AbstractRNG, default_rng +using Lux +using Zygote +using Optimisers + +_act(x) = max(x, zero(x)) + +# Per-time-step MLP: ReLU after each hidden bias-free linear layer, raw output. +function _build_subnet(sizes) + layers = Any[] + for i = 1:(length(sizes)-2) + push!(layers, Lux.Dense(sizes[i] => sizes[i+1], _act; use_bias = false)) + end + push!(layers, Lux.Dense(sizes[end-1] => sizes[end]; use_bias = false)) + return Lux.Chain(layers...) +end + +# Terminal-condition MSE of the propagated backward equation. +function _bsde_loss(θ, eq, x, dw, models, states) + batch = size(x, 1) + y = ones(batch, 1) .* θ.y0 # Y_0 (trainable, 1-vector) + for i = 1:eq.num_time_interval + s0 = x[:, :, i] + if i == 1 + out_z = ones(batch, 1) * θ.z0 # Z_0 (trainable, 1×dim) + else + net_out, _ = Lux.apply(models[i-1], permutedims(s0), θ.nets[i-1], states[i-1]) + out_z = permutedims(net_out) + end + dwi = dw[:, :, i] + rate = pde_driver(eq, 0.0, s0, y, out_z) + hamiltonian = pde_hamiltonian(eq, 0.0, s0, y, out_z) + y = y .* (1 .+ rate .* eq.delta_t) .+ hamiltonian .* eq.delta_t .+ + sum(out_z .* dwi; dims = 2) + end + payoff = pde_terminal(eq, eq.total_time, x[:, :, end]) + return sum((payoff .- y) .^ 2) / batch +end + +""" + solve_deep_bsde(eq; hidden_sizes=[16], iterations=20, batch_size=64, + init_y=0.0, learning_rate=0.01, rng=default_rng()) + -> (losses, y0_estimates) + +Train the Deep BSDE solver for equation `eq` (any `Pde.Equation`). Returns the +per-iteration validation loss and the running `Y_0` estimate. Behavioural. +Mirrors Python's `FBSDESolver(..., "DeepBSDE").solve`. +""" +function solve_deep_bsde( + eq::Equation; + hidden_sizes::AbstractVector{<:Integer} = [16], + iterations::Integer = 20, + batch_size::Integer = 64, + init_y::Real = 0.0, + learning_rate::Real = 0.01, + rng::AbstractRNG = default_rng(), +) + dim = eq.dim + sizes = vcat(dim, collect(hidden_sizes), dim) + n_nets = eq.num_time_interval - 1 + + models = [_build_subnet(sizes) for _ = 1:n_nets] + net_params = [] + net_states = [] + for model in models + ps, st = Lux.setup(rng, model) + push!(net_params, ps) + push!(net_states, st) + end + states = Tuple(net_states) + + θ = (y0 = [Float64(init_y)], z0 = zeros(1, dim), nets = Tuple(net_params)) + + dw_val, x_val = pde_sample(eq, 128; rng = rng) + opt_state = Optimisers.setup(Optimisers.Adam(learning_rate), θ) + + losses = Float64[] + y0_estimates = Float64[] + for _ = 1:iterations + dw_train, x_train = pde_sample(eq, batch_size; rng = rng) + gradient = Zygote.gradient(p -> _bsde_loss(p, eq, x_train, dw_train, models, states), θ)[1] + opt_state, θ = Optimisers.update(opt_state, θ, gradient) + push!(losses, _bsde_loss(θ, eq, x_val, dw_val, models, states)) + push!(y0_estimates, θ.y0[1]) + end + return losses, y0_estimates +end diff --git a/src/Pde/Pde.jl b/src/Pde/Pde.jl index 2c1dafb..aeca702 100644 --- a/src/Pde/Pde.jl +++ b/src/Pde/Pde.jl @@ -12,6 +12,9 @@ module Pde include("Equations.jl") +# Neural Deep-BSDE solver (Lux.jl backend). +include("DeepBSDESolver.jl") + export Equation, HJBLQ, @@ -22,6 +25,7 @@ export pde_driver, pde_hamiltonian, pde_terminal, - pde_sigma + pde_sigma, + solve_deep_bsde end # module Pde diff --git a/src/RiskLabAI.jl b/src/RiskLabAI.jl index f7d9e7d..100ce7e 100644 --- a/src/RiskLabAI.jl +++ b/src/RiskLabAI.jl @@ -62,7 +62,7 @@ using .Ensemble: bagging_classifier_accuracy, fit_bagging, bagging_evaluate_sche include("Pde/Pde.jl") using .Pde: Equation, HJBLQ, BlackScholesBarenblatt, PricingDefaultRisk, PricingDiffRate, - pde_sample, pde_driver, pde_hamiltonian, pde_terminal, pde_sigma + pde_sample, pde_driver, pde_hamiltonian, pde_terminal, pde_sigma, solve_deep_bsde # --------------------------------------------------------------------------- # # Top-level exports. @@ -120,9 +120,9 @@ export # Ensemble — bagging accuracy bagging_classifier_accuracy, fit_bagging, bagging_evaluate_schemes, calculate_bootstrap_accuracy, - # Pde — equations (Deep BSDE) + # Pde — equations & Deep-BSDE solver HJBLQ, BlackScholesBarenblatt, PricingDefaultRisk, PricingDiffRate, - pde_sample, pde_driver, pde_hamiltonian, pde_terminal, pde_sigma, + pde_sample, pde_driver, pde_hamiltonian, pde_terminal, pde_sigma, solve_deep_bsde, # Backtest (legacy) probabilityOfBacktestOverfitting, # BetSize diff --git a/test/runtests.jl b/test/runtests.jl index ff6761d..e27d952 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1226,3 +1226,21 @@ end @test size(xs) == (16, 2, 5) @test xs[:, :, 1] == repeat([1.0 0.5], 16, 1) end + +@testset "Pde — Deep BSDE solver (Lux)" begin + P = RiskLabAI.Pde + eq = P.HJBLQ(1, 0.5, 4) + losses, inits = P.solve_deep_bsde( + eq; + hidden_sizes = [8], + iterations = 12, + batch_size = 64, + init_y = 3.0, + learning_rate = 0.02, + rng = MersenneTwister(1), + ) + @test length(losses) == 12 + @test all(isfinite, losses) + @test all(isfinite, inits) + @test minimum(losses) <= losses[1] # training improves on the initial loss +end