diff --git a/frappe/database/mariadb/database.py b/frappe/database/mariadb/database.py index df64bdc86a..1f087a243a 100644 --- a/frappe/database/mariadb/database.py +++ b/frappe/database/mariadb/database.py @@ -310,7 +310,7 @@ class MariaDBDatabase(MariaDBConnectionUtil, MariaDBExceptionUtil, Database): ) @staticmethod - def get_on_duplicate_update(key=None): + def get_on_duplicate_update(): return "ON DUPLICATE key UPDATE " def get_table_columns_description(self, table_name): @@ -329,7 +329,8 @@ class MariaDBDatabase(MariaDBConnectionUtil, MariaDBExceptionUtil, Database): and Seq_in_index = 1 limit 1 ), 0) as 'index', - column_key = 'UNI' as 'unique' + column_key = 'UNI' as 'unique', + (is_nullable = 'NO') AS 'not_nullable' from information_schema.columns as columns where table_name = '{table_name}' """.format( table_name=table_name diff --git a/frappe/database/mariadb/schema.py b/frappe/database/mariadb/schema.py index 209535b9fd..7d8a1fc9e7 100644 --- a/frappe/database/mariadb/schema.py +++ b/frappe/database/mariadb/schema.py @@ -3,6 +3,7 @@ from pymysql.constants.ER import DUP_ENTRY import frappe from frappe import _ from frappe.database.schema import DBTable +from frappe.utils.defaults import get_not_null_defaults class MariaDBTable(DBTable): @@ -69,7 +70,7 @@ class MariaDBTable(DBTable): add_column_query = [ f"ADD COLUMN `{col.fieldname}` {col.get_definition()}" for col in self.add_column ] - columns_to_modify = set(self.change_type + self.set_default) + columns_to_modify = set(self.change_type + self.set_default + self.change_nullability) modify_column_query = [ f"MODIFY `{col.fieldname}` {col.get_definition(for_modification=True)}" for col in columns_to_modify @@ -102,12 +103,24 @@ class MariaDBTable(DBTable): if index_record := frappe.db.get_column_index(self.table_name, col.fieldname, unique=False): drop_index_query.append(f"DROP INDEX `{index_record.Key_name}`") + for col in self.change_nullability: + current_column = self.current_columns.get(col.fieldname.lower()) + if col.not_nullable: + default_value = get_not_null_defaults(col.fieldtype) + if isinstance(default_value, str): + default_value = frappe.db.escape(default_value) + query = f"UPDATE `{self.table_name}` SET `{col.fieldname}`={default_value} WHERE `{col.fieldname}` IS NULL;" + try: + frappe.db.sql(query, ignore_implicit_commit=True) + except Exception as e: + print(f"Failed to alter schema using query: {query}") + raise try: for query_parts in [add_column_query, modify_column_query, add_index_query, drop_index_query]: if query_parts: query_body = ", ".join(query_parts) query = f"ALTER TABLE `{self.table_name}` {query_body}" - frappe.db.sql(query) + frappe.db.sql(query, ignore_implicit_commit=True) except Exception as e: if query := locals().get("query"): # this weirdness is to avoid potentially unbounded vars diff --git a/frappe/database/postgres/database.py b/frappe/database/postgres/database.py index 8f330f3676..476067c435 100644 --- a/frappe/database/postgres/database.py +++ b/frappe/database/postgres/database.py @@ -392,7 +392,8 @@ class PostgresDatabase(PostgresExceptionUtil, Database): END AS type, BOOL_OR(b.index) AS index, SPLIT_PART(COALESCE(a.column_default, NULL), '::', 1) AS default, - BOOL_OR(b.unique) AS unique + BOOL_OR(b.unique) AS unique, + COALESCE(a.is_nullable = 'NO', false) AS not_nullable FROM information_schema.columns a LEFT JOIN (SELECT indexdef, tablename, @@ -402,7 +403,7 @@ class PostgresDatabase(PostgresExceptionUtil, Database): WHERE tablename='{table_name}') b ON SUBSTRING(b.indexdef, '(.*)') LIKE CONCAT('%', a.column_name, '%') WHERE a.table_name = '{table_name}' - GROUP BY a.column_name, a.data_type, a.column_default, a.character_maximum_length; + GROUP BY a.column_name, a.data_type, a.column_default, a.character_maximum_length, a.is_nullable; """.format( table_name=table_name ), diff --git a/frappe/database/schema.py b/frappe/database/schema.py index 70b2cee244..5cdd995c2c 100644 --- a/frappe/database/schema.py +++ b/frappe/database/schema.py @@ -25,6 +25,7 @@ class DBTable: self.add_column: list[DbColumn] = [] self.change_type: list[DbColumn] = [] self.change_name: list[DbColumn] = [] + self.change_nullability: list[DbColumn] = [] self.add_unique: list[DbColumn] = [] self.add_index: list[DbColumn] = [] self.drop_unique: list[DbColumn] = [] @@ -269,6 +270,10 @@ class DbColumn: ): self.table.set_default.append(self) + # nullability + if self.not_nullable is not None and (self.not_nullable != current_def["not_nullable"]): + self.table.change_nullability.append(self) + # index should be applied or dropped irrespective of type change if (current_def["index"] and not self.set_index) and column_type not in ("text", "longtext"): self.table.drop_index.append(self) diff --git a/frappe/database/utils.py b/frappe/database/utils.py index 7cdab76dda..5d1de5792f 100644 --- a/frappe/database/utils.py +++ b/frappe/database/utils.py @@ -1,7 +1,6 @@ # Copyright (c) 2022, Frappe Technologies Pvt. Ltd. and Contributors # License: MIT. See LICENSE -import typing from functools import cached_property from types import NoneType @@ -9,9 +8,6 @@ import frappe from frappe.query_builder.builder import MariaDB, Postgres from frappe.query_builder.functions import Function -if typing.TYPE_CHECKING: - from frappe.query_builder import DocType - Query = str | MariaDB | Postgres QueryValues = tuple | list | dict | NoneType @@ -27,7 +23,7 @@ NestedSetHierarchy = ( ) -def is_query_type(query: str, query_type: str | tuple[str]) -> bool: +def is_query_type(query: str, query_type: str | tuple[str, ...]) -> bool: return query.lstrip().split(maxsplit=1)[0].lower().startswith(query_type)