diff --git a/bot/exts/moderation/infraction/superstarify.py b/bot/exts/moderation/infraction/superstarify.py index 006334755d..180a49d304 100644 --- a/bot/exts/moderation/infraction/superstarify.py +++ b/bot/exts/moderation/infraction/superstarify.py @@ -1,3 +1,4 @@ +import asyncio import json import random import textwrap @@ -6,10 +7,12 @@ from discord import Embed, Member from discord.ext.commands import Cog, Context, command, has_any_role from discord.utils import escape_markdown +from pydis_core.site_api import ResponseCodeError from pydis_core.utils.members import get_or_fetch_member from bot import constants from bot.bot import Bot +from bot.constants import URLs from bot.converters import Duration, DurationOrExpiry from bot.decorators import ensure_future_timestamp from bot.exts.moderation.infraction import _utils @@ -18,6 +21,8 @@ from bot.utils import time from bot.utils.messages import format_user +MAX_RETRY_ATTEMPTS = URLs.connect_max_retries +BACKOFF_INITIAL_DELAY = 5 # seconds log = get_logger(__name__) NICKNAME_POLICY_URL = "https://pythondiscord.com/pages/rules/#nickname-policy" SUPERSTARIFY_DEFAULT_DURATION = "1h" @@ -43,9 +48,7 @@ async def on_member_update(self, before: Member, after: Member) -> None: f"{after.display_name}. Checking if the user is in superstar-prison..." ) - active_superstarifies = await self.bot.api_client.get( - "bot/infractions", - params={ + active_superstarifies = await self._fetch_with_retries(params={ "active": "true", "type": "superstar", "user__id": str(before.id) @@ -84,9 +87,7 @@ async def on_member_update(self, before: Member, after: Member) -> None: @Cog.listener() async def on_member_join(self, member: Member) -> None: """Reapply active superstar infractions for returning members.""" - active_superstarifies = await self.bot.api_client.get( - "bot/infractions", - params={ + active_superstarifies = await self._fetch_with_retries(params={ "active": "true", "type": "superstar", "user__id": member.id @@ -238,6 +239,25 @@ async def cog_check(self, ctx: Context) -> bool: """Only allow moderators to invoke the commands in this cog.""" return await has_any_role(*constants.MODERATION_ROLES).predicate(ctx) + async def _fetch_with_retries(self, + retries: int = MAX_RETRY_ATTEMPTS, + params: dict[str, str] | None = None) -> list[dict]: + """Fetch infractions from the API with retries and exponential backoff.""" + for attempt in range(retries): + try: + return await self.bot.api_client.get("bot/infractions", params=params) + except Exception as e: + if attempt == retries - 1 or not self._check_error_is_retriable(e): + raise + await asyncio.sleep(BACKOFF_INITIAL_DELAY * (2 ** (attempt - 1))) + return None + + async def _check_error_is_retriable(self, error: Exception) -> bool: + """Return whether loading filter lists failed due to some temporary error, thus retrying could help.""" + if isinstance(error, ResponseCodeError): + return error.status in (408, 429) or error.status >= 500 + + return isinstance(error, (TimeoutError, OSError)) async def setup(bot: Bot) -> None: """Load the Superstarify cog.""" diff --git a/tests/bot/exts/moderation/infraction/test_superstarify_cog.py b/tests/bot/exts/moderation/infraction/test_superstarify_cog.py new file mode 100644 index 0000000000..54473c7064 --- /dev/null +++ b/tests/bot/exts/moderation/infraction/test_superstarify_cog.py @@ -0,0 +1,105 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from bot.exts.moderation.infraction.superstarify import Superstarify +from tests.helpers import MockBot + + +class TestSuperstarify(unittest.IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + self.bot = MockBot() + + self.cog = Superstarify(self.bot) + + self.bot.api_client = MagicMock() + self.bot.api_client.get = AsyncMock() + + self.cog._check_error_is_retriable = MagicMock(return_value=True) + + async def test_fetch_from_api_success(self): + """API succeeds on first attempt.""" + expected = [{"id": 1}] + self.bot.api_client.get.return_value = expected + + result = await self.cog._fetch_with_retries( + params={"user__id": "123"} + ) + self.assertEqual(result, expected) + + self.bot.api_client.get.assert_awaited_once_with( + "bot/infractions", + params={"user__id": "123"}, + ) + + @patch("asyncio.sleep", new_callable=AsyncMock) + async def test_fetch_retries_then_succeeds(self, _): + self.bot.api_client.get.side_effect = [ + OSError("temporary failure"), + [{"id": 42}], + ] + + result = await self.cog._fetch_with_retries( + params={"user__id": "123"} + ) + + self.assertEqual(result, [{"id": 42}]) + self.assertEqual(self.bot.api_client.get.await_count, 2) + + + @patch("asyncio.sleep", new_callable=AsyncMock) + async def test_fetch_fails_after_max_retries(self, _): + error = OSError("API down") + + self.bot.api_client.get.side_effect = error + + with self.assertRaises(OSError): + await self.cog._fetch_with_retries( + retries=3, + params={"user__id": "123"}, + ) + + self.assertEqual(self.bot.api_client.get.await_count, 3) + + + @patch("asyncio.sleep", new_callable=AsyncMock) + async def test_non_retriable_error_stops_immediately(self, _): + error = ValueError("bad request") + + self.bot.api_client.get.side_effect = error + self.cog._check_error_is_retriable.return_value = False + + with self.assertRaises(ValueError): + await self.cog._fetch_with_retries() + + # only one attempt + self.bot.api_client.get.assert_awaited_once() + + + @patch("asyncio.sleep", new_callable=AsyncMock) + async def test_member_update_recovers_from_api_failure(self, _): + before = MagicMock(display_name="Old", id=123) + after = MagicMock(display_name="New", id=123) + after.edit = AsyncMock() + + self.bot.api_client.get.side_effect = [ + OSError(), + [{"id": 42}], + ] + + self.cog.get_nick = MagicMock(return_value="Taylor Swift") + + with patch( + "bot.exts.moderation.infraction._utils.notify_infraction", + new=AsyncMock(return_value=True), + ): + await self.cog.on_member_update(before, after) + + after.edit.assert_awaited_once() + + @patch("asyncio.sleep", new_callable=AsyncMock) + async def test_alert_triggered_after_total_failure(self, _): + self.bot.api_client.get.side_effect = OSError("down") + + with self.assertRaises(OSError): + await self.cog._fetch_with_retries(retries=3)