diff --git a/frappe/oauth.py b/frappe/oauth.py index 5e170db5d2..67595555f9 100644 --- a/frappe/oauth.py +++ b/frappe/oauth.py @@ -2,13 +2,14 @@ import base64 import datetime import hashlib import re -from http import cookies -from urllib.parse import unquote, urljoin, urlparse +from urllib.parse import urljoin, urlparse +from oauthlib.common import Request from oauthlib.openid import RequestValidator import frappe from frappe.auth import LoginManager +from frappe.integrations.doctype.oauth_client.oauth_client import OAuthClient from frappe.utils.data import cstr, get_system_timezone, now_datetime @@ -73,13 +74,11 @@ class OAuthWebRequestValidator(RequestValidator): # Post-authorization def save_authorization_code(self, client_id, code, request, *args, **kwargs): - cookie_dict = get_cookie_dict_from_headers(request) - oac = frappe.new_doc("OAuth Authorization Code") oac.scopes = get_url_delimiter().join(request.scopes) oac.redirect_uri_bound_to_authorization_code = request.redirect_uri oac.client = client_id - oac.user = unquote(cookie_dict["user_id"].value) + oac.user = frappe.session.user oac.authorization_code = code["code"] if request.nonce: @@ -92,43 +91,32 @@ class OAuthWebRequestValidator(RequestValidator): oac.save(ignore_permissions=True) frappe.db.commit() - def authenticate_client(self, request, *args, **kwargs): + def authenticate_client(self, request: Request, *args, **kwargs) -> bool | None: + """ + Loads the client based on request parameters and sets in oauth request. + Returns True on success, None on error. + """ # Get ClientID in URL if request.client_id: - oc = frappe.get_doc("OAuth Client", request.client_id) + client_name = request.client_id else: # Extract token, instantiate OAuth Bearer Token and use clientid from there. if "refresh_token" in frappe.form_dict: - oc = frappe.get_doc( - "OAuth Client", - frappe.db.get_value( - "OAuth Bearer Token", - {"refresh_token": frappe.form_dict["refresh_token"]}, - "client", - ), - ) + token_filters = {"refresh_token": frappe.form_dict["refresh_token"]} elif "token" in frappe.form_dict: - oc = frappe.get_doc( - "OAuth Client", - frappe.db.get_value("OAuth Bearer Token", frappe.form_dict["token"], "client"), - ) + token_filters = {"name": frappe.form_dict["token"]} else: - oc = frappe.get_doc( - "OAuth Client", - frappe.db.get_value( - "OAuth Bearer Token", - frappe.get_request_header("Authorization").split(" ")[1], - "client", - ), - ) + token_filters = {"name": frappe.get_request_header("Authorization").split(" ")[1]} + + client_name = frappe.db.get_value("OAuth Bearer Token", filters=token_filters, fieldname="client") + + oc: OAuthClient = frappe.get_doc("OAuth Client", client_name) try: request.client = request.client or oc.as_dict() except Exception as e: return generate_json_error_response(e) - cookie_dict = get_cookie_dict_from_headers(request) - user_id = unquote(cookie_dict.get("user_id").value) if "user_id" in cookie_dict else "Guest" - return frappe.session.user == user_id + return True def authenticate_client_id(self, client_id, request, *args, **kwargs): cli_id = frappe.db.get_value("OAuth Client", client_id, "name") @@ -506,13 +494,6 @@ class OAuthWebRequestValidator(RequestValidator): return True -def get_cookie_dict_from_headers(r): - cookie = cookies.BaseCookie() - if r.headers.get("Cookie"): - cookie.load(r.headers.get("Cookie")) - return cookie - - def calculate_at_hash(access_token, hash_alg): """Helper method for calculating an access token hash, as described in http://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken