Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions PARITY.md
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,18 @@ with a deep-learning backend (Lux.jl).

**Deliberate divergence:** torch tensors → `Matrix`es (`batch × dim`); the method
names are `pde_driver` (`r_u`), `pde_hamiltonian` (`h_z`), `pde_terminal`,
`pde_sigma`. The neural solver (`FBSDESolver`/`DeepBSDE` + the PyTorch
set-transformer architectures) is deferred to the Lux.jl slice; the transformer
variants are research scaffolding.
`pde_sigma`.

### Deep-BSDE solver (Lux.jl backend) — wired

| Concept | Python | Julia | Notes |
|---|---|---|---|
| Deep-BSDE solver | `FBSDESolver(..., "DeepBSDE")` | `Pde.solve_deep_bsde` | **behavioural**; per-time-step MLP (`Lux.jl`), trainable `Y_0`/`Z_0`, terminal-MSE loss, Adam (`Optimisers.jl`) + `Zygote.jl` gradients; validated structurally (loss is finite and decreases) |

**Deliberate divergence:** PyTorch → Lux.jl + Zygote + Optimisers (added
dependencies). The bias-free MLP matches the Python `DeepBSDE` subnetwork (the
commented-out batch-norm is omitted). The set-transformer architectures
(`ISAB`/`MAB`/`SAB`/`PMA`/`DeepTimeSetTransformer`) and the `Monte-Carlo`/`DTNN`/
`FBSNN` solver variants are research scaffolding and are not ported.

_(further submodules appended as they are wired)_
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "RiskLabAI"
uuid = "a72881da-fdaa-49c1-8962-99caf4ccfee8"
authors = ["RiskLab AI <research@risklab.ai>"]
version = "0.0.1"
authors = ["RiskLab AI <research@risklab.ai>"]

[deps]
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
Expand All @@ -12,9 +12,12 @@ DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TimeSeries = "9e3dc215-6440-5c97-bce1-76c03772f85e"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Clustering = "0.15"
Expand All @@ -23,7 +26,10 @@ DataFrames = "1"
DecisionTree = "0.12"
Distributions = "0.25"
HypothesisTests = "0.11"
Lux = "1.31.4"
Optimisers = "0.4.7"
TimeSeries = "0.20 - 0.24"
Zygote = "0.7.11"
julia = "1.10"

[extras]
Expand Down
106 changes: 106 additions & 0 deletions src/Pde/DeepBSDESolver.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
Deep BSDE solver — native Julia port mirroring the Python
`RiskLabAI.pde.solver.FBSDESolver` (`"DeepBSDE"` method) on the **Lux.jl**
backend (Han, Jentzen & E, 2018). A separate feed-forward subnetwork per time
step predicts the BSDE control `Z_t`; the initial value `Y_0` and gradient `Z_0`
are trainable scalars/vectors. Training minimises the terminal-condition MSE by
Adam (`Optimisers.jl`) with reverse-mode gradients (`Zygote.jl`).

Parity note: this is a **behavioural** port — a trained network is not
numerically identical to the PyTorch one. It is validated structurally (the
training loss is finite and decreases). The per-step subnetwork is the plain MLP
of the Python `DeepBSDE` module (ReLU between bias-free linear layers; the
commented-out batch-norm is omitted). The PyTorch set-transformer architectures
(`ISAB`/`MAB`/`SAB`/`PMA`/`DeepTimeSetTransformer`) are research scaffolding and
are not ported.

Reference: Han, J., Jentzen, A., E, W. (2018), PNAS.
"""

using Random: AbstractRNG, default_rng
using Lux
using Zygote
using Optimisers

_act(x) = max(x, zero(x))

# Per-time-step MLP: ReLU after each hidden bias-free linear layer, raw output.
function _build_subnet(sizes)
layers = Any[]
for i = 1:(length(sizes)-2)
push!(layers, Lux.Dense(sizes[i] => sizes[i+1], _act; use_bias = false))
end
push!(layers, Lux.Dense(sizes[end-1] => sizes[end]; use_bias = false))
return Lux.Chain(layers...)
end

# Terminal-condition MSE of the propagated backward equation.
function _bsde_loss(θ, eq, x, dw, models, states)
batch = size(x, 1)
y = ones(batch, 1) .* θ.y0 # Y_0 (trainable, 1-vector)
for i = 1:eq.num_time_interval
s0 = x[:, :, i]
if i == 1
out_z = ones(batch, 1) * θ.z0 # Z_0 (trainable, 1×dim)
else
net_out, _ = Lux.apply(models[i-1], permutedims(s0), θ.nets[i-1], states[i-1])
out_z = permutedims(net_out)
end
dwi = dw[:, :, i]
rate = pde_driver(eq, 0.0, s0, y, out_z)
hamiltonian = pde_hamiltonian(eq, 0.0, s0, y, out_z)
y = y .* (1 .+ rate .* eq.delta_t) .+ hamiltonian .* eq.delta_t .+
sum(out_z .* dwi; dims = 2)
end
payoff = pde_terminal(eq, eq.total_time, x[:, :, end])
return sum((payoff .- y) .^ 2) / batch
end

"""
solve_deep_bsde(eq; hidden_sizes=[16], iterations=20, batch_size=64,
init_y=0.0, learning_rate=0.01, rng=default_rng())
-> (losses, y0_estimates)

Train the Deep BSDE solver for equation `eq` (any `Pde.Equation`). Returns the
per-iteration validation loss and the running `Y_0` estimate. Behavioural.
Mirrors Python's `FBSDESolver(..., "DeepBSDE").solve`.
"""
function solve_deep_bsde(
eq::Equation;
hidden_sizes::AbstractVector{<:Integer} = [16],
iterations::Integer = 20,
batch_size::Integer = 64,
init_y::Real = 0.0,
learning_rate::Real = 0.01,
rng::AbstractRNG = default_rng(),
)
dim = eq.dim
sizes = vcat(dim, collect(hidden_sizes), dim)
n_nets = eq.num_time_interval - 1

models = [_build_subnet(sizes) for _ = 1:n_nets]
net_params = []
net_states = []
for model in models
ps, st = Lux.setup(rng, model)
push!(net_params, ps)
push!(net_states, st)
end
states = Tuple(net_states)

θ = (y0 = [Float64(init_y)], z0 = zeros(1, dim), nets = Tuple(net_params))

dw_val, x_val = pde_sample(eq, 128; rng = rng)
opt_state = Optimisers.setup(Optimisers.Adam(learning_rate), θ)

losses = Float64[]
y0_estimates = Float64[]
for _ = 1:iterations
dw_train, x_train = pde_sample(eq, batch_size; rng = rng)
gradient = Zygote.gradient(p -> _bsde_loss(p, eq, x_train, dw_train, models, states), θ)[1]
opt_state, θ = Optimisers.update(opt_state, θ, gradient)
push!(losses, _bsde_loss(θ, eq, x_val, dw_val, models, states))
push!(y0_estimates, θ.y0[1])
end
return losses, y0_estimates
end
6 changes: 5 additions & 1 deletion src/Pde/Pde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ module Pde

include("Equations.jl")

# Neural Deep-BSDE solver (Lux.jl backend).
include("DeepBSDESolver.jl")

export
Equation,
HJBLQ,
Expand All @@ -22,6 +25,7 @@ export
pde_driver,
pde_hamiltonian,
pde_terminal,
pde_sigma
pde_sigma,
solve_deep_bsde

end # module Pde
6 changes: 3 additions & 3 deletions src/RiskLabAI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ using .Ensemble: bagging_classifier_accuracy, fit_bagging, bagging_evaluate_sche

include("Pde/Pde.jl")
using .Pde: Equation, HJBLQ, BlackScholesBarenblatt, PricingDefaultRisk, PricingDiffRate,
pde_sample, pde_driver, pde_hamiltonian, pde_terminal, pde_sigma
pde_sample, pde_driver, pde_hamiltonian, pde_terminal, pde_sigma, solve_deep_bsde

# --------------------------------------------------------------------------- #
# Top-level exports.
Expand Down Expand Up @@ -120,9 +120,9 @@ export
# Ensemble — bagging accuracy
bagging_classifier_accuracy, fit_bagging, bagging_evaluate_schemes,
calculate_bootstrap_accuracy,
# Pde — equations (Deep BSDE)
# Pde — equations & Deep-BSDE solver
HJBLQ, BlackScholesBarenblatt, PricingDefaultRisk, PricingDiffRate,
pde_sample, pde_driver, pde_hamiltonian, pde_terminal, pde_sigma,
pde_sample, pde_driver, pde_hamiltonian, pde_terminal, pde_sigma, solve_deep_bsde,
# Backtest (legacy)
probabilityOfBacktestOverfitting,
# BetSize
Expand Down
18 changes: 18 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1226,3 +1226,21 @@ end
@test size(xs) == (16, 2, 5)
@test xs[:, :, 1] == repeat([1.0 0.5], 16, 1)
end

@testset "Pde — Deep BSDE solver (Lux)" begin
P = RiskLabAI.Pde
eq = P.HJBLQ(1, 0.5, 4)
losses, inits = P.solve_deep_bsde(
eq;
hidden_sizes = [8],
iterations = 12,
batch_size = 64,
init_y = 3.0,
learning_rate = 0.02,
rng = MersenneTwister(1),
)
@test length(losses) == 12
@test all(isfinite, losses)
@test all(isfinite, inits)
@test minimum(losses) <= losses[1] # training improves on the initial loss
end
Loading