diff --git a/src/preset_cli/api/clients/dbt.py b/src/preset_cli/api/clients/dbt.py index 842d7775..a83ded18 100644 --- a/src/preset_cli/api/clients/dbt.py +++ b/src/preset_cli/api/clients/dbt.py @@ -31,7 +31,7 @@ class PostelSchema(Schema): """ Be liberal in what you accept, and conservative in what you send. - A schema that allows unknown fields. This way if they API returns new fields that + A schema that allows unknown fields. This way if the API returns new fields that the client is not expecting no errors will be thrown when validating the payload. """ diff --git a/src/preset_cli/cli/superset/sync/native/command.py b/src/preset_cli/cli/superset/sync/native/command.py index ac54c302..b61f74cf 100644 --- a/src/preset_cli/cli/superset/sync/native/command.py +++ b/src/preset_cli/cli/superset/sync/native/command.py @@ -204,28 +204,40 @@ def import_resources_individually( database info, since it's needed), then charts, on so on. It helps troubleshoot problematic exports and large imports. """ - asset_configs: Dict[Path, AssetConfig] - - imports = [ - ("databases", lambda config: []), - ("datasets", lambda config: [config["database_uuid"]]), - ("charts", lambda config: [config["dataset_uuid"]]), - ("dashboards", get_charts_uuids), - ] - related_configs: Dict[str, Dict[Path, AssetConfig]] = {} - for resource_name, get_related_uuids in imports: - for path, config in configs.items(): - if path.parts[1] != resource_name: - continue - - asset_configs = {path: config} - for uuid in get_related_uuids(config): - asset_configs.update(related_configs[uuid]) - - _logger.info("Importing %s", path.relative_to("bundle")) - contents = {str(k): yaml.dump(v) for k, v in asset_configs.items()} - import_resources(contents, client, overwrite) - related_configs[config["uuid"]] = asset_configs + # store progress in case the import stops midway + checkpoint_path = Path("checkpoint.log") + if not checkpoint_path.exists(): + checkpoint_path.touch() + + with open(checkpoint_path, "r+", encoding="utf-8") as log: + imported = {Path(path.strip()) for path in log.readlines()} + asset_configs: Dict[Path, AssetConfig] + imports = [ + ("databases", lambda config: []), + ("datasets", lambda config: [config["database_uuid"]]), + ("charts", lambda config: [config["dataset_uuid"]]), + ("dashboards", get_charts_uuids), + ] + related_configs: Dict[str, Dict[Path, AssetConfig]] = {} + for resource_name, get_related_uuids in imports: + for path, config in configs.items(): + if path.parts[1] != resource_name or path in imported: + continue + + asset_configs = {path: config} + for uuid in get_related_uuids(config): + asset_configs.update(related_configs[uuid]) + + _logger.info("Importing %s", path.relative_to("bundle")) + contents = {str(k): yaml.dump(v) for k, v in asset_configs.items()} + import_resources(contents, client, overwrite) + related_configs[config["uuid"]] = asset_configs + + imported.add(path) + log.write(str(path) + "\n") + log.flush() + + os.unlink(checkpoint_path) def get_charts_uuids(config: AssetConfig) -> Iterator[str]: diff --git a/tests/cli/superset/sync/native/command_test.py b/tests/cli/superset/sync/native/command_test.py index 40515066..01195ec4 100644 --- a/tests/cli/superset/sync/native/command_test.py +++ b/tests/cli/superset/sync/native/command_test.py @@ -607,10 +607,10 @@ def test_import_resources_individually_retries( requests.exceptions.ConnectionError("Connection aborted."), None, ] - contents = { + configs = { Path("bundle/databases/gsheets.yaml"): {"name": "my database", "uuid": "uuid1"}, } - import_resources_individually(contents, client, overwrite=True) + import_resources_individually(configs, client, overwrite=True) client.import_zip.side_effect = [ requests.exceptions.ConnectionError("Connection aborted."), @@ -620,5 +620,75 @@ def test_import_resources_individually_retries( requests.exceptions.ConnectionError("Connection aborted."), ] with pytest.raises(Exception) as excinfo: - import_resources_individually(contents, client, overwrite=True) + import_resources_individually(configs, client, overwrite=True) assert str(excinfo.value) == "Connection aborted." + + +def test_import_resources_individually_checkpoint( + mocker: MockerFixture, + fs: FakeFilesystem, # pylint: disable=unused-argument +) -> None: + """ + Test checkpoint in ``import_resources_individually``. + """ + client = mocker.MagicMock() + configs = { + Path("bundle/databases/gsheets.yaml"): {"name": "my database", "uuid": "uuid1"}, + Path("bundle/databases/psql.yaml"): { + "name": "my other database", + "uuid": "uuid2", + }, + } + import_resources = mocker.patch( + "preset_cli.cli.superset.sync.native.command.import_resources", + ) + import_resources.side_effect = [None, Exception("An error occurred!"), None] + + with pytest.raises(Exception) as excinfo: + import_resources_individually(configs, client, overwrite=True) + assert str(excinfo.value) == "An error occurred!" + + import_resources.assert_has_calls( + [ + mocker.call( + { + "bundle/databases/gsheets.yaml": yaml.dump( + {"name": "my database", "uuid": "uuid1"}, + ), + }, + client, + True, + ), + mocker.call( + { + "bundle/databases/psql.yaml": yaml.dump( + {"name": "my other database", "uuid": "uuid2"}, + ), + }, + client, + True, + ), + ], + ) + + with open("checkpoint.log", encoding="utf-8") as log: + assert log.read() == "bundle/databases/gsheets.yaml\n" + + # retry + import_resources.mock_reset() + import_resources_individually(configs, client, overwrite=True) + import_resources.assert_has_calls( + [ + mock.call( + { + "bundle/databases/psql.yaml": yaml.dump( + {"name": "my other database", "uuid": "uuid2"}, + ), + }, + client, + True, + ), + ], + ) + + assert not Path("checkpoint.log").exists()