diff --git a/frappe/tests/test_oauth20.py b/frappe/tests/test_oauth20.py index 8375d42851..533f2a49a7 100644 --- a/frappe/tests/test_oauth20.py +++ b/frappe/tests/test_oauth20.py @@ -12,6 +12,7 @@ from frappe.integrations.oauth2 import encode_params from frappe.tests import IntegrationTestCase from frappe.tests.test_api import get_test_client, make_request, suppress_stdout from frappe.tests.utils import make_test_records +from frappe.utils.oauth import build_oauth_url if TYPE_CHECKING: from frappe.integrations.doctype.social_login_key.social_login_key import SocialLoginKey @@ -360,6 +361,28 @@ class TestOAuth20(FrappeRequestTestCase): self.assertTrue(payload.get("nonce") == nonce) + def test_build_oauth_url(self): + self.assertEqual(build_oauth_url("https://example.com", "/endpoint"), "https://example.com/endpoint") + + self.assertEqual(build_oauth_url("https://example.com"), "https://example.com") + + self.assertEqual(build_oauth_url("https://example.com", None), "https://example.com") + + self.assertEqual( + build_oauth_url("https://example.com", "//endpoint.com/test"), + "https://example.com//endpoint.com/test", + ) + + self.assertEqual( + build_oauth_url("https://example.com", "http://endpoint.com/test"), "http://endpoint.com/test" + ) + + self.assertEqual( + build_oauth_url("https://example.com", "https://endpoint.com"), "https://endpoint.com" + ) + + self.assertEqual(build_oauth_url("https://example.com", ""), "https://example.com") + def decode_id_token(self, id_token): import jwt diff --git a/frappe/utils/oauth.py b/frappe/utils/oauth.py index d41304084b..8f7c7a6d1f 100644 --- a/frappe/utils/oauth.py +++ b/frappe/utils/oauth.py @@ -5,6 +5,7 @@ import base64 import json from collections.abc import Callable from typing import TYPE_CHECKING +from urllib.parse import urlparse import frappe import frappe.utils @@ -19,18 +20,50 @@ if TYPE_CHECKING: class SignupDisabledError(frappe.PermissionError): ... +def build_oauth_url(base_url: str, url: str | None = None) -> str: + """ + Build a complete OAuth authorization URL. + + This helper constructs a full OAuth URL starting from a given base URL. + + If `url` is omitted, the function simply returns the normalized base URL. If the + `url` contains the relative or absolute path, the function will return this + appended to the base URL. If the `url` contains a `scheme` (e.g. "https://" and a + `netloc` (e.g. "www.example.com")), the function will return the passed `url` alone. + + Args: + base_url (str): The base OAuth endpoint (e.g. "https://example.com"). + url (str | None): An optional path or override URL to combine with the base. + + Returns: + str: The fully qualified OAuth URL ready for use in redirects or API calls. + """ + if url is None: + return base_url + parsed = urlparse(url) + if not (parsed.scheme and parsed.netloc): + return base_url + url + return url + + def get_oauth2_providers() -> dict[str, dict]: out = {} providers = frappe.get_all("Social Login Key", fields=["*"]) for provider in providers: - authorize_url, access_token_url = provider.authorize_url, provider.access_token_url + authorize_url, access_token_url, api_endpoint_url = ( + provider.authorize_url, + provider.access_token_url, + provider.api_endpoint, + ) + if provider.custom_base_url: - authorize_url = provider.base_url + provider.authorize_url - access_token_url = provider.base_url + provider.access_token_url + authorize_url = build_oauth_url(provider.base_url, provider.authorize_url) + access_token_url = build_oauth_url(provider.base_url, provider.access_token_url) + api_endpoint_url = build_oauth_url(provider.base_url, provider.api_endpoint) # Keycloak needs this, the base URL also has a route, that urljoin() ignores if provider.name == "keycloak": - provider.api_endpoint = provider.base_url + provider.api_endpoint + api_endpoint_url = build_oauth_url(provider.base_url, provider.api_endpoint) out[provider.name] = { "flow_params": { @@ -40,7 +73,7 @@ def get_oauth2_providers() -> dict[str, dict]: "base_url": provider.base_url, }, "redirect_uri": provider.redirect_url, - "api_endpoint": provider.api_endpoint, + "api_endpoint": api_endpoint_url, } if provider.auth_url_data: out[provider.name]["auth_url_data"] = json.loads(provider.auth_url_data)