Merge pull request #19029 from gavindsouza/runtime-type-checks-api

feat(whitelisted): Runtime typing hints validation
This commit is contained in:
Ankush Menat 2022-12-19 15:46:49 +05:30 committed by GitHub
commit ee9bfed4ec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 304 additions and 85 deletions

View file

@ -735,14 +735,21 @@ def whitelist(allow_guest=False, xss_safe=False, methods=None):
methods = ["GET", "POST", "PUT", "DELETE"]
def innerfn(fn):
from frappe.utils.typing_validations import validate_argument_types
global whitelisted, guest_methods, xss_safe_methods, allowed_http_methods_for_whitelisted_func
# validate argument types only if request is present
in_request_or_test = lambda: getattr(local, "request", None) or local.flags.in_test # noqa: E731
# get function from the unbound / bound method
# this is needed because functions can be compared, but not methods
method = None
if hasattr(fn, "__func__"):
method = fn
method = validate_argument_types(fn, apply_condition=in_request_or_test)
fn = method.__func__
else:
fn = validate_argument_types(fn, apply_condition=in_request_or_test)
whitelisted.append(fn)
allowed_http_methods_for_whitelisted_func[fn] = methods

View file

@ -40,9 +40,6 @@ def get_attached_images(doctype: str, names: list[str]) -> frappe._dict:
@frappe.whitelist()
def get_files_in_folder(folder: str, start: int = 0, page_length: int = 20) -> dict:
start = cint(start)
page_length = cint(page_length)
attachment_folder = frappe.db.get_value(
"File",
"Home/Attachments",
@ -101,10 +98,11 @@ def create_new_folder(file_name: str, folder: str) -> File:
@frappe.whitelist()
def move_file(file_list: list[File], new_parent: str, old_parent: str) -> None:
def move_file(file_list: list[File | dict] | str, new_parent: str, old_parent: str) -> None:
if isinstance(file_list, str):
file_list = json.loads(file_list)
# will check for permission on each file & update parent
for file_obj in file_list:
setup_folder_path(file_obj.get("name"), new_parent)

View file

@ -326,7 +326,7 @@ class Report(Document):
return data
@frappe.whitelist()
def toggle_disable(self, disable):
def toggle_disable(self, disable: bool):
if not self.has_permission("write"):
frappe.throw(_("You are not allowed to edit the report."))

View file

@ -383,9 +383,7 @@ class TestUser(FrappeTestCase):
# reset password
update_password(old_password, old_password=new_password)
self.assertRaisesRegex(
frappe.exceptions.ValidationError, "Invalid key type", update_password, "test", 1, ["like", "%"]
)
self.assertRaises(TypeError, update_password, "test", 1, ["like", "%"])
password_strength_response = {
"feedback": {"password_policy_validation_passed": False, "suggestions": ["Fix password"]}

View file

@ -1,6 +1,7 @@
# Copyright (c) 2015, Frappe Technologies Pvt. Ltd. and Contributors
# License: MIT. See LICENSE
from datetime import timedelta
from typing import Optional, Sequence
import frappe
import frappe.defaults
@ -538,7 +539,7 @@ class User(Document):
if self.__new_password:
user_data = (self.first_name, self.middle_name, self.last_name, self.email, self.birth_date)
result = test_password_strength(self.__new_password, "", None, user_data)
result = test_password_strength(self.__new_password, user_data=user_data)
feedback = result.get("feedback", None)
if feedback and not feedback.get("password_policy_validation_passed", False):
@ -679,12 +680,19 @@ def get_perm_info(role):
@frappe.whitelist(allow_guest=True)
def update_password(new_password, logout_all_sessions=0, key=None, old_password=None):
# validate key to avoid key input like ['like', '%'], '', ['in', ['']]
if key and not isinstance(key, str):
frappe.throw(_("Invalid key type"))
def update_password(
new_password: str, logout_all_sessions: int = 0, key: str = None, old_password: str = None
):
"""Update password for the current user.
result = test_password_strength(new_password, key, old_password)
Args:
new_password (str): New password.
logout_all_sessions (int, optional): If set to 1, all other sessions will be logged out. Defaults to 0.
key (str, optional): Password reset key. Defaults to None.
old_password (str, optional): Old password. Defaults to None.
"""
result = test_password_strength(new_password)
feedback = result.get("feedback", None)
if feedback and not feedback.get("password_policy_validation_passed", False):
@ -718,22 +726,22 @@ def update_password(new_password, logout_all_sessions=0, key=None, old_password=
if user_doc.user_type == "System User":
return "/app"
else:
return redirect_url if redirect_url else "/"
return redirect_url or "/"
@frappe.whitelist(allow_guest=True)
def test_password_strength(new_password, key=None, old_password=None, user_data=None):
def test_password_strength(
new_password: str, key=None, old_password=None, user_data: tuple | None = None
):
from frappe.utils.deprecations import deprecation_warning
from frappe.utils.password_strength import test_password_strength as _test_password_strength
password_policy = (
frappe.db.get_value(
"System Settings", None, ["enable_password_policy", "minimum_password_score"], as_dict=True
if key is not None or old_password is not None:
deprecation_warning(
"Arguments `key` and `old_password` are deprecated in function `test_password_strength`."
)
or {}
)
enable_password_policy = cint(password_policy.get("enable_password_policy", 0))
minimum_password_score = cint(password_policy.get("minimum_password_score", 0))
enable_password_policy = frappe.get_system_settings("enable_password_policy") or 0
if not enable_password_policy:
return {}
@ -746,6 +754,7 @@ def test_password_strength(new_password, key=None, old_password=None, user_data=
if new_password:
result = _test_password_strength(new_password, user_inputs=user_data)
password_policy_validation_passed = False
minimum_password_score = cint(frappe.get_system_settings("minimum_password_score")) or 0
# score should be greater than 0 and minimum_password_score
if result.get("score") and result.get("score") >= minimum_password_score:
@ -755,9 +764,8 @@ def test_password_strength(new_password, key=None, old_password=None, user_data=
return result
# for login
@frappe.whitelist()
def has_email_account(email):
def has_email_account(email: str):
return frappe.get_list("Email Account", filters={"email_id": email})
@ -824,7 +832,7 @@ def verify_password(password):
@frappe.whitelist(allow_guest=True)
def sign_up(email, full_name, redirect_to):
def sign_up(email: str, full_name: str, redirect_to: str) -> tuple[int, str]:
if is_signup_disabled():
frappe.throw(_("Sign Up is disabled"), title=_("Not Allowed"))
@ -876,12 +884,12 @@ def sign_up(email, full_name, redirect_to):
@frappe.whitelist(allow_guest=True)
@rate_limit(limit=get_password_reset_limit, seconds=24 * 60 * 60, methods=["POST"])
def reset_password(user):
def reset_password(user: str) -> str:
if user == "Administrator":
return "not allowed"
try:
user = frappe.get_doc("User", user)
user: User = frappe.get_doc("User", user)
if not user.enabled:
return "disabled"
@ -1071,13 +1079,12 @@ def throttle_user_creation():
@frappe.whitelist()
def get_role_profile(role_profile):
roles = frappe.get_doc("Role Profile", {"role_profile": role_profile})
return roles.roles
def get_role_profile(role_profile: str):
return frappe.get_doc("Role Profile", {"role_profile": role_profile}).roles
@frappe.whitelist()
def get_module_profile(module_profile):
def get_module_profile(module_profile: str):
module_profile = frappe.get_doc("Module Profile", {"module_profile_name": module_profile})
return module_profile.get("block_modules")
@ -1150,14 +1157,14 @@ def get_restricted_ip_list(user):
@frappe.whitelist()
def generate_keys(user):
def generate_keys(user: str):
"""
generate api key and api secret
:param user: str
"""
frappe.only_for("System Manager")
user_details = frappe.get_doc("User", user)
user_details: User = frappe.get_doc("User", user)
api_secret = frappe.generate_hash(length=15)
# if api key is not set generate api key
if not user_details.api_key:

View file

@ -15,6 +15,10 @@ class ValidationError(Exception):
http_status_code = 417
class FrappeTypeError(TypeError):
http_status_code = 417
class AuthenticationError(Exception):
http_status_code = 401

View file

@ -269,7 +269,7 @@ def ping():
def run_doc_method(method, docs=None, dt=None, dn=None, arg=None, args=None):
"""run a whitelisted controller method"""
from inspect import getfullargspec
from inspect import signature
if not args and arg:
args = arg
@ -298,7 +298,7 @@ def run_doc_method(method, docs=None, dt=None, dn=None, arg=None, args=None):
is_whitelisted(fn)
is_valid_http_method(fn)
fnargs = getfullargspec(method_obj).args
fnargs = list(signature(method_obj).parameters)
if not fnargs or (len(fnargs) == 1 and fnargs[0] == "self"):
response = doc.run_method(method)

View file

@ -103,10 +103,7 @@ def get_redis_server():
@frappe.whitelist(allow_guest=True)
def can_subscribe_doc(doctype, docname):
if os.environ.get("CI"):
return True
def can_subscribe_doc(doctype: str, docname: str) -> bool:
from frappe.exceptions import PermissionError
from frappe.sessions import Session
@ -118,7 +115,7 @@ def can_subscribe_doc(doctype, docname):
@frappe.whitelist(allow_guest=True)
def can_subscribe_list(doctype):
def can_subscribe_list(doctype: str) -> bool:
from frappe.exceptions import PermissionError
if not frappe.has_permission(user=frappe.session.user, doctype=doctype, ptype="read"):

View file

@ -223,7 +223,7 @@ class TestRenameDoc(FrappeTestCase):
new_name = f"{dn}-new"
# pass invalid types to API
with self.assertRaises(ValidationError):
with self.assertRaises(TypeError):
update_document_title(doctype=dt, docname=dn, title={}, name={"hack": "this"})
doc_before = frappe.get_doc(test_doctype, dn)

View file

@ -919,3 +919,28 @@ class TestMiscUtils(FrappeTestCase):
self.assertEqual(safe_json_loads("{}"), {})
self.assertEqual(safe_json_loads("{ /}"), "{ /}")
self.assertEqual(safe_json_loads("12"), 12) # this is a quirk
class TestTypingValidations(FrappeTestCase):
ERR_REGEX = f"^Argument '.*' should be of type '.*' but got '.*' instead.$"
def test_validate_whitelisted_api(self):
from inspect import signature
whitelisted_fn = next(x for x in frappe.whitelisted if x.__annotations__)
bad_params = (object(),) * len(signature(whitelisted_fn).parameters)
with self.assertRaisesRegex(frappe.FrappeTypeError, self.ERR_REGEX):
whitelisted_fn(*bad_params)
def test_validate_whitelisted_doc_method(self):
report = frappe.get_last_doc("Report")
with self.assertRaisesRegex(frappe.FrappeTypeError, self.ERR_REGEX):
report.toggle_disable(["disable"])
current_value = report.disabled
changed_value = not current_value
report.toggle_disable(changed_value)
report.toggle_disable(current_value)

View file

@ -1279,7 +1279,7 @@ def get_translator_url():
@frappe.whitelist(allow_guest=True)
def get_all_languages(with_language_name=False):
def get_all_languages(with_language_name: bool = False) -> list:
"""Returns all enabled language codes ar, ch etc"""
def get_language_codes():
@ -1298,7 +1298,7 @@ def get_all_languages(with_language_name=False):
@frappe.whitelist(allow_guest=True)
def set_preferred_language_cookie(preferred_language):
def set_preferred_language_cookie(preferred_language: str):
frappe.local.cookie_manager.set_cookie("preferred_language", preferred_language)

View file

@ -489,7 +489,7 @@ def search(text, start=0, limit=20, doctype=""):
@frappe.whitelist(allow_guest=True)
def web_search(text, scope=None, start=0, limit=20):
def web_search(text: str, scope: str | None = None, start: int = 0, limit: int = 20):
"""
Search for given text in __global_search where published = 1
:param text: phrase to be searched

View file

@ -0,0 +1,175 @@
from functools import lru_cache, wraps
from inspect import _empty, isclass, signature
from types import EllipsisType
from typing import Any, Callable, ForwardRef, TypeVar, Union
from pydantic.config import BaseConfig
from pydantic.error_wrappers import ValidationError as PyValidationError
from pydantic.tools import NameFactory, _generate_parsing_type_name
from frappe.exceptions import FrappeTypeError
SLACK_DICT = {
bool: (int, bool, float),
}
T = TypeVar("T")
class FrappePydanticConfig:
arbitrary_types_allowed = True
def validate_argument_types(func: Callable, apply_condition: Callable = lambda: True):
@wraps(func)
def wrapper(*args, **kwargs):
"""Validate argument types of whitelisted functions.
:param args: Function arguments.
:param kwargs: Function keyword arguments."""
if apply_condition():
args, kwargs = transform_parameter_types(func, args, kwargs)
return func(*args, **kwargs)
return wrapper
def qualified_name(obj) -> str:
"""
Return the qualified name (e.g. package.module.Type) for the given object.
Builtins and types from the :mod:typing package get special treatment by having the module
name stripped from the generated name.
"""
discovered_type = obj if isclass(obj) else type(obj)
module, qualname = discovered_type.__module__, discovered_type.__qualname__
if module in {"typing", "types"}:
return obj
elif module in {"builtins"}:
return qualname
else:
return f"{module}.{qualname}"
def raise_type_error(
arg_name: str, arg_type: type, arg_value: object, current_exception: Exception = None
):
"""
Raise a TypeError with a message that includes the name of the argument, the expected type
and the actual type of the value passed.
"""
raise FrappeTypeError(
f"Argument '{arg_name}' should be of type '{qualified_name(arg_type)}' but got "
f"'{qualified_name(arg_value)}' instead."
) from current_exception
@lru_cache(maxsize=2048)
def _get_parsing_type(
type_: Any, *, type_name: NameFactory | None = None, config: type[BaseConfig] = None
) -> Any:
# Note: this is a copy of pydantic.tools._get_parsing_type with the addition of allowing a config argument
from pydantic.main import create_model
if type_name is None:
type_name = _generate_parsing_type_name
if not isinstance(type_name, str):
type_name = type_name(type_)
return create_model(type_name, __root__=(type_, ...), __config__=config)
def parse_obj_as(
type_: type[T],
obj: Any,
*,
type_name: NameFactory | None = None,
config: type[BaseConfig] | None = None,
) -> T:
# Note: This is a copy of pydantic.tools.parse_obj_as with the addition of allowing a config argument
model_type = _get_parsing_type(type_, type_name=type_name, config=config) # type: ignore[arg-type]
return model_type(__root__=obj).__root__
def transform_parameter_types(func: Callable, args: tuple, kwargs: dict):
"""
Validate the types of the arguments passed to a function with the type annotations
defined on the function.
"""
if not (args or kwargs) or not func.__annotations__:
return args, kwargs
annotations = func.__annotations__
new_args, new_kwargs = list(args), kwargs
# generate kwargs dict from args
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
if not args:
prepared_args = kwargs
elif kwargs:
arg_values = args or func.__defaults__ or []
prepared_args = dict(zip(arg_names, arg_values))
prepared_args.update(kwargs)
else:
prepared_args = dict(zip(arg_names, args))
# check if type hints dont match the default values
func_signature = signature(func)
func_params = dict(func_signature.parameters)
# check if the argument types are correct
for current_arg, current_arg_type in annotations.items():
if current_arg not in prepared_args:
continue
current_arg_value = prepared_args[current_arg]
# if the type is a ForwardRef or str, ignore it
if isinstance(current_arg_type, (ForwardRef, str)):
continue
elif any(isinstance(x, (ForwardRef, str)) for x in getattr(current_arg_type, "__args__", [])):
continue
# allow slack for Frappe types
if current_arg_type in SLACK_DICT:
current_arg_type = SLACK_DICT[current_arg_type]
param_def = func_params.get(current_arg)
# add default value's type in acceptable types
if param_def.default is not _empty:
if isinstance(current_arg_type, tuple):
if type(param_def.default) not in current_arg_type:
current_arg_type += (type(param_def.default),)
current_arg_type = Union[current_arg_type]
elif param_def.default != current_arg_type:
current_arg_type = Union[current_arg_type, type(param_def.default)]
elif isinstance(current_arg_type, tuple):
current_arg_type = Union[current_arg_type]
# validate the type set using pydantic - raise a TypeError if Validation is raised or Ellipsis is returned
try:
current_arg_value_after = parse_obj_as(
current_arg_type, current_arg_value, type_name=current_arg, config=FrappePydanticConfig
)
except (TypeError, PyValidationError) as e:
raise_type_error(current_arg, current_arg_type, current_arg_value, current_exception=e)
if isinstance(current_arg_value_after, EllipsisType):
raise_type_error(current_arg, current_arg_type, current_arg_value)
# update the args and kwargs with possibly casted value
if current_arg in kwargs:
new_kwargs[current_arg] = current_arg_value_after
else:
new_args[arg_names.index(current_arg)] = current_arg_value_after
return new_args, new_kwargs

View file

@ -5,6 +5,7 @@ import copy
import json
import os
import re
from typing import TYPE_CHECKING, Optional
import frappe
from frappe import _, get_module_path
@ -13,6 +14,10 @@ from frappe.core.doctype.document_share_key.document_share_key import is_expired
from frappe.utils import cint, sanitize_html, strip_html
from frappe.utils.jinja_globals import is_rtl
if TYPE_CHECKING:
from frappe.model.document import Document
from frappe.printing.doctype.print_format.print_format import PrintFormat
no_cache = 1
standard_format = "templates/print_formats/standard.html"
@ -88,13 +93,12 @@ def get_print_format_doc(print_format_name, meta):
def get_rendered_template(
doc,
name=None,
print_format=None,
doc: "Document",
print_format: str | None = None,
meta=None,
no_letterhead=None,
letterhead=None,
trigger_print=False,
no_letterhead: bool | None = None,
letterhead: str | None = None,
trigger_print: bool = False,
settings=None,
):
@ -184,7 +188,7 @@ def get_rendered_template(
letter_head.footer, {"doc": doc.as_dict()}
)
convert_markdown(doc, meta)
convert_markdown(doc)
args = {}
# extract `print_heading_template` from the first field and remove it
@ -257,9 +261,9 @@ def set_title_values_for_table_and_multiselect_fields(meta, doc):
set_title_values_for_link_and_dynamic_link_fields(_meta, value, doc)
def convert_markdown(doc, meta):
def convert_markdown(doc: "Document"):
"""Convert text field values to markdown if necessary"""
for field in meta.fields:
for field in doc.meta.fields:
if field.fieldtype == "Text Editor":
value = doc.get(field.fieldname)
if value and "<!-- markdown -->" in value:
@ -268,34 +272,32 @@ def convert_markdown(doc, meta):
@frappe.whitelist()
def get_html_and_style(
doc,
name=None,
print_format=None,
meta=None,
no_letterhead=None,
letterhead=None,
trigger_print=False,
style=None,
settings=None,
templates=None,
doc: str,
name: str | None = None,
print_format: str | None = None,
no_letterhead: bool | None = None,
letterhead: str | None = None,
trigger_print: bool = False,
style: str | None = None,
settings: str | None = None,
):
"""Returns `html` and `style` of print format, used in PDF etc"""
if isinstance(doc, str) and isinstance(name, str):
doc = frappe.get_doc(doc, name)
if isinstance(name, str):
document = frappe.get_doc(doc, name)
else:
document = frappe.get_doc(json.loads(doc))
if isinstance(doc, str):
doc = frappe.get_doc(json.loads(doc))
document.check_permission()
print_format = get_print_format_doc(print_format, meta=meta or frappe.get_meta(doc.doctype))
set_link_titles(doc)
print_format = get_print_format_doc(print_format, meta=document.meta)
set_link_titles(document)
try:
html = get_rendered_template(
doc,
name=name,
doc=document,
print_format=print_format,
meta=meta,
meta=document.meta,
no_letterhead=no_letterhead,
letterhead=letterhead,
trigger_print=trigger_print,
@ -309,16 +311,17 @@ def get_html_and_style(
@frappe.whitelist()
def get_rendered_raw_commands(doc, name=None, print_format=None, meta=None, lang=None):
def get_rendered_raw_commands(doc: str, name: str | None = None, print_format: str | None = None):
"""Returns Rendered Raw Commands of print format, used to send directly to printer"""
if isinstance(doc, str) and isinstance(name, str):
doc = frappe.get_doc(doc, name)
if isinstance(name, str):
document = frappe.get_doc(doc, name)
else:
document = frappe.get_doc(json.loads(doc))
if isinstance(doc, str):
doc = frappe.get_doc(json.loads(doc))
document.check_permission()
print_format = get_print_format_doc(print_format, meta=meta or frappe.get_meta(doc.doctype))
print_format = get_print_format_doc(print_format, meta=document.meta)
if not print_format or (print_format and not print_format.raw_printing):
frappe.throw(
@ -326,7 +329,9 @@ def get_rendered_raw_commands(doc, name=None, print_format=None, meta=None, lang
)
return {
"raw_commands": get_rendered_template(doc, name=name, print_format=print_format, meta=meta)
"raw_commands": get_rendered_template(
doc=document, name=name, print_format=print_format, meta=document.meta
)
}
@ -361,7 +366,7 @@ def validate_key(key, doc):
raise frappe.exceptions.InvalidKeyError
def get_letter_head(doc, no_letterhead, letterhead=None):
def get_letter_head(doc: "Document", no_letterhead: bool, letterhead: str | None = None):
if no_letterhead:
return {}
if letterhead:
@ -519,7 +524,9 @@ def has_value(df, doc):
return True
def get_print_style(style=None, print_format=None, for_legacy=False):
def get_print_style(
style: str | None = None, print_format: Optional["PrintFormat"] = None, for_legacy: bool = False
):
print_settings = frappe.get_doc("Print Settings")
if not style:

View file

@ -20,7 +20,7 @@ def get_context(context):
@frappe.whitelist(allow_guest=True)
def get_search_results(text, scope=None, start=0, as_html=False):
def get_search_results(text: str, scope: str = None, start: int = 0, as_html: bool = False):
results = web_search(text, scope, start, limit=21)
out = frappe._dict()

View file

@ -55,7 +55,7 @@ def get_first_login(client):
@frappe.whitelist()
def delete_client(client_id):
def delete_client(client_id: str):
active_client_id_tokens = frappe.get_all(
"OAuth Bearer Token", filters=[["user", "=", frappe.session.user], ["client", "=", client_id]]
)

View file

@ -52,6 +52,7 @@ dependencies = [
"psycopg2-binary~=2.9.1",
"pyOpenSSL~=22.1.0",
"pycryptodome~=3.10.1",
"pydantic~=1.10.2",
"pyotp~=2.6.0",
"python-dateutil~=2.8.1",
"pytz==2022.1",