Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions bot/exts/moderation/infraction/superstarify.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
import random
import textwrap
Expand All @@ -6,10 +7,12 @@
from discord import Embed, Member
from discord.ext.commands import Cog, Context, command, has_any_role
from discord.utils import escape_markdown
from pydis_core.site_api import ResponseCodeError
from pydis_core.utils.members import get_or_fetch_member

from bot import constants
from bot.bot import Bot
from bot.constants import URLs
from bot.converters import Duration, DurationOrExpiry
from bot.decorators import ensure_future_timestamp
from bot.exts.moderation.infraction import _utils
Expand All @@ -18,6 +21,8 @@
from bot.utils import time
from bot.utils.messages import format_user

MAX_RETRY_ATTEMPTS = URLs.connect_max_retries
BACKOFF_INITIAL_DELAY = 5 # seconds
log = get_logger(__name__)
NICKNAME_POLICY_URL = "https://pythondiscord.com/pages/rules/#nickname-policy"
SUPERSTARIFY_DEFAULT_DURATION = "1h"
Expand All @@ -43,9 +48,7 @@ async def on_member_update(self, before: Member, after: Member) -> None:
f"{after.display_name}. Checking if the user is in superstar-prison..."
)

active_superstarifies = await self.bot.api_client.get(
"bot/infractions",
params={
active_superstarifies = await self._fetch_with_retries(params={
"active": "true",
"type": "superstar",
"user__id": str(before.id)
Expand Down Expand Up @@ -84,9 +87,7 @@ async def on_member_update(self, before: Member, after: Member) -> None:
@Cog.listener()
async def on_member_join(self, member: Member) -> None:
"""Reapply active superstar infractions for returning members."""
active_superstarifies = await self.bot.api_client.get(
"bot/infractions",
params={
active_superstarifies = await self._fetch_with_retries(params={
"active": "true",
"type": "superstar",
"user__id": member.id
Expand Down Expand Up @@ -238,6 +239,25 @@ async def cog_check(self, ctx: Context) -> bool:
"""Only allow moderators to invoke the commands in this cog."""
return await has_any_role(*constants.MODERATION_ROLES).predicate(ctx)

async def _fetch_with_retries(self,
retries: int = MAX_RETRY_ATTEMPTS,
params: dict[str, str] | None = None) -> list[dict]:
"""Fetch infractions from the API with retries and exponential backoff."""
for attempt in range(retries):
try:
return await self.bot.api_client.get("bot/infractions", params=params)
except Exception as e:
if attempt == retries - 1 or not self._check_error_is_retriable(e):
raise
await asyncio.sleep(BACKOFF_INITIAL_DELAY * (2 ** (attempt - 1)))
return None

async def _check_error_is_retriable(self, error: Exception) -> bool:
"""Return whether loading filter lists failed due to some temporary error, thus retrying could help."""
if isinstance(error, ResponseCodeError):
return error.status in (408, 429) or error.status >= 500

return isinstance(error, (TimeoutError, OSError))

async def setup(bot: Bot) -> None:
"""Load the Superstarify cog."""
Expand Down
105 changes: 105 additions & 0 deletions tests/bot/exts/moderation/infraction/test_superstarify_cog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import unittest
from unittest.mock import AsyncMock, MagicMock, patch

from bot.exts.moderation.infraction.superstarify import Superstarify
from tests.helpers import MockBot


class TestSuperstarify(unittest.IsolatedAsyncioTestCase):

async def asyncSetUp(self):
self.bot = MockBot()

self.cog = Superstarify(self.bot)

self.bot.api_client = MagicMock()
self.bot.api_client.get = AsyncMock()

self.cog._check_error_is_retriable = MagicMock(return_value=True)

async def test_fetch_from_api_success(self):
"""API succeeds on first attempt."""
expected = [{"id": 1}]
self.bot.api_client.get.return_value = expected

result = await self.cog._fetch_with_retries(
params={"user__id": "123"}
)
self.assertEqual(result, expected)

self.bot.api_client.get.assert_awaited_once_with(
"bot/infractions",
params={"user__id": "123"},
)

@patch("asyncio.sleep", new_callable=AsyncMock)
async def test_fetch_retries_then_succeeds(self, _):
self.bot.api_client.get.side_effect = [
OSError("temporary failure"),
[{"id": 42}],
]

result = await self.cog._fetch_with_retries(
params={"user__id": "123"}
)

self.assertEqual(result, [{"id": 42}])
self.assertEqual(self.bot.api_client.get.await_count, 2)


@patch("asyncio.sleep", new_callable=AsyncMock)
async def test_fetch_fails_after_max_retries(self, _):
error = OSError("API down")

self.bot.api_client.get.side_effect = error

with self.assertRaises(OSError):
await self.cog._fetch_with_retries(
retries=3,
params={"user__id": "123"},
)

self.assertEqual(self.bot.api_client.get.await_count, 3)


@patch("asyncio.sleep", new_callable=AsyncMock)
async def test_non_retriable_error_stops_immediately(self, _):
error = ValueError("bad request")

self.bot.api_client.get.side_effect = error
self.cog._check_error_is_retriable.return_value = False

with self.assertRaises(ValueError):
await self.cog._fetch_with_retries()

# only one attempt
self.bot.api_client.get.assert_awaited_once()


@patch("asyncio.sleep", new_callable=AsyncMock)
async def test_member_update_recovers_from_api_failure(self, _):
before = MagicMock(display_name="Old", id=123)
after = MagicMock(display_name="New", id=123)
after.edit = AsyncMock()

self.bot.api_client.get.side_effect = [
OSError(),
[{"id": 42}],
]

self.cog.get_nick = MagicMock(return_value="Taylor Swift")

with patch(
"bot.exts.moderation.infraction._utils.notify_infraction",
new=AsyncMock(return_value=True),
):
await self.cog.on_member_update(before, after)

after.edit.assert_awaited_once()

@patch("asyncio.sleep", new_callable=AsyncMock)
async def test_alert_triggered_after_total_failure(self, _):
self.bot.api_client.get.side_effect = OSError("down")

with self.assertRaises(OSError):
await self.cog._fetch_with_retries(retries=3)