From bcffd3e7414a2c708cec9489ec09cd1daf5288ed Mon Sep 17 00:00:00 2001 From: William Conti Date: Mon, 30 Oct 2023 23:03:48 +0100 Subject: [PATCH 1/2] improving error message for not null constraints violated --- src/databricks/labs/ucx/framework/crawlers.py | 14 +++++++++++--- tests/unit/framework/test_crawlers.py | 19 +++++++++++++------ 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/databricks/labs/ucx/framework/crawlers.py b/src/databricks/labs/ucx/framework/crawlers.py index ee63016e80..5b4e3da353 100644 --- a/src/databricks/labs/ucx/framework/crawlers.py +++ b/src/databricks/labs/ucx/framework/crawlers.py @@ -45,14 +45,22 @@ def _schema_for(cls, klass): fields.append(f"{f.name} {spark_type}{not_null}") return ", ".join(fields) + from dataclasses import dataclass, fields, asdict + @classmethod - def _filter_none_rows(cls, rows): + def _filter_none_rows(cls, rows, klass): if len(rows) == 0: return rows + results = [] + class_fields = dataclasses.fields(klass) for row in rows: if row is None: continue + for field in class_fields: + if field.default is not None and getattr(row, field.name) is None: + raise ValueError(f"Not null constraint violated for column {field.name}, " + f"row = {dataclasses.asdict(row)}") results.append(row) return results @@ -75,7 +83,7 @@ def save_table(self, full_name: str, rows: list[any], klass: dataclasses.datacla if mode == "overwrite": msg = "Overwrite mode is not yet supported" raise NotImplementedError(msg) - rows = self._filter_none_rows(rows) + rows = self._filter_none_rows(rows, klass) self.create_table(full_name, klass) if len(rows) == 0: return @@ -126,7 +134,7 @@ def fetch(self, sql) -> Iterator[any]: return self._spark.sql(sql).collect() def save_table(self, full_name: str, rows: list[any], klass: dataclasses.dataclass, mode: str = "append"): - rows = self._filter_none_rows(rows) + rows = self._filter_none_rows(rows, klass) if len(rows) == 0: self.create_table(full_name, klass) diff --git a/tests/unit/framework/test_crawlers.py b/tests/unit/framework/test_crawlers.py index 466ba3930a..bb16344f40 100644 --- a/tests/unit/framework/test_crawlers.py +++ b/tests/unit/framework/test_crawlers.py @@ -213,6 +213,16 @@ def test_runtime_backend_save_table(mocker): def test_runtime_backend_save_table_with_row_containing_none_with_nullable_class(mocker): from unittest import mock + @dataclass + class TestClass: + id: str + value: str = None + + rows = [ + TestClass("1", "test"), + TestClass("2", None), + TestClass(None, "value") + ] with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): pyspark_sql_session = mocker.Mock() @@ -220,10 +230,7 @@ def test_runtime_backend_save_table_with_row_containing_none_with_nullable_class rb = RuntimeBackend() - rb.save_table("a.b.c", [Baz("aaa", "ccc"), Baz("bbb", None)], Bar) + with pytest.raises(Exception) as exc_info: + rb.save_table("a.b.c", rows, TestClass) - rb._spark.createDataFrame.assert_called_with( - [Baz(first="aaa", second="ccc"), Baz(first="bbb", second=None)], - "first STRING NOT NULL, second STRING", - ) - rb._spark.createDataFrame().write.saveAsTable.assert_called_with("a.b.c", mode="append") + assert str(exc_info.value) == "Not null constraint violated for column id, row = {'id': None, 'value': 'value'}" From c803fb2cd3923281b0bdbc57d077ffb3c954d736 Mon Sep 17 00:00:00 2001 From: William Conti Date: Tue, 31 Oct 2023 09:46:13 +0100 Subject: [PATCH 2/2] fixing tests, keeping column preservation order --- src/databricks/labs/ucx/framework/crawlers.py | 9 +++-- tests/unit/framework/test_crawlers.py | 36 +++++++++++-------- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/src/databricks/labs/ucx/framework/crawlers.py b/src/databricks/labs/ucx/framework/crawlers.py index 5b4e3da353..54f7ff7a4c 100644 --- a/src/databricks/labs/ucx/framework/crawlers.py +++ b/src/databricks/labs/ucx/framework/crawlers.py @@ -45,7 +45,7 @@ def _schema_for(cls, klass): fields.append(f"{f.name} {spark_type}{not_null}") return ", ".join(fields) - from dataclasses import dataclass, fields, asdict + from dataclasses import asdict, dataclass, fields @classmethod def _filter_none_rows(cls, rows, klass): @@ -58,9 +58,12 @@ def _filter_none_rows(cls, rows, klass): if row is None: continue for field in class_fields: + if not hasattr(row, field.name): + logger.debug(f"Field {field.name} not present in row {dataclasses.asdict(row)}") + continue if field.default is not None and getattr(row, field.name) is None: - raise ValueError(f"Not null constraint violated for column {field.name}, " - f"row = {dataclasses.asdict(row)}") + msg = f"Not null constraint violated for column {field.name}, row = {dataclasses.asdict(row)}" + raise ValueError(msg) results.append(row) return results diff --git a/tests/unit/framework/test_crawlers.py b/tests/unit/framework/test_crawlers.py index bb16344f40..da078ae122 100644 --- a/tests/unit/framework/test_crawlers.py +++ b/tests/unit/framework/test_crawlers.py @@ -1,6 +1,7 @@ import os import sys from dataclasses import dataclass +from unittest import mock import pytest from databricks.sdk.service import sql @@ -163,8 +164,6 @@ def test_statement_execution_backend_save_table_in_batches_of_two(mocker): def test_runtime_backend_execute(mocker): - from unittest import mock - with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): pyspark_sql_session = mocker.Mock() sys.modules["pyspark.sql.session"] = pyspark_sql_session @@ -177,8 +176,6 @@ def test_runtime_backend_execute(mocker): def test_runtime_backend_fetch(mocker): - from unittest import mock - with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): pyspark_sql_session = mocker.Mock() sys.modules["pyspark.sql.session"] = pyspark_sql_session @@ -194,8 +191,6 @@ def test_runtime_backend_fetch(mocker): def test_runtime_backend_save_table(mocker): - from unittest import mock - with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): pyspark_sql_session = mocker.Mock() sys.modules["pyspark.sql.session"] = pyspark_sql_session @@ -212,17 +207,28 @@ def test_runtime_backend_save_table(mocker): def test_runtime_backend_save_table_with_row_containing_none_with_nullable_class(mocker): - from unittest import mock + with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): + pyspark_sql_session = mocker.Mock() + sys.modules["pyspark.sql.session"] = pyspark_sql_session + + rb = RuntimeBackend() + + rb.save_table("a.b.c", [Baz("aaa", "ccc"), Baz("bbb", None)], Baz) + + rb._spark.createDataFrame.assert_called_with( + [Baz(first="aaa", second="ccc"), Baz(first="bbb", second=None)], + "first STRING NOT NULL, second STRING", + ) + rb._spark.createDataFrame().write.saveAsTable.assert_called_with("a.b.c", mode="append") + + +def test_save_table_with_not_null_constraint_violated(mocker): @dataclass class TestClass: - id: str + key: str value: str = None - rows = [ - TestClass("1", "test"), - TestClass("2", None), - TestClass(None, "value") - ] + rows = [TestClass("1", "test"), TestClass("2", None), TestClass(None, "value")] with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): pyspark_sql_session = mocker.Mock() @@ -233,4 +239,6 @@ class TestClass: with pytest.raises(Exception) as exc_info: rb.save_table("a.b.c", rows, TestClass) - assert str(exc_info.value) == "Not null constraint violated for column id, row = {'id': None, 'value': 'value'}" + assert ( + str(exc_info.value) == "Not null constraint violated for column key, row = {'key': None, 'value': 'value'}" + )