Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/snippets/providers/pagerduty-snippet-autogenerated.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ actions:
priority: {value} # Priority reference ID for incidents
status: {value} # Status for incident updates (resolved/acknowledged)
resolution: {value} # Resolution note for resolved incidents
client: {value} # Name of the monitoring client triggering this event (Events API v2 only)
client_url: {value} # URL of the monitoring client triggering this event (Events API v2 only)
body: {value} # Body of the incident as per https://developer.pagerduty.com/api-reference/a7d81b0e9200f-create-an-incident#request-body
kwargs: {value} # Additional event/incident fields
```
Expand Down
21 changes: 20 additions & 1 deletion keep/providers/pagerduty_provider/pagerduty_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ def _build_alert(
severity: typing.Literal["critical", "error", "warning", "info"] | None = None,
event_type: typing.Literal["trigger", "acknowledge", "resolve"] | None = None,
source: str | None = None,
client: str | None = None,
client_url: str | None = None,
**kwargs,
) -> typing.Dict[str, typing.Any]:
"""
Expand Down Expand Up @@ -437,6 +439,12 @@ def _build_alert(
if kwargs.get("class"):
payload["payload"]["class"] = kwargs.get("class")

if client:
payload["client"] = client

if client_url:
payload["client_url"] = client_url

if kwargs.get("images"):
images = kwargs.get("images", [])
if isinstance(images, str):
Expand All @@ -458,6 +466,8 @@ def _send_alert(
severity: typing.Literal["critical", "error", "warning", "info"] | None = None,
event_type: typing.Literal["trigger", "acknowledge", "resolve"] | None = None,
source: str | None = None,
client: str | None = None,
client_url: str | None = None,
**kwargs,
):
"""
Expand All @@ -468,11 +478,14 @@ def _send_alert(
alert_body: UTF-8 string of custom message for alert. Shown in incident body
dedup: Any string, max 255, characters used to deduplicate alerts
event_type: The type of event to send to PagerDuty
client: Name of the monitoring client triggering this event
client_url: URL of the monitoring client triggering this event
"""
url = "https://events.pagerduty.com/v2/enqueue"

payload = self._build_alert(
title, routing_key, dedup, severity, event_type, source, **kwargs
title, routing_key, dedup, severity, event_type, source,
client=client, client_url=client_url, **kwargs
)
result = requests.post(url, json=payload)
result.raise_for_status()
Expand Down Expand Up @@ -708,6 +721,8 @@ def _notify(
priority: str = "",
status: typing.Literal["resolved", "acknowledged"] = "",
resolution: str = "",
client: str = "",
client_url: str = "",
**kwargs: dict,
):
"""
Expand All @@ -729,6 +744,8 @@ def _notify(
source (str): Source field for events API
status (str): Status for incident updates (resolved/acknowledged)
resolution (str): Resolution note for resolved incidents
client (str): Name of the monitoring client triggering this event (Events API v2 only)
client_url (str): URL of the monitoring client triggering this event (Events API v2 only)
kwargs (dict): Additional event/incident fields
"""
if not routing_key: # If routing_key not specified in workflow, fallback to config routing_key
Expand All @@ -741,6 +758,8 @@ def _notify(
routing_key=routing_key,
source=source,
severity=severity,
client=client or None,
client_url=client_url or None,
**kwargs,
)
else:
Expand Down
46 changes: 46 additions & 0 deletions tests/test_pagerduty_provider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import unittest
from unittest.mock import MagicMock

from keep.api.models.db.incident import IncidentSeverity, IncidentStatus
from keep.providers.pagerduty_provider.pagerduty_provider import PagerdutyProvider
Expand All @@ -18,6 +19,51 @@ def test_format_alert(self):
self.assertEqual(formatted_alert.status, IncidentStatus.FIRING)
self.assertEqual(formatted_alert.alert_sources, ["pagerduty"])

def _make_provider(self):
"""Create a minimal PagerdutyProvider for unit testing _build_alert."""
ctx = MagicMock()
ctx.event_context = None
config = MagicMock()
config.authentication = {"routing_key": "test-key"}
provider = object.__new__(PagerdutyProvider)
provider.context_manager = ctx
provider.logger = MagicMock()
return provider

def test_build_alert_includes_client_fields(self):
provider = self._make_provider()
payload = provider._build_alert(
title="Test alert",
routing_key="test-routing-key",
client="My Monitoring Tool",
client_url="https://monitoring.example.com",
)
self.assertEqual(payload["client"], "My Monitoring Tool")
self.assertEqual(payload["client_url"], "https://monitoring.example.com")
# client and client_url should be top-level, not inside payload.payload
self.assertNotIn("client", payload["payload"])
self.assertNotIn("client_url", payload["payload"])

def test_build_alert_omits_client_fields_when_not_provided(self):
provider = self._make_provider()
payload = provider._build_alert(
title="Test alert",
routing_key="test-routing-key",
)
self.assertNotIn("client", payload)
self.assertNotIn("client_url", payload)

def test_build_alert_images_and_links_in_payload(self):
provider = self._make_provider()
payload = provider._build_alert(
title="Test alert",
routing_key="test-routing-key",
images=[{"src": "https://example.com/img.png"}],
links=[{"href": "https://example.com", "text": "Example"}],
)
self.assertIn("images", payload["payload"])
self.assertIn("links", payload["payload"])


if __name__ == "__main__":
unittest.main()
Loading