Skip to content

Commit 6cbc0bf

Browse files
authored
Merge pull request DJ4Earth#15 from milankl/mk/speed
Speedups
2 parents a9b88a8 + 4020f16 commit 6cbc0bf

File tree

6 files changed

+158
-41
lines changed

6 files changed

+158
-41
lines changed

explicit_solver/ExplicitSolver.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
module ExplicitSolver
2+
3+
using Plots, SparseArrays, Parameters, UnPack
4+
using JLD2, LinearAlgebra
5+
using Enzyme, Checkpointing, Zygote
6+
7+
include("init_structs.jl")
8+
include("init_params.jl")
9+
include("build_grid.jl")
10+
include("build_discrete_operators.jl")
11+
include("advance.jl")
12+
include("compute_time_deriv.jl")
13+
14+
end

explicit_solver/advance.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,7 @@ function advance(u_v_eta::gyre_vector,
1616
)
1717

1818
nx = grid.nx
19-
dt = params.dt
20-
21-
# we now use RK4 as the timestepper, here I'm storing the coefficients needed for this
22-
rk_a = [1/6, 1/3, 1/3, 1/6]
23-
rk_b = [1/2, 1/2, 1.]
19+
(;dt, rk_a, rk_b) = params
2420

2521
rhs.umid .= u_v_eta.u
2622
rhs.vmid .= u_v_eta.v
@@ -34,7 +30,7 @@ function advance(u_v_eta::gyre_vector,
3430
rhs.v1 .= u_v_eta.v
3531
rhs.eta1 .= u_v_eta.eta
3632

37-
for j in 1:4
33+
@inbounds for j in 1:4
3834

3935
comp_u_v_eta_t(nx, rhs, params, interp, grad, advec)
4036

@@ -50,9 +46,14 @@ function advance(u_v_eta::gyre_vector,
5046

5147
end
5248

53-
@assert all(x -> x < 7.0, rhs.u0)
54-
@assert all(x -> x < 7.0, rhs.v0)
55-
@assert all(x -> x < 7.0, rhs.eta0)
49+
# Diffusion and bottom friction as Euler forward
50+
dissipative_terms!(nx, rhs, params, interp, grad, advec)
51+
rhs.u0 .+= dt .* rhs.u_t
52+
rhs.v0 .+= dt .* rhs.v_t
53+
54+
# @assert all(x -> x < 7.0, rhs.u0)
55+
# @assert all(x -> x < 7.0, rhs.v0)
56+
# @assert all(x -> x < 7.0, rhs.eta0)
5657

5758
copyto!(u_v_eta.u, rhs.u0)
5859
copyto!(u_v_eta.v, rhs.v0)
@@ -152,7 +153,7 @@ function integrate(T, nx, ny; Lx = 3840e3, Ly = 3840e3)
152153
# u_v_eta_mat = vec_to_mat(u_v_eta.u, u_v_eta.v, u_v_eta.eta, grid_params)
153154

154155
return u_v_eta
155-
156+
# return (u_v_eta, grid_params, rhs_terms, gyre_params, interp_ops, grad_ops, advec_ops)
156157
end
157158

158159
# ****IMPORTANT**** not yet sure if I'm moving between high and low res grids, need to check with Patrick

explicit_solver/build_discrete_operators.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,3 +369,23 @@ function build_advec(grid)
369369
return advec_ops
370370

371371
end
372+
373+
"""
374+
@inplacemul c = A*b
375+
376+
Macro to translate c = A*b with `A::SparseMatrixCSC`, `b` and `c` `Vector`s into
377+
`SparseArrays.mul!(c,A,b,true,false)` to perform the sparse matrix -
378+
dense vector multiplication in-place."""
379+
macro inplacemul(ex)
380+
@assert ex.head == :(=) "@inplacemul requires expression a = b*c"
381+
@assert ex.args[2].args[1] == :(*) "@inplacemul requires expression a = b*c"
382+
383+
return quote
384+
local c = $(esc(ex.args[1])) # output dense vector
385+
local A = $(esc(ex.args[2].args[2])) # input sparse matrix
386+
local b = $(esc(ex.args[2].args[3])) # input dense vector
387+
388+
# c = β*c + α*A*b, with α=1, β=0 so that c = A*b
389+
SparseArrays.mul!(c,A,b,true,false)
390+
end
391+
end

explicit_solver/compute_time_deriv.jl

Lines changed: 103 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,48 +8,122 @@ function comp_u_v_eta_t(nx::Int,
88
interp::Interps,
99
grad::Derivatives,
1010
advec::Advection
11-
)
11+
)
12+
13+
# unpack stuff
14+
u = rhs.u1
15+
v = rhs.v1
16+
eta = rhs.eta1
17+
18+
(;ITu, ITv, ITq, IuT, IvT) = interp # interpolation operators
19+
(;Iuv, Ivu, Iqu, Iqv) = interp
20+
(;GTx, GTy, Gux, Guy, Gvx, Gvy) = grad # gradient operators
21+
(;h, h_u, h_v, h_q, U, V, p, q) = rhs # diagnostic variables
22+
(;kinetic) = rhs
23+
(;GTx_p, GTy_p, Gux_U, Gvy_V, Gvx_v1, Guy_u1) = rhs
24+
(;adv_u, adv_v) = rhs
25+
(;IuT_u1, IvT_v1, ITu_ksq,ITv_ksq) = rhs
26+
(;u_t, v_t, eta_t) = rhs # tendencies
27+
(;H, coriolis, g, wind_stress) = params
28+
29+
h .= eta .+ H
30+
31+
@inplacemul h_q = ITq * h
32+
@inplacemul h_u = ITu * h
33+
@inplacemul h_v = ITv * h
34+
U .= u .* h_u # volume fluxes U,V
35+
V .= v .* h_v
36+
37+
# kinetic energy u² + v²
38+
u²_T, u² = IuT_u1, ITu_ksq # reuse and rename arrays for u²
39+
v²_T, v² = IvT_v1, ITv_ksq # and v², _T is on T-grid
40+
u² .= u.^2
41+
v² .= v.^2
42+
@inplacemul u²_T = IuT *
43+
@inplacemul v²_T = IvT *
44+
kinetic .= u²_T .+ v²_T
45+
46+
# Kloewer defined new terms q and p corresponding to potential vorticity and
47+
# Bernoulli potential respectively. To avoid errors in my mimic I'm following
48+
# along and doing the same
49+
@inplacemul Guy_u1 = Guy * u
50+
@inplacemul Gvx_v1 = Gvx * v
51+
52+
q .= (coriolis .+ Gvx_v1 .- Guy_u1) ./ h_q
53+
p .= 0.5 .* kinetic .+ g .* h
1254

13-
rhs.h .= rhs.eta1 .+ params.H
55+
# deal with the advection term
56+
# comp_advection(nx, rhs, advec) # Arakawa and Lamb advection scheme
1457

15-
rhs.h_u .= interp.ITu * rhs.h
16-
rhs.h_v .= interp.ITv * rhs.h
17-
rhs.h_q .= interp.ITq * rhs.h
58+
# Sadourny, 1975 enstrophy conserving advection scheme
59+
V_u = ITu_ksq # reuse and rename array
60+
@inplacemul V_u = Ivu * V
61+
@inplacemul adv_u = Iqu * q # u-component qhv
62+
adv_u .*= V_u
1863

19-
rhs.U .= rhs.u1 .* rhs.h_u
20-
rhs.V .= rhs.v1 .* rhs.h_v
64+
U_v = ITv_ksq # reuse and rename array
65+
@inplacemul U_v = Iuv * U
66+
# @inplacemul adv_v = Iqv * q # v-component -qhu
67+
adv_v .*= .-U_v
2168

22-
rhs.kinetic .= interp.IuT * (rhs.u1.^2) .+ interp.IvT * (rhs.v1.^2)
69+
# bernoulli gradient ∇p = ∇(1/2(u²+v² + gh))
70+
@inplacemul GTx_p = GTx * p
71+
@inplacemul GTy_p = GTy * p
2372

24-
# Kloewer defined new terms q and p corresponding to potential vorticity and
25-
# Bernoulli potential respectively. To avoid errors in my mimic I'm following
26-
# along and doing the same
27-
rhs.q .= (params.coriolis .+ grad.Gvx * rhs.v1 .- grad.Guy * rhs.u1) ./ rhs.h_q
28-
rhs.p .= 0.5 .* rhs.kinetic .+ params.g .* rhs.h
73+
# momentum equations
74+
u_t .= adv_u .- GTx_p .+ wind_stress ./ h_u
75+
v_t .= adv_v .- GTy_p
2976

30-
# bottom friction
31-
rhs.kinetic_sq .= sqrt.(rhs.kinetic)
32-
rhs.bfric_u .= params.bottom_drag .* ((interp.ITu * rhs.kinetic_sq) .* rhs.u1) ./ rhs.h_u
33-
rhs.bfric_v .= params.bottom_drag .* ((interp.ITv * rhs.kinetic_sq) .* rhs.v1) ./ rhs.h_v
77+
# continuity equations
78+
@inplacemul Gux_U = Gux * U # volume flux divergence dUdx + dVdy
79+
@inplacemul Gvy_V = Gvy * V
80+
@. eta_t = -(Gux_U + Gvy_V)
3481

35-
# deal with the advection term
36-
comp_advection(nx, rhs, advec)
82+
return nothing
83+
end
3784

38-
# rhs.Mu .= params.A_h .* (grad.LLu * rhs.u1)
39-
# rhs.Mv .= params.A_h .* (grad.LLv * rhs.v1)
40-
41-
rhs.Mu .= (interp.ITu * params.nu) .* (grad.LLu * rhs.u1)
42-
rhs.Mv .= (interp.ITv * params.nu) .* (grad.LLv * rhs.v1)
85+
function dissipative_terms!(nx::Int,
86+
rhs::RHS_terms,
87+
params::Params,
88+
interp::Interps,
89+
grad::Derivatives,
90+
advec::Advection
91+
)
4392

44-
rhs.u_t .= rhs.adv_u .- grad.GTx * rhs.p .+ params.wind_stress ./ rhs.h_u .- rhs.Mu .- rhs.bfric_u
93+
# unpack stuff
94+
u = rhs.u0 # calculate based on prognostics at
95+
v = rhs.v0 # t + dt of the non-dissipative RHS
4596

46-
rhs.v_t .= rhs.adv_v .- grad.GTy * rhs.p .- rhs.Mv .- rhs.bfric_v
97+
(;ITu, ITv) = interp # interpolation operators
98+
(;LLu, LLv) = grad # gradient operators
99+
(;h_u, h_v) = rhs # diagnostic variables
100+
(;kinetic, kinetic_sq, Mu, Mv, nu_u, nu_v) = rhs
101+
(;bfric_u, bfric_v) = rhs
102+
(;ITu_ksq,ITv_ksq) = rhs
103+
(;u_t, v_t) = rhs # tendencies
104+
(;nu, bottom_drag) = params
47105

48-
rhs.eta_t .= - (grad.Gux * rhs.U .+ grad.Gvy * rhs.V)
106+
# bottom friction
107+
kinetic_sq .= sqrt.(kinetic)
108+
@inplacemul ITu_ksq = ITu * kinetic_sq
109+
@inplacemul ITv_ksq = ITv * kinetic_sq
110+
bfric_u .= bottom_drag .* ITu_ksq .* u ./ h_u
111+
bfric_v .= bottom_drag .* ITv_ksq .* v ./ h_v
112+
113+
# diffusion term ν∇⁴(u,v)
114+
@inplacemul nu_u = ITu * nu
115+
@inplacemul nu_v = ITv * nu
116+
@inplacemul Mu = LLu * u
117+
@inplacemul Mv = LLv * v
118+
Mu .*= nu_u
119+
Mv .*= nu_v
120+
121+
# tendencies for bottom friction and diffusion
122+
u_t .= .- Mu .- bfric_u
123+
v_t .= .- Mv .- bfric_v
49124

50125
return nothing
51-
52-
end
126+
end
53127

54128
function comp_u_v_eta_t(nx::Int,
55129
rhs::SWM_pde,

explicit_solver/init_params.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,14 @@ function def_params(grid)
5555
# removing the requirement that dt be an integer, not sure why that's there
5656
dt = (0.9 * min(dx, dy)) / (sqrt(g * H)) # CFL condition for dt [seconds]
5757

58+
# we now use RK4 as the timestepper, here I'm storing the coefficients needed for this
59+
rk_a = [1/6, 1/3, 1/3, 1/6]
60+
rk_b = [1/2, 1/2, 1.]
61+
5862
gyre_params = Params(
5963
dt,
64+
rk_a,
65+
rk_b,
6066
g,
6167
f0,
6268
beta,

explicit_solver/init_structs.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ end
106106
# Parameters that appear in the model in various places
107107
struct Params
108108
dt::Float64 # timestep
109+
rk_a::Vector{Float64} # Runge Kutta 4th order coefficients
110+
rk_b::Vector{Float64}
109111
g::Float64 # gravity
110112
f0::Float64 # Coriolis parameter
111113
beta::Float64 # Coriolis parameter
@@ -255,8 +257,8 @@ end
255257
bfric_u::Vector{Float64} = zeros(Nu)
256258
bfric_v::Vector{Float64} = zeros(Nv)
257259

258-
LLu_u1::Vector{Float64} = zeros(Nu)
259-
LLv_v1::Vector{Float64} = zeros(Nv)
260+
nu_u::Vector{Float64} = zeros(Nu)
261+
nu_v::Vector{Float64} = zeros(Nv)
260262
Mu::Vector{Float64} = zeros(Nu)
261263
Mv::Vector{Float64} = zeros(Nv)
262264

0 commit comments

Comments
 (0)