562 lines
17 KiB
Python
562 lines
17 KiB
Python
import re
|
|
import sqlite3
|
|
import warnings
|
|
from datetime import date, datetime, time
|
|
from pathlib import Path
|
|
|
|
import frappe
|
|
from frappe.database.database import (
|
|
TRANSACTION_DISABLED_MSG,
|
|
Database,
|
|
ImplicitCommitError,
|
|
)
|
|
from frappe.database.sqlite.schema import SQLiteTable
|
|
from frappe.utils import get_datetime, get_table_name
|
|
|
|
_PARAM_COMP = re.compile(r"%\([\w]*\)s")
|
|
IMPLICIT_COMMIT_QUERY_TYPES = frozenset(("start", "alter", "drop", "create", "truncate"))
|
|
|
|
|
|
class SQLiteExceptionUtil:
|
|
ProgrammingError = sqlite3.ProgrammingError
|
|
TableMissingError = sqlite3.OperationalError
|
|
OperationalError = sqlite3.OperationalError
|
|
InternalError = sqlite3.InternalError
|
|
SQLError = sqlite3.OperationalError
|
|
DataError = sqlite3.DataError
|
|
|
|
@staticmethod
|
|
def is_deadlocked(e: sqlite3.Error) -> bool:
|
|
return "database is locked" in str(e)
|
|
|
|
@staticmethod
|
|
def is_timedout(e: sqlite3.Error) -> bool:
|
|
return "database is locked" in str(e)
|
|
|
|
@staticmethod
|
|
def is_read_only_mode_error(e: sqlite3.Error) -> bool:
|
|
return "attempt to write a readonly database" in str(e)
|
|
|
|
@staticmethod
|
|
def is_table_missing(e: sqlite3.Error) -> bool:
|
|
return "no such table" in str(e)
|
|
|
|
@staticmethod
|
|
def is_missing_column(e: sqlite3.Error) -> bool:
|
|
return "no such column" in str(e)
|
|
|
|
@staticmethod
|
|
def is_duplicate_fieldname(e: sqlite3.Error) -> bool:
|
|
return "duplicate column name" in str(e)
|
|
|
|
@staticmethod
|
|
def is_duplicate_entry(e: sqlite3.Error) -> bool:
|
|
return "UNIQUE constraint failed" in str(e)
|
|
|
|
@staticmethod
|
|
def is_access_denied(e: sqlite3.Error) -> bool:
|
|
return "access denied" in str(e)
|
|
|
|
@staticmethod
|
|
def cant_drop_field_or_key(e: sqlite3.Error) -> bool:
|
|
return "cannot drop" in str(e)
|
|
|
|
@staticmethod
|
|
def is_syntax_error(e: sqlite3.Error) -> bool:
|
|
return "syntax error" in str(e)
|
|
|
|
@staticmethod
|
|
def is_statement_timeout(e: sqlite3.Error) -> bool:
|
|
return "statement timeout" in str(e)
|
|
|
|
@staticmethod
|
|
def is_data_too_long(e: sqlite3.Error) -> bool:
|
|
return "string or blob too big" in str(e)
|
|
|
|
@staticmethod
|
|
def is_db_table_size_limit(e: sqlite3.Error) -> bool:
|
|
return "too many columns" in str(e)
|
|
|
|
@staticmethod
|
|
def is_primary_key_violation(e: sqlite3.IntegrityError) -> bool:
|
|
if hasattr(e, "sqlite_errorcode"):
|
|
return e.sqlite_errorcode == 1555
|
|
return "UNIQUE constraint failed" in str(e)
|
|
|
|
@staticmethod
|
|
def is_unique_key_violation(e: sqlite3.IntegrityError) -> bool:
|
|
if hasattr(e, "sqlite_errorcode"):
|
|
return e.sqlite_errorcode == 2067
|
|
return "UNIQUE constraint failed" in str(e)
|
|
|
|
@staticmethod
|
|
def is_interface_error(e: sqlite3.Error):
|
|
return isinstance(e, sqlite3.InterfaceError)
|
|
|
|
|
|
class SQLiteDatabase(SQLiteExceptionUtil, Database):
|
|
REGEX_CHARACTER = "regexp"
|
|
default_port = None
|
|
MAX_ROW_SIZE_LIMIT = None
|
|
|
|
def get_connection(self, read_only: bool = False):
|
|
conn = self.create_connection(read_only)
|
|
conn.isolation_level = None
|
|
conn.create_function("regexp", 2, regexp)
|
|
return conn
|
|
|
|
def create_connection(self, read_only: bool = False):
|
|
db_path = self.get_db_path()
|
|
sqlite3.register_converter("timestamp", lambda x: datetime.fromisoformat(x.decode()))
|
|
sqlite3.register_converter("date", lambda x: date.fromisoformat(x.decode()))
|
|
sqlite3.register_converter("time", lambda x: time.fromisoformat(x.decode()))
|
|
if read_only:
|
|
return sqlite3.connect(f"file:{db_path}?mode=ro", uri=True, detect_types=sqlite3.PARSE_DECLTYPES)
|
|
return sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES)
|
|
|
|
def get_db_path(self):
|
|
return Path(frappe.get_site_path()) / "db" / f"{self.cur_db_name}.db"
|
|
|
|
def set_execution_timeout(self, seconds: int):
|
|
self.sql(f"PRAGMA busy_timeout = {int(seconds) * 1000}")
|
|
|
|
def setup_type_map(self):
|
|
self.db_type = "sqlite"
|
|
self.type_map = {
|
|
"Currency": ("REAL", None),
|
|
"Int": ("INTEGER", None),
|
|
"Long Int": ("INTEGER", None),
|
|
"Float": ("REAL", None),
|
|
"Percent": ("REAL", None),
|
|
"Check": ("INTEGER", None),
|
|
"Small Text": ("TEXT", None),
|
|
"Long Text": ("TEXT", None),
|
|
"Code": ("TEXT", None),
|
|
"Text Editor": ("TEXT", None),
|
|
"Markdown Editor": ("TEXT", None),
|
|
"HTML Editor": ("TEXT", None),
|
|
"Date": ("DATE", None),
|
|
"Datetime": ("TIMESTAMP", None),
|
|
"Time": ("TIME", None),
|
|
"Text": ("TEXT", None),
|
|
"Data": ("TEXT", None),
|
|
"Link": ("TEXT", None),
|
|
"Dynamic Link": ("TEXT", None),
|
|
"Password": ("TEXT", None),
|
|
"Select": ("TEXT", None),
|
|
"Rating": ("REAL", None),
|
|
"Read Only": ("TEXT", None),
|
|
"Attach": ("TEXT", None),
|
|
"Attach Image": ("TEXT", None),
|
|
"Signature": ("TEXT", None),
|
|
"Color": ("TEXT", None),
|
|
"Barcode": ("TEXT", None),
|
|
"Geolocation": ("TEXT", None),
|
|
"Duration": ("REAL", None),
|
|
"Icon": ("TEXT", None),
|
|
"Phone": ("TEXT", None),
|
|
"Autocomplete": ("TEXT", None),
|
|
"JSON": ("TEXT", None),
|
|
}
|
|
|
|
def get_database_size(self):
|
|
"""Return database size in MB."""
|
|
import os
|
|
|
|
return os.path.getsize(self.get_db_path()) / (1024 * 1024)
|
|
|
|
def _clean_up(self):
|
|
pass
|
|
|
|
@staticmethod
|
|
def escape(s, percent=True):
|
|
"""Escape quotes and percent in given string."""
|
|
s = s.replace("'", "''")
|
|
if percent:
|
|
s = s.replace("%", "%%")
|
|
return "'" + s + "'"
|
|
|
|
@staticmethod
|
|
def is_type_number(code):
|
|
return code in (sqlite3.NUMERIC, sqlite3.INTEGER, sqlite3.REAL)
|
|
|
|
@staticmethod
|
|
def is_type_datetime(code):
|
|
return code == sqlite3.TEXT
|
|
|
|
def rename_table(self, old_name: str, new_name: str) -> list | tuple:
|
|
old_name = get_table_name(old_name)
|
|
new_name = get_table_name(new_name)
|
|
return self.sql(f"ALTER TABLE `{old_name}` RENAME TO `{new_name}`")
|
|
|
|
def describe(self, doctype: str) -> list | tuple:
|
|
table_name = get_table_name(doctype)
|
|
return self.sql(f"PRAGMA table_info(`{table_name}`)")
|
|
|
|
def change_column_type(
|
|
self, doctype: str, column: str, type: str, nullable: bool = False
|
|
) -> list | tuple:
|
|
"""Change column type by recreating the table"""
|
|
table_name = get_table_name(doctype)
|
|
temp_table = f"{table_name}_new"
|
|
|
|
# Get current table column definitions
|
|
columns = []
|
|
column_exists = False
|
|
for col in self.sql(f"PRAGMA table_info(`{table_name}`)", as_dict=1):
|
|
if col["name"] == column:
|
|
column_exists = True
|
|
null_str = "" if nullable else " NOT NULL"
|
|
columns.append(f"`{col['name']}` {type}{null_str}")
|
|
else:
|
|
null_str = "" if col["notnull"] == 0 else " NOT NULL"
|
|
columns.append(f"`{col['name']}` {col['type']}{null_str}")
|
|
|
|
# Check that the column exists
|
|
if not column_exists:
|
|
raise frappe.InvalidColumnName(f"Column {column} does not exist in table {table_name}")
|
|
|
|
# Create new table
|
|
create_table = f"CREATE TABLE `{temp_table}` (\n{','.join(columns)}\n)"
|
|
self.sql_ddl(create_table)
|
|
|
|
# Copy data
|
|
column_names = [
|
|
f"`{col['name']}`" for col in self.sql(f"PRAGMA table_info(`{table_name}`)", as_dict=1)
|
|
]
|
|
column_list = ", ".join(column_names)
|
|
self.sql_ddl(f"INSERT INTO `{temp_table}` SELECT {column_list} FROM `{table_name}`")
|
|
|
|
# Drop old table and rename new table
|
|
self.sql_ddl(f"DROP TABLE `{table_name}`")
|
|
self.sql_ddl(f"ALTER TABLE `{temp_table}` RENAME TO `{table_name}`")
|
|
|
|
def rename_column(self, doctype: str, old_column_name: str, new_column_name: str):
|
|
"""Rename column by recreating the table"""
|
|
table_name = get_table_name(doctype)
|
|
temp_table = f"{table_name}_new"
|
|
|
|
# Get current table column definitions
|
|
columns = []
|
|
column_exists = False
|
|
for col in self.sql(f"PRAGMA table_info(`{table_name}`)", as_dict=1):
|
|
if col["name"] == old_column_name:
|
|
column_exists = True
|
|
null_str = "" if col["notnull"] == 0 else " NOT NULL"
|
|
columns.append(f"`{new_column_name}` {col['type']}{null_str}")
|
|
else:
|
|
null_str = "" if col["notnull"] == 0 else " NOT NULL"
|
|
columns.append(f"`{col['name']}` {col['type']}{null_str}")
|
|
|
|
if not column_exists:
|
|
raise frappe.InvalidColumnName(f"Column {old_column_name} does not exist in table {table_name}")
|
|
|
|
# Create new table
|
|
create_table = f"CREATE TABLE `{temp_table}` (\n{','.join(columns)}\n)"
|
|
self.sql_ddl(create_table)
|
|
|
|
# Get list of columns for SELECT, replacing old name with new
|
|
column_names = []
|
|
for col in self.sql(f"PRAGMA table_info(`{table_name}`)", as_dict=1):
|
|
if col["name"] == old_column_name:
|
|
column_names.append(f"`{old_column_name}` as `{new_column_name}`")
|
|
else:
|
|
column_names.append(f"`{col['name']}`")
|
|
|
|
# Copy data
|
|
column_list = ", ".join(column_names)
|
|
self.sql_ddl(f"INSERT INTO `{temp_table}` SELECT {column_list} FROM `{table_name}`")
|
|
|
|
# Drop old table and rename new table
|
|
self.sql_ddl(f"DROP TABLE `{table_name}`")
|
|
self.sql_ddl(f"ALTER TABLE `{temp_table}` RENAME TO `{table_name}`")
|
|
|
|
def create_auth_table(self):
|
|
self.sql_ddl(
|
|
"""CREATE TABLE IF NOT EXISTS `__Auth` (
|
|
`doctype` TEXT NOT NULL,
|
|
`name` TEXT NOT NULL,
|
|
`fieldname` TEXT NOT NULL,
|
|
`password` TEXT NOT NULL,
|
|
`encrypted` INTEGER NOT NULL DEFAULT 0,
|
|
PRIMARY KEY (`doctype`, `name`, `fieldname`)
|
|
)"""
|
|
)
|
|
|
|
def create_global_search_table(self):
|
|
if "__global_search" not in self.get_tables():
|
|
self.sql(
|
|
"""CREATE VIRTUAL TABLE __global_search USING FTS5(
|
|
doctype,
|
|
name,
|
|
title,
|
|
content,
|
|
route,
|
|
published
|
|
)"""
|
|
)
|
|
|
|
def create_user_settings_table(self):
|
|
self.sql_ddl(
|
|
"""CREATE TABLE IF NOT EXISTS __UserSettings (
|
|
`user` TEXT NOT NULL,
|
|
`doctype` TEXT NOT NULL,
|
|
`data` TEXT,
|
|
UNIQUE(user, doctype)
|
|
)"""
|
|
)
|
|
|
|
@staticmethod
|
|
def get_on_duplicate_update():
|
|
return "ON CONFLICT DO UPDATE SET "
|
|
|
|
def get_table_columns_description(self, table_name):
|
|
"""Return list of columns with descriptions."""
|
|
return self.sql(f"PRAGMA table_info(`{table_name}`)", as_dict=1)
|
|
|
|
def get_column_type(self, doctype, column):
|
|
"""Return column type from database."""
|
|
table_name = get_table_name(doctype)
|
|
result = self.sql(f"PRAGMA table_info(`{table_name}`)", as_dict=1)
|
|
for row in result:
|
|
if row["name"] == column:
|
|
return row["type"]
|
|
return None
|
|
|
|
def has_index(self, table_name, index_name):
|
|
return self.sql(f"SELECT * FROM pragma_index_list(`{table_name}`) WHERE name = '{index_name}'")
|
|
|
|
def get_column_index(self, table_name: str, fieldname: str, unique: bool = False) -> frappe._dict | None:
|
|
"""Check if column exists for a specific fields in specified order."""
|
|
indexes = self.sql(f"PRAGMA index_list(`{table_name}`)", as_dict=True)
|
|
for index in indexes:
|
|
index_info = self.sql(f"PRAGMA index_info(`{index['name']}`)", as_dict=True)
|
|
if index_info and index_info[0]["name"] == fieldname:
|
|
return index
|
|
|
|
def add_index(self, doctype: str, fields: list, index_name: str | None = None):
|
|
"""Creates an index with given fields if not already created."""
|
|
|
|
from frappe.custom.doctype.property_setter.property_setter import make_property_setter
|
|
|
|
# We can't specify the length of the index in SQLite
|
|
fields = [re.sub(r"\(.*?\)", "", field) for field in fields]
|
|
|
|
index_name = index_name or self.get_index_name(fields)
|
|
table_name = get_table_name(doctype)
|
|
self.commit()
|
|
self.sql(f"CREATE INDEX IF NOT EXISTS `{index_name}` ON `{table_name}` ({', '.join(fields)})")
|
|
|
|
# Ensure that DB migration doesn't clear this index, assuming this is manually added
|
|
# via code or console.
|
|
if len(fields) == 1 and not (frappe.flags.in_install or frappe.flags.in_migrate):
|
|
make_property_setter(
|
|
doctype,
|
|
fields[0],
|
|
property="search_index",
|
|
value="1",
|
|
property_type="Check",
|
|
for_doctype=False, # Applied on docfield
|
|
)
|
|
|
|
def add_unique(self, doctype, fields, constraint_name=None):
|
|
"""Creates unique constraint on fields."""
|
|
if isinstance(fields, str):
|
|
fields = [fields]
|
|
if not constraint_name:
|
|
constraint_name = f"unique_{'_'.join(fields)}"
|
|
table_name = get_table_name(doctype)
|
|
|
|
columns = ", ".join(fields)
|
|
sql_create_unique = (
|
|
f"CREATE UNIQUE INDEX IF NOT EXISTS `{constraint_name}` ON `{table_name}` ({columns})"
|
|
)
|
|
self.commit() # commit before creating index
|
|
self.sql(sql_create_unique)
|
|
|
|
def updatedb(self, doctype, meta=None):
|
|
"""Syncs a `DocType` to the table."""
|
|
res = self.sql("SELECT issingle FROM `tabDocType` WHERE name=%s", (doctype,))
|
|
if not res:
|
|
raise Exception(f"Wrong doctype {doctype} in updatedb")
|
|
|
|
if not res[0][0]:
|
|
db_table = SQLiteTable(doctype, meta)
|
|
db_table.validate()
|
|
db_table.sync()
|
|
self.commit()
|
|
|
|
def get_database_list(self):
|
|
return [self.db_name]
|
|
|
|
def get_tables(self, cached=True):
|
|
"""Return list of tables."""
|
|
to_query = not cached
|
|
|
|
if cached:
|
|
tables = frappe.cache.get_value("db_tables")
|
|
to_query = not tables
|
|
|
|
if to_query:
|
|
tables = self.sql("SELECT name FROM sqlite_master WHERE type='table';", pluck=True)
|
|
frappe.cache.set_value("db_tables", tables)
|
|
|
|
return tables
|
|
|
|
def get_row_size(self, doctype: str) -> int:
|
|
"""Get estimated max row size of any table in bytes."""
|
|
raise NotImplementedError("SQLite does not support getting row size directly.")
|
|
|
|
def execute_query(self, query, values=None):
|
|
query = query.replace("%s", "?")
|
|
try:
|
|
if isinstance(values, dict):
|
|
for k, v in values.items():
|
|
if isinstance(v, str) and "'" in v:
|
|
values[k] = self.escape(v)
|
|
else:
|
|
values[k] = f"'{v}'"
|
|
query = query % values
|
|
except TypeError:
|
|
pass
|
|
|
|
return self._cursor.execute(query, values or ())
|
|
|
|
def sql(self, *args, **kwargs):
|
|
if args:
|
|
# since tuple is immutable
|
|
args = list(args)
|
|
args[0] = modify_query(args[0])
|
|
args = tuple(args)
|
|
elif kwargs.get("query"):
|
|
kwargs["query"] = modify_query(kwargs.get("query"))
|
|
|
|
return super().sql(*args, **kwargs)
|
|
|
|
def sql_ddl(self, query, *args, **kwargs):
|
|
"""Execute DDL query."""
|
|
super().sql_ddl(query, *args, **kwargs)
|
|
self.commit()
|
|
|
|
def begin(self, *, read_only=False):
|
|
if read_only or frappe.flags.read_only:
|
|
if self._conn:
|
|
self._conn.close()
|
|
self._conn = self.get_connection(read_only=True)
|
|
self._cursor = self._conn.cursor()
|
|
self.read_only = True
|
|
|
|
elif hasattr(self, "read_only") and self.read_only:
|
|
self._conn.close()
|
|
self._conn = self.get_connection()
|
|
self._cursor = self._conn.cursor()
|
|
self.read_only = False
|
|
|
|
self.sql("BEGIN")
|
|
|
|
def commit(self):
|
|
"""Commit current transaction. Calls SQL `COMMIT`."""
|
|
if not self._conn:
|
|
self.connect()
|
|
|
|
if self._disable_transaction_control:
|
|
warnings.warn(message=TRANSACTION_DISABLED_MSG, stacklevel=2)
|
|
return
|
|
|
|
self.before_rollback.reset()
|
|
self.after_rollback.reset()
|
|
|
|
self.before_commit.run()
|
|
|
|
self._conn.commit()
|
|
self.transaction_writes = 0
|
|
self.begin() # explicitly start a new transaction
|
|
|
|
self.after_commit.run()
|
|
|
|
def rollback(self, *, save_point=None):
|
|
"""`ROLLBACK` current transaction. Optionally rollback to a known save_point."""
|
|
if not self._conn:
|
|
self.connect()
|
|
if save_point:
|
|
self.sql(f"rollback to savepoint {save_point}")
|
|
elif not self._disable_transaction_control:
|
|
self.before_commit.reset()
|
|
self.after_commit.reset()
|
|
|
|
self.before_rollback.run()
|
|
|
|
self._conn.rollback()
|
|
self.begin()
|
|
|
|
self.after_rollback.run()
|
|
else:
|
|
warnings.warn(message=TRANSACTION_DISABLED_MSG, stacklevel=2)
|
|
|
|
def get_db_table_columns(self, table) -> list[str]:
|
|
"""Return list of column names from given table."""
|
|
key = f"table_columns::{table}"
|
|
columns = frappe.client_cache.get_value(key)
|
|
if columns is None:
|
|
columns = self.sql(f"PRAGMA table_info(`{table}`)", as_dict=True)
|
|
columns = [col["name"] for col in columns]
|
|
|
|
if columns:
|
|
frappe.cache.set_value(key, columns)
|
|
|
|
return columns
|
|
|
|
def estimate_count(self, doctype: str):
|
|
"""Get estimated count of total rows in a table."""
|
|
from frappe.utils.data import cint
|
|
|
|
table = get_table_name(doctype)
|
|
try:
|
|
if count := self.sql(f"SELECT COUNT(*) FROM `{table}`"):
|
|
return cint(count[0][0])
|
|
except sqlite3.OperationalError as e:
|
|
if not self.is_table_missing(e):
|
|
raise
|
|
return 0
|
|
|
|
def truncate(self, doctype: str):
|
|
"""Truncate a table."""
|
|
table = get_table_name(doctype)
|
|
self.sql_ddl(f"DELETE FROM `{table}`")
|
|
self.sql_ddl(f"DELETE FROM sqlite_sequence WHERE name='{table}'")
|
|
|
|
def check_implicit_commit(self, query: str, query_type: str):
|
|
if query_type in IMPLICIT_COMMIT_QUERY_TYPES and self.transaction_writes:
|
|
raise ImplicitCommitError("This statement can cause implicit commit", query)
|
|
|
|
|
|
def modify_query(query):
|
|
"""
|
|
Modifies query according to the requirements of SQLite
|
|
"""
|
|
# Replace ` with " for definitions
|
|
query = str(query)
|
|
query = query.replace("`", '"')
|
|
query = replace_locate_with_instr(query)
|
|
|
|
# Select from requires ""
|
|
if re.search("from tab", query, flags=re.IGNORECASE):
|
|
query = re.sub("from tab([a-zA-Z]*)", r'from "tab\1"', query, flags=re.IGNORECASE)
|
|
|
|
return query
|
|
|
|
|
|
def replace_locate_with_instr(query: str) -> str:
|
|
# instr is the locate equivalent in SQLite
|
|
if re.search(r"locate\(", query, flags=re.IGNORECASE):
|
|
query = re.sub(r"locate\(([^,]+),([^)]+)\)", r"instr(\2, \1)", query, flags=re.IGNORECASE)
|
|
return query
|
|
|
|
|
|
def regexp(expr: str, item: str) -> bool:
|
|
"""
|
|
Define regexp implementation for SQLite manually
|
|
|
|
Although it works in the CLI - doesn't work through python
|
|
"""
|
|
return re.search(expr, item) is not None
|