Skip to content
This repository was archived by the owner on Jul 25, 2023. It is now read-only.

Commit 4411e02

Browse files
committed
FEAT: tri-diagonal solver
1 parent 68a3b33 commit 4411e02

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

day_6/2023/tridiagonal.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
2+
import numpy as np
3+
4+
# ----------------------------------------------------------------------
5+
# Taken from:
6+
# https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm
7+
# BUT - they go from 1 -> n and python goes from 0 -> n-1
8+
# ----------------------------------------------------------------------
9+
10+
def solve_tridiagonal(a, b, c, d):
11+
12+
""" Function that solves a tridiagonal matrix
13+
14+
Parameters
15+
----------
16+
a - i-1 coefficient
17+
b - i coefficient
18+
c - i+1 coefficient
19+
d - source term
20+
21+
Returns
22+
-------
23+
x - solution to tri-diagonal matrix
24+
25+
Notes
26+
-----
27+
Solve system of equation
28+
29+
"""
30+
31+
nPts = len(a)
32+
33+
cp = np.zeros(nPts)
34+
dp = np.zeros(nPts)
35+
x = np.zeros(nPts)
36+
37+
# calculate c':
38+
i = 0
39+
cp[i] = c[i]/b[i]
40+
for i in range(1, nPts-1):
41+
cp[i] = c[i] / (b[i] - a[i] * cp[i-1])
42+
43+
# calculate d':
44+
i = 0
45+
dp[i] = d[i]/b[i]
46+
for i in range(1, nPts):
47+
dp[i] = (d[i] - a[i] * dp[i-1]) / (b[i] - a[i] * cp[i-1])
48+
49+
# calculate x:
50+
i = nPts - 1
51+
x[i] = dp[i]
52+
for i in range(nPts-2, -1, -1):
53+
x[i] = dp[i] - cp[i] * x[i + 1]
54+
55+
return x
56+

0 commit comments

Comments
 (0)