diff --git a/frappe/auth.py b/frappe/auth.py index 598e98e3e0..05a9e30ebe 100644 --- a/frappe/auth.py +++ b/frappe/auth.py @@ -86,6 +86,7 @@ class HTTPRequest: (frappe.get_request_header("X-Frappe-CSRF-Token") or frappe.form_dict.pop("csrf_token", None)) == saved_token ) + or self.is_allowed_referrer() ): return @@ -95,6 +96,21 @@ class HTTPRequest: def set_lang(self): frappe.local.lang = get_language() + def is_allowed_referrer(self): + referrer = frappe.get_request_header("Referer") + origin = frappe.get_request_header("Origin") + + # Get the list of allowed referrers from cache or configuration + allowed_referrers = frappe.cache.get_value( + "allowed_referrers", + generator=lambda: frappe.conf.get("allowed_referrers", []), + ) + + # Check if the referrer or origin is in the allowed list + return (referrer and any(referrer.startswith(allowed) for allowed in allowed_referrers)) or ( + origin and any(origin == allowed for allowed in allowed_referrers) + ) + class LoginManager: __slots__ = ("user", "info", "full_name", "user_type", "user_lang", "resume") diff --git a/frappe/tests/test_auth.py b/frappe/tests/test_auth.py index 23a98d55be..016b46528c 100644 --- a/frappe/tests/test_auth.py +++ b/frappe/tests/test_auth.py @@ -4,12 +4,14 @@ import datetime import time import requests +from werkzeug.test import EnvironBuilder +from werkzeug.wrappers import Request import frappe from frappe.auth import LoginAttemptTracker from frappe.frappeclient import AuthError, FrappeClient from frappe.sessions import Session, get_expired_sessions, get_expiry_in_seconds -from frappe.tests import IntegrationTestCase +from frappe.tests import IntegrationTestCase, UnitTestCase from frappe.tests.test_api import FrappeAPITestCase from frappe.utils import get_datetime, get_site_url, now from frappe.utils.data import add_to_date @@ -165,6 +167,39 @@ class TestAuth(IntegrationTestCase): self.assertAlmostEqual(get_expiry_in_seconds(), expiry_time - current_time, delta=60 * 60) +class TestAllowedReferrer(UnitTestCase): + def test_is_allowed_referrer(self): + def create_request(headers): + builder = EnvironBuilder(headers=headers) + env = builder.get_environ() + return Request(env) + + # Test with valid referrer + frappe.cache.set_value("allowed_referrers", ["https://example.com"]) + frappe.local.request = create_request({"Referer": "https://example.com/some/path"}) + http_request = frappe.auth.HTTPRequest() + self.assertTrue(http_request.is_allowed_referrer()) + + # Test with invalid referrer + frappe.local.request = create_request({"Referer": "https://malicious.com"}) + http_request = frappe.auth.HTTPRequest() + self.assertFalse(http_request.is_allowed_referrer()) + + # Test with valid origin + frappe.local.request = create_request({"Origin": "https://example.com"}) + http_request = frappe.auth.HTTPRequest() + self.assertTrue(http_request.is_allowed_referrer()) + + # Test with invalid origin + frappe.local.request = create_request({"Origin": "https://malicious.com"}) + http_request = frappe.auth.HTTPRequest() + self.assertFalse(http_request.is_allowed_referrer()) + + # Clean up + frappe.cache.delete_value("allowed_referrers") + frappe.local.request = None + + class TestLoginAttemptTracker(IntegrationTestCase): def test_account_lock(self): """Make sure that account locks after `n consecutive failures"""