feat: add allowed referrers to loosen csrf prevention (#27841)

* fix: add allowed referrers to loosen csrf prevention

* feat: Add test case for is_allowed_referrer functionality
This commit is contained in:
David Arnold 2024-11-15 07:39:53 +01:00 committed by GitHub
parent 6a568daa75
commit d4382dc020
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 52 additions and 1 deletions

View file

@ -86,6 +86,7 @@ class HTTPRequest:
(frappe.get_request_header("X-Frappe-CSRF-Token") or frappe.form_dict.pop("csrf_token", None)) (frappe.get_request_header("X-Frappe-CSRF-Token") or frappe.form_dict.pop("csrf_token", None))
== saved_token == saved_token
) )
or self.is_allowed_referrer()
): ):
return return
@ -95,6 +96,21 @@ class HTTPRequest:
def set_lang(self): def set_lang(self):
frappe.local.lang = get_language() frappe.local.lang = get_language()
def is_allowed_referrer(self):
referrer = frappe.get_request_header("Referer")
origin = frappe.get_request_header("Origin")
# Get the list of allowed referrers from cache or configuration
allowed_referrers = frappe.cache.get_value(
"allowed_referrers",
generator=lambda: frappe.conf.get("allowed_referrers", []),
)
# Check if the referrer or origin is in the allowed list
return (referrer and any(referrer.startswith(allowed) for allowed in allowed_referrers)) or (
origin and any(origin == allowed for allowed in allowed_referrers)
)
class LoginManager: class LoginManager:
__slots__ = ("user", "info", "full_name", "user_type", "user_lang", "resume") __slots__ = ("user", "info", "full_name", "user_type", "user_lang", "resume")

View file

@ -4,12 +4,14 @@ import datetime
import time import time
import requests import requests
from werkzeug.test import EnvironBuilder
from werkzeug.wrappers import Request
import frappe import frappe
from frappe.auth import LoginAttemptTracker from frappe.auth import LoginAttemptTracker
from frappe.frappeclient import AuthError, FrappeClient from frappe.frappeclient import AuthError, FrappeClient
from frappe.sessions import Session, get_expired_sessions, get_expiry_in_seconds from frappe.sessions import Session, get_expired_sessions, get_expiry_in_seconds
from frappe.tests import IntegrationTestCase from frappe.tests import IntegrationTestCase, UnitTestCase
from frappe.tests.test_api import FrappeAPITestCase from frappe.tests.test_api import FrappeAPITestCase
from frappe.utils import get_datetime, get_site_url, now from frappe.utils import get_datetime, get_site_url, now
from frappe.utils.data import add_to_date from frappe.utils.data import add_to_date
@ -165,6 +167,39 @@ class TestAuth(IntegrationTestCase):
self.assertAlmostEqual(get_expiry_in_seconds(), expiry_time - current_time, delta=60 * 60) self.assertAlmostEqual(get_expiry_in_seconds(), expiry_time - current_time, delta=60 * 60)
class TestAllowedReferrer(UnitTestCase):
def test_is_allowed_referrer(self):
def create_request(headers):
builder = EnvironBuilder(headers=headers)
env = builder.get_environ()
return Request(env)
# Test with valid referrer
frappe.cache.set_value("allowed_referrers", ["https://example.com"])
frappe.local.request = create_request({"Referer": "https://example.com/some/path"})
http_request = frappe.auth.HTTPRequest()
self.assertTrue(http_request.is_allowed_referrer())
# Test with invalid referrer
frappe.local.request = create_request({"Referer": "https://malicious.com"})
http_request = frappe.auth.HTTPRequest()
self.assertFalse(http_request.is_allowed_referrer())
# Test with valid origin
frappe.local.request = create_request({"Origin": "https://example.com"})
http_request = frappe.auth.HTTPRequest()
self.assertTrue(http_request.is_allowed_referrer())
# Test with invalid origin
frappe.local.request = create_request({"Origin": "https://malicious.com"})
http_request = frappe.auth.HTTPRequest()
self.assertFalse(http_request.is_allowed_referrer())
# Clean up
frappe.cache.delete_value("allowed_referrers")
frappe.local.request = None
class TestLoginAttemptTracker(IntegrationTestCase): class TestLoginAttemptTracker(IntegrationTestCase):
def test_account_lock(self): def test_account_lock(self):
"""Make sure that account locks after `n consecutive failures""" """Make sure that account locks after `n consecutive failures"""