Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
MAINT: Update ma tutorial plt patterns.
  • Loading branch information
rossbar committed Nov 25, 2025
commit 47effeceba35e80237a133f44c4a57d4c99ac6d5
44 changes: 26 additions & 18 deletions content/tutorial-ma.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,11 @@ First of all, we can plot the whole set of data we have and see what it looks li
import matplotlib.pyplot as plt

selected_dates = [0, 3, 11, 13]
plt.plot(dates, nbcases.T, "--")
plt.xticks(selected_dates, dates[selected_dates])
plt.title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020")

fig, ax = plt.subplots()
ax.plot(dates, nbcases.T, "--")
ax.set_xticks(selected_dates, dates[selected_dates])
ax.set_title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020")
```

The graph has a strange shape from January 24th to February 1st. It would be interesting to know where this data comes from. If we look at the `locations` array we extracted from the `.csv` file, we can see that we have two columns, where the first would contain regions and the second would contain the name of the country. However, only the first few rows contain data for the the first column (province names in China). Following that, we only have country names. So it would make sense to group all the data from China into a single row. For this, we'll select from the `nbcases` array only the rows for which the second entry of the `locations` array corresponds to China. Next, we'll use the [numpy.sum](https://numpy.org/devdocs/reference/generated/numpy.sum.html#numpy.sum) function to sum all the selected rows (`axis=0`). Note also that row 35 corresponds to the total counts for the whole country for each date. Since we want to calculate the sum ourselves from the provinces data, we have to remove that row first from both `locations` and `nbcases`:
Expand Down Expand Up @@ -183,9 +185,10 @@ Let's try and see what the data looks like excluding the first row (data from th
closely:

```{code-cell}
plt.plot(dates, nbcases_ma[1:].T, "--")
plt.xticks(selected_dates, dates[selected_dates])
plt.title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020")
fig, ax = plt.subplots()
ax.plot(dates, nbcases_ma[1:].T, "--")
ax.set_xticks(selected_dates, dates[selected_dates])
ax.set_title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020")
```

Now that our data has been masked, let's try summing up all the cases in China:
Expand Down Expand Up @@ -232,9 +235,10 @@ china_total
We can replace the data with this information and plot a new graph, focusing on Mainland China:

```{code-cell}
plt.plot(dates, china_total.T, "--")
plt.xticks(selected_dates, dates[selected_dates])
plt.title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020 - Mainland China")
fig, ax = plt.subplots()
ax.plot(dates, china_total.T, "--")
ax.set_xticks(selected_dates, dates[selected_dates])
ax.set_title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020 - Mainland China")
```

It's clear that masked arrays are the right solution here. We cannot represent the missing data without mischaracterizing the evolution of the curve.
Expand Down Expand Up @@ -271,21 +275,25 @@ package to create a cubic polynomial model that fits the data as best as possibl
```{code-cell}
t = np.arange(len(china_total))
model = np.polynomial.Polynomial.fit(t[~china_total.mask], valid, deg=3)
plt.plot(t, china_total)
plt.plot(t, model(t), "--")

fig, ax = plt.subplots()
ax.plot(t, china_total)
ax.plot(t, model(t), "--")
```

This plot is not so readable since the lines seem to be over each other, so let's summarize in a more elaborate plot. We'll plot the real data when
available, and show the cubic fit for unavailable data, using this fit to compute an estimate to the observed number of cases on January 28th 2020, 7 days after the beginning of the records:

```{code-cell}
plt.plot(t, china_total)
plt.plot(t[china_total.mask], model(t)[china_total.mask], "--", color="orange")
plt.plot(7, model(7), "r*")
plt.xticks([0, 7, 13], dates[[0, 7, 13]])
plt.yticks([0, model(7), 10000, 17500])
plt.legend(["Mainland China", "Cubic estimate", "7 days after start"])
plt.title(
fig, ax = plt.subplots()
ax.plot(t, china_total)
ax.plot(t[china_total.mask], model(t)[china_total.mask], "--", color="orange")
ax.plot(7, model(7), "r*")

ax.set_xticks([0, 7, 13], dates[[0, 7, 13]])
ax.set_yticks([0, model(7), 10000, 17500])
ax.legend(["Mainland China", "Cubic estimate", "7 days after start"])
ax.set_title(
"COVID-19 cumulative cases from Jan 21 to Feb 3 2020 - Mainland China\n"
"Cubic estimate for 7 days after start"
)
Expand Down