Skip to content

Commit b059363

Browse files
authored
Merge pull request DJ4Earth#23 from swilliamson7/main
Removing old version in this repo
2 parents 04139b9 + 6f25586 commit b059363

20 files changed

+298
-1789
lines changed

.DS_Store

-2 KB
Binary file not shown.

explicit_solver/.DS_Store

-6 KB
Binary file not shown.
File renamed without changes.

explicit_solver/advance.jl

Lines changed: 73 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -5,107 +5,59 @@
55
# I needed all of the terms that matter to the derivative to appear in a *single* structure, so keeping
66
# RHS_terms separate from the states was no longer a good idea.
77

8+
function advance(states_rhs::SWM_pde,
9+
grid::Grid,
10+
params::Params,
11+
interp::Interps,
12+
grad::Derivatives,
13+
advec::Advection
14+
)
815

9-
# function advance(u_v_eta::gyre_vector,
10-
# grid::Grid,
11-
# rhs::RHS_terms,
12-
# params::Params,
13-
# interp::Interps,
14-
# grad::Derivatives,
15-
# advec::Advection
16-
# )
17-
18-
# nx = grid.nx
19-
# dt = params.dt
20-
21-
# # we now use RK4 as the timestepper, here I'm storing the coefficients needed for this
22-
# rk_a = [1/6, 1/3, 1/3, 1/6]
23-
# rk_b = [1/2, 1/2, 1.]
24-
25-
# rhs.umid .= u_v_eta.u
26-
# rhs.vmid .= u_v_eta.v
27-
# rhs.etamid .= u_v_eta.eta
28-
29-
# rhs.u0 .= u_v_eta.u
30-
# rhs.v0 .= u_v_eta.v
31-
# rhs.eta0 .= u_v_eta.eta
32-
33-
# rhs.u1 .= u_v_eta.u
34-
# rhs.v1 .= u_v_eta.v
35-
# rhs.eta1 .= u_v_eta.eta
36-
37-
# @inbounds for j in 1:4
38-
39-
# comp_u_v_eta_t(nx, rhs, params, interp, grad, advec)
40-
41-
# if j < 4
42-
# rhs.u1 .= rhs.umid .+ rk_b[j] .* dt .* rhs.u_t
43-
# rhs.v1 .= rhs.vmid .+ rk_b[j] .* dt .* rhs.v_t
44-
# rhs.eta1 .= rhs.etamid .+ rk_b[j] .* dt .* rhs.eta_t
45-
# end
46-
47-
# rhs.u0 .= rhs.u0 .+ rk_a[j] .* dt .* rhs.u_t
48-
# rhs.v0 .= rhs.v0 .+ rk_a[j] .* dt .* rhs.v_t
49-
# rhs.eta0 .= rhs.eta0 .+ rk_a[j] .* dt .* rhs.eta_t
50-
51-
# end
52-
53-
# @assert all(x -> x < 7.0, rhs.u0)
54-
# @assert all(x -> x < 7.0, rhs.v0)
55-
# @assert all(x -> x < 7.0, rhs.eta0)
56-
57-
# copyto!(u_v_eta.u, rhs.u0)
58-
# copyto!(u_v_eta.v, rhs.v0)
59-
# copyto!(u_v_eta.eta, rhs.eta0)
60-
61-
# return nothing
62-
63-
# end
16+
nx = grid.nx
17+
(;dt, rk_a, rk_b) = params
6418

65-
function advance(states_rhs::SWM_pde,
66-
grid::Grid,
67-
params::Params,
68-
interp::Interps,
69-
grad::Derivatives,
70-
advec::Advection
71-
)
19+
states_rhs.umid .= states_rhs.u
20+
states_rhs.vmid .= states_rhs.v
21+
states_rhs.etamid .= states_rhs.eta
7222

73-
@unpack nx = grid
74-
@unpack dt, rk_a, rk_b = params
23+
states_rhs.u0 .= states_rhs.u
24+
states_rhs.v0 .= states_rhs.v
25+
states_rhs.eta0 .= states_rhs.eta
7526

76-
states_rhs.umid .= states_rhs.u
77-
states_rhs.vmid .= states_rhs.v
78-
states_rhs.etamid .= states_rhs.eta
27+
states_rhs.u1 .= states_rhs.u
28+
states_rhs.v1 .= states_rhs.v
29+
states_rhs.eta1 .= states_rhs.eta
7930

80-
states_rhs.u0 .= states_rhs.u
81-
states_rhs.v0 .= states_rhs.v
82-
states_rhs.eta0 .= states_rhs.eta
31+
@inbounds for j in 1:4
8332

84-
states_rhs.u1 .= states_rhs.u
85-
states_rhs.v1 .= states_rhs.v
86-
states_rhs.eta1 .= states_rhs.eta
33+
comp_u_v_eta_t(nx, states_rhs, params, interp, grad, advec)
8734

88-
@inbounds for j in 1:4
35+
if j < 4
36+
states_rhs.u1 .= states_rhs.umid .+ rk_b[j] .* dt .* states_rhs.u_t
37+
states_rhs.v1 .= states_rhs.vmid .+ rk_b[j] .* dt .* states_rhs.v_t
38+
states_rhs.eta1 .= states_rhs.etamid .+ rk_b[j] .* dt .* states_rhs.eta_t
39+
end
8940

90-
comp_u_v_eta_t(nx, states_rhs, params, interp, grad, advec)
41+
states_rhs.u0 .= states_rhs.u0 .+ rk_a[j] .* dt .* states_rhs.u_t
42+
states_rhs.v0 .= states_rhs.v0 .+ rk_a[j] .* dt .* states_rhs.v_t
43+
states_rhs.eta0 .= states_rhs.eta0 .+ rk_a[j] .* dt .* states_rhs.eta_t
9144

92-
if j < 4
93-
states_rhs.u1 .= states_rhs.umid .+ rk_b[j] .* dt .* states_rhs.u_t
94-
states_rhs.v1 .= states_rhs.vmid .+ rk_b[j] .* dt .* states_rhs.v_t
95-
states_rhs.eta1 .= states_rhs.etamid .+ rk_b[j] .* dt .* states_rhs.eta_t
96-
end
45+
end
9746

98-
states_rhs.u0 .= states_rhs.u0 .+ rk_a[j] .* dt .* states_rhs.u_t
99-
states_rhs.v0 .= states_rhs.v0 .+ rk_a[j] .* dt .* states_rhs.v_t
100-
states_rhs.eta0 .= states_rhs.eta0 .+ rk_a[j] .* dt .* states_rhs.eta_t
47+
# Diffusion and bottom friction as Euler forward
48+
dissipative_terms!(nx, states_rhs, params, interp, grad, advec)
49+
states_rhs.u0 .+= dt .* states_rhs.u_t
50+
states_rhs.v0 .+= dt .* states_rhs.v_t
10151

102-
end
52+
@assert all(x -> x < 7.0, states_rhs.u0)
53+
@assert all(x -> x < 7.0, states_rhs.v0)
54+
@assert all(x -> x < 7.0, states_rhs.eta0)
10355

104-
@assert all(x -> x < 7.0, states_rhs.u0)
105-
@assert all(x -> x < 7.0, states_rhs.v0)
106-
@assert all(x -> x < 7.0, states_rhs.eta0)
56+
copyto!(states_rhs.u, states_rhs.u0)
57+
copyto!(states_rhs.v, states_rhs.v0)
58+
copyto!(states_rhs.eta, states_rhs.eta0)
10759

108-
return nothing
60+
return nothing
10961

11062
end
11163

@@ -117,7 +69,7 @@ end
11769
# mostly keeping because I like having a function that just integrates for some amount
11870
# of time
11971
function integrate(days, nx, ny; Lx = 3840e3, Ly = 3840e3)
120-
72+
12173
grid = build_grid(Lx, Ly, nx, ny)
12274
params = def_params(grid)
12375

@@ -140,107 +92,57 @@ function integrate(days, nx, ny; Lx = 3840e3, Ly = 3840e3)
14092
T = T
14193
)
14294

143-
@time for t in 1:T
95+
for t in 1:states_rhs.T
96+
14497
advance(states_rhs, grid, params, interp, grad, advec)
98+
14599
copyto!(states_rhs.u, states_rhs.u0)
146100
copyto!(states_rhs.v, states_rhs.v0)
147101
copyto!(states_rhs.eta, states_rhs.eta0)
102+
148103
end
149104

150105
return states_rhs
151-
106+
152107
end
153108

154-
# ****IMPORTANT**** not yet sure if I'm moving between high and low res grids, need to check with Patrick
155-
# and come back here if there are issues with how I did it
109+
function integrate(u0, v0, eta0, days, nx, ny; Lx = 3840e3, Ly = 3840e3)
156110

157-
# This function needs to be given
158-
# days - how many days to integrate the model for
159-
# nx_lowres, ny_lowres - grid resolution (number of cells in the x and y directions
160-
# respectively) of the courser grid
161-
# Lx, Ly - size of the domain, have a default value but can set
162-
# manually if needed
163-
# data_steps - which timesteps we want to store data at
164-
# Theoretically, we should only ever run this function *once*, from there
165-
# the data points will be stored as a JLD2 data file
166-
function create_data(days, nx_lowres, ny_lowres, data_steps; scaling = 4, Lx = 3840e3, Ly = 3840e3)
167-
168-
nx_highres = nx_lowres * scaling
169-
ny_highres = ny_lowres * scaling
170-
171-
grid_highres = build_grid(Lx, Ly, nx_highres, ny_highres)
172-
params = def_params(grid_highres)
111+
grid = build_grid(Lx, Ly, nx, ny)
112+
params = def_params(grid)
173113

174114
# building discrete operators
175-
grad = build_derivs(grid_highres) # discrete gradient operators
176-
interp = build_interp(grid_highres, grad) # discrete interpolation operators (travels between grids)
177-
advec = build_advec(grid_highres)
115+
grad = build_derivs(grid) # discrete gradient operators
116+
interp = build_interp(grid, grad) # discrete interpolation operators (travels between grids)
117+
advec = build_advec(grid)
178118

179-
Trun = days_to_seconds(days, params.dt)
119+
Nu = grid.Nu
120+
Nv = grid.Nv
121+
NT = grid.NT
122+
Nq = grid.Nq
180123

181-
Nu = grid_highres.Nu
182-
Nv = grid_highres.Nv
183-
NT = grid_highres.NT
184-
Nq = grid_highres.Nq
124+
T = days_to_seconds(days, params.dt)
185125

186-
u_v_eta_rhs = SWM_pde(Nu = Nu,
126+
states_rhs = SWM_pde(Nu = Nu,
187127
Nv = Nv,
188128
NT = NT,
189-
Nq = Nq
129+
Nq = Nq,
130+
T = T,
131+
u = u0,
132+
v = v0,
133+
eta = eta0
190134
)
191-
192-
# In order to compare high res data to low res velocities I'm going to (1) interpolate
193-
# the velocities to the T-grid (cell centers) (2) average the high
194-
# res data points down to the low res grid (3) interpolate the low
195-
# res results to the cell centers. Then I'll be comparing apples to apples (hopefully)
196-
# will check with Patrick that this is a valid method
197135

198-
# Building the averaging operator needed for step (2) above
199-
diag1 = (1 / scaling^2) .* ones(NT)
200-
M = spdiagm(NT, NT, 0.0 .* diag1)
201-
for k = 1:scaling
202-
for j = 1:scaling
203-
M += spdiagm(NT, NT, j+(k-1)*nx_highres - 1 => diag1[1:end-j-(k-1)*nx_highres + 1])
204-
end
205-
end
206-
M = M[1:scaling:end, :]
207-
index1 = 1:nx_lowres*ny_lowres*scaling
208-
for k in 1:ny_lowres
209-
index1 = filter(x -> x [j for j in ((k-1)*nx_lowres*scaling+nx_lowres+1):((k-1)*nx_lowres*scaling+nx_lowres*scaling)], index1)
210-
end
211-
M = M[index1, :]
212-
213-
# initializing where to store the data
214-
data = zeros(grid_highres.Nu + grid_highres.Nv + grid_highres.NT, length(data_steps))
215-
216-
# the steps where we want data in the high res model correspond to (roughly) scaling * t for t
217-
# in the low res model. for simplicity I'm going to keep the times in the low res where I want to have
218-
# data and then just scale them in the for loop to find corresponding high res data points
219-
220-
# for an initial effort I'm just going to run pretty course resolution models for both the high and low res
136+
@btime for t in 1:T
221137

222-
if 1 in scaling .* data_steps
223-
data[:, 1] .= [u_v_eta_rhs.u; u_v_eta_rhs.v; u_v_eta_rhs.eta]
224-
j = 2
225-
else
226-
j = 1
227-
end
228-
229-
for t in 2:Trun
230-
231-
advance(u_v_eta_rhs, grid_highres, params, interp, grad, advec)
232-
233-
if t in scaling .* data_steps
234-
data[:, j] .= [u_v_eta_rhs.u; u_v_eta_rhs.v; u_v_eta_rhs.eta]
235-
j += 1
236-
end
138+
advance(states_rhs, grid, params, interp, grad, advec)
139+
copyto!(states_rhs.u, states_rhs.u0)
140+
copyto!(states_rhs.v, states_rhs.v0)
141+
copyto!(states_rhs.eta, states_rhs.eta0)
142+
143+
end
144+
145+
return states_rhs
237146

238-
copyto!(u_v_eta_rhs.u, u_v_eta_rhs.u0)
239-
copyto!(u_v_eta_rhs.v, u_v_eta_rhs.v0)
240-
copyto!(u_v_eta_rhs.eta, u_v_eta_rhs.eta0)
147+
end
241148

242-
end
243-
244-
return data, M
245-
246-
end

explicit_solver/build_discrete_operators.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,24 @@
1-
# this function will build the discrete operators needed for things such as
1+
"""
2+
@inplacemul c = A*b
3+
4+
Macro to translate c = A*b with `A::SparseMatrixCSC`, `b` and `c` `Vector`s into
5+
`SparseArrays.mul!(c,A,b,true,false)` to perform the sparse matrix -
6+
dense vector multiplication in-place."""
7+
macro inplacemul(ex)
8+
@assert ex.head == :(=) "@inplacemul requires expression a = b*c"
9+
@assert ex.args[2].args[1] == :(*) "@inplacemul requires expression a = b*c"
10+
11+
return quote
12+
local c = $(esc(ex.args[1])) # output dense vector
13+
local A = $(esc(ex.args[2].args[2])) # input sparse matrix
14+
local b = $(esc(ex.args[2].args[3])) # input dense vector
15+
16+
# c = β*c + α*A*b, with α=1, β=0 so that c = A*b
17+
SparseArrays.mul!(c,A,b,true,false)
18+
end
19+
end
20+
21+
# Functions to build the discrete operators needed for things such as
222
# -- moving between grids
323
# -- discrete gradients
424
# -- discrete laplacians

0 commit comments

Comments
 (0)