diff --git a/frappe/__init__.py b/frappe/__init__.py index a751217e3d..bd8068741b 100644 --- a/frappe/__init__.py +++ b/frappe/__init__.py @@ -35,7 +35,6 @@ from typing import ( import click from werkzeug.datastructures import Headers -from werkzeug.local import Local, LocalProxy, release_local import frappe from frappe.query_builder.utils import ( @@ -45,7 +44,7 @@ from frappe.query_builder.utils import ( from frappe.utils.caching import deprecated_local_cache as local_cache from frappe.utils.caching import request_cache from frappe.utils.data import as_unicode, bold, cint, cstr, safe_decode, safe_encode, sbool -from frappe.utils.local import FrappeLocal +from frappe.utils.local import Local, LocalProxy, release_local # Local application imports from .exceptions import * @@ -76,7 +75,7 @@ if TYPE_CHECKING: # pragma: no cover from frappe.utils.redis_wrapper import ClientCache, RedisWrapper controllers: dict[str, "Document"] = {} -local = FrappeLocal() +local = Local() cache: Optional["RedisWrapper"] = None client_cache: Optional["ClientCache"] = None STANDARD_USERS = ("Guest", "Administrator") diff --git a/frappe/utils/local.py b/frappe/utils/local.py index 0d3f72e664..d724abcd10 100644 --- a/frappe/utils/local.py +++ b/frappe/utils/local.py @@ -1,39 +1,35 @@ from contextvars import ContextVar -from typing import Any +from typing import Any, Generic, TypeVar -from werkzeug.local import Local, LocalProxy +from werkzeug.local import LocalProxy as _LocalProxy +from werkzeug.local import _ProxyLookup +from werkzeug.local import release_local as _release_local _contextvar = ContextVar("frappe_local") -_local_attributes = frozenset(dir(Local)) -_local_proxy_attributes = frozenset(dir(LocalProxy)) + +T = TypeVar("T") -def get_local(name: str) -> Any: - obj = _contextvar.get(None) - if obj is not None and name in obj: - return obj[name] - - raise AttributeError(name) - - -class FrappeLocal(Local): +class Local: """ For internal use only. Do not use this class directly. """ __slots__ = () - def __init__(self): - super().__init__(_contextvar) - def __getattribute__(self, name: str) -> Any: - if name in _local_attributes: - return object.__getattribute__(self, name) + # this is not needed as long as we have no other attributes than special methods + # if name in _local_attributes: + # return object.__getattribute__(self, name) - return get_local(name) + obj = _contextvar.get(None) + if obj is not None and name in obj: + return obj[name] - def __getattr__(self, name: str) -> Any: - return get_local(name) + raise AttributeError(name) + + def __iter__(self): + return iter((_contextvar.get({})).items()) def __setattr__(self, name: str, value: Any) -> None: obj = _contextvar.get(None) @@ -51,10 +47,7 @@ class FrappeLocal(Local): raise AttributeError(name) - def __release_local__(self): - _contextvar.set({}) - - def __call__(self, name: str) -> LocalProxy: + def __call__(self, name: str) -> "LocalProxy": def _get_current_object() -> Any: obj = _contextvar.get(None) if obj is not None and name in obj: @@ -62,13 +55,20 @@ class FrappeLocal(Local): raise RuntimeError("object is not bound") from None - lp = FrappeLocalProxy(_get_current_object) + lp = LocalProxy(_get_current_object) object.__setattr__(lp, "_get_current_object", _get_current_object) return lp -class FrappeLocalProxy(LocalProxy): - __slots__ = () +class LocalProxy(Generic[T]): + __slots__ = _LocalProxy.__slots__ + __init__ = _LocalProxy.__init__ + + for attr, val in vars(_LocalProxy).items(): + if attr == "__getattr__" or not isinstance(val, _ProxyLookup): + continue + + locals()[attr] = val def __getattribute__(self, name: str) -> Any: if name in _local_proxy_attributes: @@ -76,9 +76,6 @@ class FrappeLocalProxy(LocalProxy): return getattr(get_obj(self), name) - def __getattr__(self, name: str) -> Any: - return getattr(get_obj(self), name) - def __setattr__(self, name: str, value: str) -> None: setattr(get_obj(self), name, value) @@ -107,5 +104,17 @@ class FrappeLocalProxy(LocalProxy): return str(get_obj(self)) -def get_obj(lp: FrappeLocalProxy) -> Any: +def get_obj(lp: LocalProxy) -> Any: return object.__getattribute__(lp, "_get_current_object")() + + +def release_local(local): + if isinstance(local, Local): + _contextvar.set({}) + return + + _release_local(local) + + +# _local_attributes = frozenset(attr for attr in dir(Local)) +_local_proxy_attributes = frozenset(attr for attr in dir(LocalProxy))