diff --git a/env_sample b/env_sample index fa2d5e3..b7bf774 100644 --- a/env_sample +++ b/env_sample @@ -1,2 +1,7 @@ SN_USER_NAME=example SN_PASSWORD=example + +# optional values +SN_SET_USE_OAUTH=true / false +SN_SET_CLIENT_ID="example" +SN_SET_CLIENT_SECRET="secret" diff --git a/requirements.txt b/requirements.txt index 068266e..1ac0128 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ annotated-doc==0.0.4 +Authlib==1.7.2 black==26.3.1 cachetools==7.0.6 certifi==2026.4.22 diff --git a/setup.py b/setup.py index d8ed99e..d1c5c6d 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,7 @@ "requests==2.33.1", "click==8.3.3", "xlsxwriter==3.2.9", + "Authlib==1.7.2", ] test_dependencies = [ "pytest==9.0.3", diff --git a/sn_set/requests_lib.py b/sn_set/requests_lib.py index 473e33d..3fe734f 100644 --- a/sn_set/requests_lib.py +++ b/sn_set/requests_lib.py @@ -1,11 +1,57 @@ from datetime import datetime -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import requests +from authlib.integrations.requests_client import OAuth2Session from requests.exceptions import HTTPError from .settings import Settings +# context holder to persist oauth2 tokens through +# the execution +context: Dict = {} + + +def client_factory(*args, **kwargs) -> Tuple: + if not (base_url := kwargs.get("base_url")): + raise ValueError("base_url must be specified") + # we store the request config indexed by the base_url, since we + # need different tokens for each instance + if clientConfig := context.get(base_url): + return clientConfig.get("client"), clientConfig.get("auth") + + settings = Settings() + if not settings.get_user() or not settings.get_password(): + raise ValueError("Username or Password is empty") + if settings.get_use_oauth() and ( + not settings.get_client_id() + or not settings.get_client_secret() + or not settings.get_grant_type() + ): + raise ValueError( + "Client ID, Client Secret, and Grant Type are required to use OAuth2" + ) + if settings.get_use_oauth(): + client = OAuth2Session( + client_id=settings.get_client_id(), + client_secret=settings.get_client_secret(), + scope="useraccount", + ) + client.fetch_token( + f"{base_url}/oauth_token.do", + username=settings.get_user(), + password=settings.get_password(), + ) + clientConfig: Dict = {"client": client} + context[base_url] = clientConfig + return (client, None) + else: + client = requests + auth = requests.auth.HTTPBasicAuth(settings.get_user(), settings.get_password()) + clientConfig: Dict = {"client": client, "auth": auth} + context[base_url] = clientConfig + return client, auth + def get_update_sets(instance_name: str) -> List[Dict[str, str]]: """ @@ -23,11 +69,12 @@ def get_update_sets(instance_name: str) -> List[Dict[str, str]]: raise ValueError("Please enter a valid instance name") uri = f"https://{instance_name}.service-now.com/api/now/table/sys_update_set" + base_url: str = f"https://{instance_name}.service-now.com" params = { "sysparm_query": "state=complete^ORstate=ignore", "sysparm_fields": "name", } - return make_request(uri, path_params=params) + return make_request(uri, path_params=params, base_url=base_url) def get_install_order(instance_name: str, set_ids: List[str]) -> List[Dict[str, str]]: @@ -48,10 +95,6 @@ def get_install_order(instance_name: str, set_ids: List[str]) -> List[Dict[str, if not isinstance(set_ids, List): raise ValueError("set_ids must be a list") - # id_regex = re.compile("[a-zA-Z0-9]{32}") - # for sys_id in set_ids: - # if not id_regex.match(sys_id): - # raise ValueError("Each ID must be a valid sys_id") for name in set_ids: if not name or not isinstance(name, str): raise ValueError("IDs cannot be null or empty") @@ -69,6 +112,7 @@ def get_install_order(instance_name: str, set_ids: List[str]) -> List[Dict[str, ] id_list = ",".join(set_ids) + base_url: str = f"https://{instance_name}.service-now.com" uri = f"https://{instance_name}.service-now.com/api/now/table/sys_remote_update_set" params = { "sysparm_query": ( @@ -79,7 +123,7 @@ def get_install_order(instance_name: str, set_ids: List[str]) -> List[Dict[str, "sysparm_display_value": "true", } try: - return make_request(uri, path_params=params) + return make_request(uri, path_params=params, base_url=base_url) except HTTPError as e: if e.response.status_code != 400 and e.response.status_code != 414: raise e @@ -100,7 +144,7 @@ def get_install_order(instance_name: str, set_ids: List[str]) -> List[Dict[str, "sysparm_fields": ",".join(fields), "sysparm_display_value": "true", } - results.append(make_request(uri, path_params=params)) + results.append(make_request(uri, path_params=params, base_url=base_url)) results = [elem[0] for elem in results if len(elem) > 0] return order_sets(results) @@ -154,6 +198,7 @@ def get_install_order_new( ] id_list = ",".join(set_ids) + base_url: str = f"https://{instance_name}.service-now.com" uri = f"https://{instance_name}.service-now.com/api/now/table/sys_update_set" params = { "sysparm_query": ( @@ -163,7 +208,7 @@ def get_install_order_new( "sysparm_fields": ",".join(fields), } try: - return make_request(uri, path_params=params) + return make_request(uri, path_params=params, base_url=base_url) except HTTPError as ex: if ex.response.status_code != 400 and ex.response.status_code != 414: raise ex @@ -183,13 +228,15 @@ def get_install_order_new( ), "sysparm_fields": ",".join(fields), } - results.append(make_request(uri, path_params=params)) + results.append(make_request(uri, path_params=params, base_url=base_url)) results = [elem[0] for elem in results if len(elem) > 0] return order_sets(results, order_by_field="sys_updated_on") -def make_request(uri: str, path_params: Dict[str, str] = None) -> Optional[Dict]: +def make_request( + uri: str, path_params: Dict[str, str] = None, base_url: str | None = None +) -> Optional[Dict]: """ Makes a request to the given uri @@ -197,18 +244,16 @@ def make_request(uri: str, path_params: Dict[str, str] = None) -> Optional[Dict] uri: str - The HTTP URI to gake the request against path_params: Dict - Dictionary of path params and their values to be added to the request + base_url - optional base_url to include when using OAuth2 """ - settings = Settings() - if not settings.get_user() or not settings.get_password(): - raise ValueError("Username or Password is empty") + client, basicAuth = client_factory(base_url=base_url) - r = requests.get( - uri, - params=path_params, - auth=requests.auth.HTTPBasicAuth(settings.get_user(), settings.get_password()), + r: requests.Response = ( + client.get(uri, params=path_params, auth=basicAuth) + if basicAuth + else client.get(uri, params=path_params) ) r.raise_for_status() - return r.json().get("result") diff --git a/sn_set/settings.py b/sn_set/settings.py index 7bf7e8d..04e046a 100644 --- a/sn_set/settings.py +++ b/sn_set/settings.py @@ -3,13 +3,33 @@ class Settings: def __init__(self): - env = Env() + env: Env = Env() env.read_env() - self.user = env.str("SN_USER_NAME") - self.password = env.str("SN_PASSWORD") + self.user: str = env.str("SN_USER_NAME") + self.password: str = env.str("SN_PASSWORD") + self.use_oauth: bool = env.bool("SN_SET_USE_OAUTH", False) + if self.use_oauth: + self.client_id: str = env.str("SN_SET_CLIENT_ID") + self.client_secret: str = env.str("SN_SET_CLIENT_SECRET") + self.grant_type: str = "password" - def get_user(self): + def get_user(self) -> str: return self.user - def get_password(self): + def get_password(self) -> str: return self.password + + def get_use_oauth(self) -> bool: + return self.use_oauth + + def get_client_id(self) -> str | None: + if self.use_oauth: + return self.client_id + + def get_client_secret(self) -> str | None: + if self.use_oauth: + return self.client_secret + + def get_grant_type(self) -> str | None: + if self.use_oauth: + return self.grant_type diff --git a/tests/conftest.py b/tests/conftest.py index e931c8d..296bf77 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,9 +11,11 @@ def runner(): def mock_env_vars(monkeypatch): monkeypatch.setenv("SN_USER_NAME", "user") monkeypatch.setenv("SN_PASSWORD", "password") + monkeypatch.setenv("SN_SET_USE_OAUTH", "False") @pytest.fixture def mock_empty_env_vars(monkeypatch): monkeypatch.setenv("SN_USER_NAME", "user") monkeypatch.setenv("SN_PASSWORD", "") + monkeypatch.setenv("SN_SET_USE_OAUTH", "False") diff --git a/tests/test_requests_lib.py b/tests/test_requests_lib.py index edd8574..05494c7 100644 --- a/tests/test_requests_lib.py +++ b/tests/test_requests_lib.py @@ -38,7 +38,7 @@ def test_make_request_valid(requests_mock, mock_env_vars): } ] requests_mock.get(mock_uri, json=mock_response, status_code=200) - r = make_request(test_uri, path_params=test_params) + r = make_request(test_uri, path_params=test_params, base_url="nyudev") assert r == valid_response @@ -46,14 +46,14 @@ def test_make_request_unauthorized(requests_mock, mock_env_vars): mock_uri = "mock://some-test.com" requests_mock.get(mock_uri, status_code=401) with pytest.raises(HTTPError): - make_request(mock_uri) + make_request(mock_uri, base_url="some-test") def test_make_request_not_found(requests_mock, mock_env_vars): mock_uri = "mock://some-test.com" requests_mock.get(mock_uri, status_code=404) with pytest.raises(HTTPError): - make_request(mock_uri) + make_request(mock_uri, base_url="test") def test_make_request_no_data(requests_mock, mock_env_vars): @@ -61,20 +61,14 @@ def test_make_request_no_data(requests_mock, mock_env_vars): mock_response = {"result": []} requests_mock.get(mock_uri, json=mock_response, status_code=200) - r = make_request(mock_uri) + r = make_request(mock_uri, base_url="test") assert r == [] -def test_make_request_missing_pass(requests_mock, mock_empty_env_vars): - mock_uri = "mock://some-test.com" - with pytest.raises(ValueError): - make_request(mock_uri) - - def test_get_update_sets_valid(monkeypatch): mock_payload = [{"name": "an update set", "sys_id": "12345"}] - def mock_make_request(instance_name, path_params): + def mock_make_request(instance_name, path_params, base_url): return mock_payload from sn_set import requests_lib @@ -115,6 +109,7 @@ def test_get_install_order_invalid(): def test_get_install_order_400(mock_make_request): mock_payload = [{"name": "a set", "commit_date": "2021-05-08 18:39:00"}] mock_uri = "https://nyudev.service-now.com/api/now/table/sys_remote_update_set" + mock_base_uri = "https://nyudev.service-now.com" test_fields = [ "name", "state", @@ -128,7 +123,7 @@ def test_get_install_order_400(mock_make_request): ] mock_params1 = { "sysparm_query": ( - "state=committed^nameINa,b" "^commit_dateISNOTEMPTY^ORDERBYcommit_date" + "state=committed^nameINa,b^commit_dateISNOTEMPTY^ORDERBYcommit_date" ), "sysparm_fields": ",".join(test_fields), "sysparm_display_value": "true", @@ -155,9 +150,15 @@ def test_get_install_order_400(mock_make_request): from sn_set.requests_lib import get_install_order get_install_order("nyudev", ["a", "b"]) - mock_make_request.assert_any_call(mock_uri, path_params=mock_params1) - mock_make_request.assert_any_call(mock_uri, path_params=mock_params2) - mock_make_request.assert_any_call(mock_uri, path_params=mock_params3) + mock_make_request.assert_any_call( + mock_uri, path_params=mock_params1, base_url=mock_base_uri + ) + mock_make_request.assert_any_call( + mock_uri, path_params=mock_params2, base_url=mock_base_uri + ) + mock_make_request.assert_any_call( + mock_uri, path_params=mock_params3, base_url=mock_base_uri + ) def test_get_install_list_internal_order():