Replace a for append loop with list extend. Create the list with values instead of creating an empty list and extending it with another list.
378 lines
11 KiB
Python
378 lines
11 KiB
Python
import re
|
|
|
|
import frappe
|
|
from frappe import _
|
|
from frappe.utils import cint, cstr, flt
|
|
|
|
SPECIAL_CHAR_PATTERN = re.compile(r"[\W]", flags=re.UNICODE)
|
|
VARCHAR_CAST_PATTERN = re.compile(r"varchar\(([\d]+)\)")
|
|
|
|
|
|
class InvalidColumnName(frappe.ValidationError):
|
|
pass
|
|
|
|
|
|
class DBTable:
|
|
def __init__(self, doctype, meta=None):
|
|
self.doctype = doctype
|
|
self.table_name = f"tab{doctype}"
|
|
self.meta = meta or frappe.get_meta(doctype, False)
|
|
self.columns: dict[str, DbColumn] = {}
|
|
self.current_columns = {}
|
|
|
|
# lists for change
|
|
self.add_column: list[DbColumn] = []
|
|
self.change_type: list[DbColumn] = []
|
|
self.change_name: list[DbColumn] = []
|
|
self.add_unique: list[DbColumn] = []
|
|
self.add_index: list[DbColumn] = []
|
|
self.drop_unique: list[DbColumn] = []
|
|
self.drop_index: list[DbColumn] = []
|
|
self.set_default: list[DbColumn] = []
|
|
|
|
# load
|
|
self.get_columns_from_docfields()
|
|
|
|
def sync(self):
|
|
if self.meta.get("is_virtual"):
|
|
# no schema to sync for virtual doctypes
|
|
return
|
|
if self.is_new():
|
|
self.create()
|
|
else:
|
|
frappe.cache.hdel("table_columns", self.table_name)
|
|
self.alter()
|
|
|
|
def create(self):
|
|
pass
|
|
|
|
def get_column_definitions(self):
|
|
column_list = [] + frappe.db.DEFAULT_COLUMNS
|
|
ret = []
|
|
for k in list(self.columns):
|
|
if k not in column_list:
|
|
d = self.columns[k].get_definition()
|
|
if d:
|
|
ret.append("`" + k + "` " + d)
|
|
column_list.append(k)
|
|
return ret
|
|
|
|
def get_index_definitions(self):
|
|
return [
|
|
"index `" + key + "`(`" + key + "`)"
|
|
for key, col in self.columns.items()
|
|
if (
|
|
col.set_index
|
|
and not col.unique
|
|
and col.fieldtype in frappe.db.type_map
|
|
and frappe.db.type_map.get(col.fieldtype)[0] not in ("text", "longtext")
|
|
)
|
|
]
|
|
|
|
def get_columns_from_docfields(self):
|
|
"""
|
|
get columns from docfields and custom fields
|
|
"""
|
|
fields = self.meta.get_fieldnames_with_value(with_field_meta=True)
|
|
|
|
# optional fields like _comments
|
|
if not self.meta.get("istable"):
|
|
for fieldname in frappe.db.OPTIONAL_COLUMNS:
|
|
fields.append({"fieldname": fieldname, "fieldtype": "Text"})
|
|
|
|
# add _seen column if track_seen
|
|
if self.meta.get("track_seen"):
|
|
fields.append({"fieldname": "_seen", "fieldtype": "Text"})
|
|
|
|
for field in fields:
|
|
if field.get("is_virtual"):
|
|
continue
|
|
|
|
self.columns[field.get("fieldname")] = DbColumn(
|
|
self,
|
|
field.get("fieldname"),
|
|
field.get("fieldtype"),
|
|
field.get("length"),
|
|
field.get("default"),
|
|
field.get("search_index"),
|
|
field.get("options"),
|
|
field.get("unique"),
|
|
field.get("precision"),
|
|
)
|
|
|
|
def validate(self):
|
|
"""Check if change in varchar length isn't truncating the columns"""
|
|
if self.is_new():
|
|
return
|
|
|
|
self.setup_table_columns()
|
|
|
|
columns = [
|
|
frappe._dict({"fieldname": f, "fieldtype": "Data"}) for f in frappe.db.STANDARD_VARCHAR_COLUMNS
|
|
]
|
|
if self.meta.get("istable"):
|
|
columns += [
|
|
frappe._dict({"fieldname": f, "fieldtype": "Data"}) for f in frappe.db.CHILD_TABLE_COLUMNS
|
|
]
|
|
columns += self.columns.values()
|
|
|
|
for col in columns:
|
|
if len(col.fieldname) >= 64:
|
|
frappe.throw(
|
|
_("Fieldname is limited to 64 characters ({0})").format(frappe.bold(col.fieldname))
|
|
)
|
|
|
|
if "varchar" in frappe.db.type_map.get(col.fieldtype, ()):
|
|
|
|
# validate length range
|
|
new_length = cint(col.length) or cint(frappe.db.VARCHAR_LEN)
|
|
if not (1 <= new_length <= 1000):
|
|
frappe.throw(_("Length of {0} should be between 1 and 1000").format(col.fieldname))
|
|
|
|
current_col = self.current_columns.get(col.fieldname, {})
|
|
if not current_col:
|
|
continue
|
|
current_type = self.current_columns[col.fieldname]["type"]
|
|
current_length = VARCHAR_CAST_PATTERN.findall(current_type)
|
|
if not current_length:
|
|
# case when the field is no longer a varchar
|
|
continue
|
|
current_length = current_length[0]
|
|
if cint(current_length) != cint(new_length):
|
|
try:
|
|
# check for truncation
|
|
max_length = frappe.db.sql(
|
|
"""SELECT MAX(CHAR_LENGTH(`{fieldname}`)) FROM `tab{doctype}`""".format(
|
|
fieldname=col.fieldname, doctype=self.doctype
|
|
)
|
|
)
|
|
|
|
except frappe.db.InternalError as e:
|
|
if frappe.db.is_missing_column(e):
|
|
# Unknown column 'column_name' in 'field list'
|
|
continue
|
|
raise
|
|
|
|
if max_length and max_length[0][0] and max_length[0][0] > new_length:
|
|
if col.fieldname in self.columns:
|
|
self.columns[col.fieldname].length = current_length
|
|
info_message = _(
|
|
"Reverting length to {0} for '{1}' in '{2}'. Setting the length as {3} will cause truncation of data."
|
|
).format(current_length, col.fieldname, self.doctype, new_length)
|
|
frappe.msgprint(info_message)
|
|
|
|
def is_new(self):
|
|
return self.table_name not in frappe.db.get_tables()
|
|
|
|
def setup_table_columns(self):
|
|
# TODO: figure out a way to get key data
|
|
for c in frappe.db.get_table_columns_description(self.table_name):
|
|
self.current_columns[c.name.lower()] = c
|
|
|
|
def alter(self):
|
|
pass
|
|
|
|
|
|
class DbColumn:
|
|
def __init__(
|
|
self, table, fieldname, fieldtype, length, default, set_index, options, unique, precision
|
|
):
|
|
self.table = table
|
|
self.fieldname = fieldname
|
|
self.fieldtype = fieldtype
|
|
self.length = length
|
|
self.set_index = set_index
|
|
self.default = default
|
|
self.options = options
|
|
self.unique = unique
|
|
self.precision = precision
|
|
|
|
def get_definition(self, for_modification=False):
|
|
column_def = get_definition(self.fieldtype, precision=self.precision, length=self.length)
|
|
|
|
if not column_def:
|
|
return column_def
|
|
|
|
if self.fieldtype in ("Check", "Int"):
|
|
default_value = cint(self.default) or 0
|
|
column_def += f" not null default {default_value}"
|
|
|
|
elif self.fieldtype in ("Currency", "Float", "Percent"):
|
|
default_value = flt(self.default) or 0
|
|
column_def += f" not null default {default_value}"
|
|
|
|
elif (
|
|
self.default
|
|
and (self.default not in frappe.db.DEFAULT_SHORTCUTS)
|
|
and not cstr(self.default).startswith(":")
|
|
):
|
|
column_def += f" default {frappe.db.escape(self.default)}"
|
|
|
|
if self.unique and not for_modification and (column_def not in ("text", "longtext")):
|
|
column_def += " unique"
|
|
|
|
return column_def
|
|
|
|
def build_for_alter_table(self, current_def):
|
|
column_type = get_definition(self.fieldtype, self.precision, self.length)
|
|
|
|
# no columns
|
|
if not column_type:
|
|
return
|
|
|
|
# to add?
|
|
if not current_def:
|
|
self.fieldname = validate_column_name(self.fieldname)
|
|
self.table.add_column.append(self)
|
|
|
|
if column_type not in ("text", "longtext"):
|
|
if self.unique:
|
|
self.table.add_unique.append(self)
|
|
if self.set_index:
|
|
self.table.add_index.append(self)
|
|
return
|
|
|
|
# type
|
|
if current_def["type"] != column_type:
|
|
self.table.change_type.append(self)
|
|
|
|
# unique
|
|
if (self.unique and not current_def["unique"]) and column_type not in ("text", "longtext"):
|
|
self.table.add_unique.append(self)
|
|
elif (current_def["unique"] and not self.unique) and column_type not in ("text", "longtext"):
|
|
self.table.drop_unique.append(self)
|
|
|
|
# default
|
|
if (
|
|
self.default_changed(current_def)
|
|
and (self.default not in frappe.db.DEFAULT_SHORTCUTS)
|
|
and not cstr(self.default).startswith(":")
|
|
):
|
|
self.table.set_default.append(self)
|
|
|
|
# index should be applied or dropped irrespective of type change
|
|
if (current_def["index"] and not self.set_index) and column_type not in ("text", "longtext"):
|
|
self.table.drop_index.append(self)
|
|
|
|
elif (not current_def["index"] and self.set_index) and not (column_type in ("text", "longtext")):
|
|
self.table.add_index.append(self)
|
|
|
|
def default_changed(self, current_def):
|
|
if "decimal" in current_def["type"]:
|
|
return self.default_changed_for_decimal(current_def)
|
|
else:
|
|
cur_default = current_def["default"]
|
|
new_default = self.default
|
|
if cur_default == "NULL" or cur_default is None:
|
|
cur_default = None
|
|
else:
|
|
# Strip quotes from default value
|
|
# eg. database returns default value as "'System Manager'"
|
|
cur_default = cur_default.lstrip("'").rstrip("'")
|
|
|
|
fieldtype = self.fieldtype
|
|
if fieldtype in ["Int", "Check"]:
|
|
cur_default = cint(cur_default)
|
|
new_default = cint(new_default)
|
|
elif fieldtype in ["Currency", "Float", "Percent"]:
|
|
cur_default = flt(cur_default)
|
|
new_default = flt(new_default)
|
|
return cur_default != new_default
|
|
|
|
def default_changed_for_decimal(self, current_def):
|
|
try:
|
|
if current_def["default"] in ("", None) and self.default in ("", None):
|
|
# both none, empty
|
|
return False
|
|
|
|
elif current_def["default"] in ("", None):
|
|
try:
|
|
# check if new default value is valid
|
|
float(self.default)
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
|
|
elif self.default in ("", None):
|
|
# new default value is empty
|
|
return True
|
|
|
|
else:
|
|
# NOTE float() raise ValueError when "" or None is passed
|
|
return float(current_def["default"]) != float(self.default)
|
|
except TypeError:
|
|
return True
|
|
|
|
|
|
def validate_column_name(n):
|
|
if special_characters := SPECIAL_CHAR_PATTERN.findall(n):
|
|
special_characters = ", ".join(f'"{c}"' for c in special_characters)
|
|
frappe.throw(
|
|
_("Fieldname {0} cannot have special characters like {1}").format(
|
|
frappe.bold(cstr(n)), special_characters
|
|
),
|
|
frappe.db.InvalidColumnName,
|
|
)
|
|
return n
|
|
|
|
|
|
def validate_column_length(fieldname):
|
|
if len(fieldname) > frappe.db.MAX_COLUMN_LENGTH:
|
|
frappe.throw(_("Fieldname is limited to 64 characters ({0})").format(fieldname))
|
|
|
|
|
|
def get_definition(fieldtype, precision=None, length=None):
|
|
d = frappe.db.type_map.get(fieldtype)
|
|
|
|
if not d:
|
|
return
|
|
|
|
if fieldtype == "Int" and length and length > 11:
|
|
# convert int to long int if the length of the int is greater than 11
|
|
d = frappe.db.type_map.get("Long Int")
|
|
|
|
coltype = d[0]
|
|
size = d[1] if d[1] else None
|
|
|
|
if size:
|
|
# This check needs to exist for backward compatibility.
|
|
# Till V13, default size used for float, currency and percent are (18, 6).
|
|
if fieldtype in ["Float", "Currency", "Percent"] and cint(precision) > 6:
|
|
size = "21,9"
|
|
|
|
if length:
|
|
if coltype == "varchar":
|
|
size = length
|
|
elif coltype == "int" and length < 11:
|
|
# allow setting custom length for int if length provided is less than 11
|
|
# NOTE: this will only be applicable for mariadb as frappe implements int
|
|
# in postgres as bigint (as seen in type_map)
|
|
size = length
|
|
|
|
if size is not None:
|
|
coltype = f"{coltype}({size})"
|
|
|
|
return coltype
|
|
|
|
|
|
def add_column(
|
|
doctype, column_name, fieldtype, precision=None, length=None, default=None, not_null=False
|
|
):
|
|
if column_name in frappe.db.get_table_columns(doctype):
|
|
# already exists
|
|
return
|
|
|
|
frappe.db.commit()
|
|
|
|
query = "alter table `tab{}` add column {} {}".format(
|
|
doctype,
|
|
column_name,
|
|
get_definition(fieldtype, precision, length),
|
|
)
|
|
|
|
if not_null:
|
|
query += " not null"
|
|
if default:
|
|
query += f" default '{default}'"
|
|
|
|
frappe.db.sql(query)
|