Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added entrypoint for Databricks CLI #84

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,5 @@ derby.log

# Databricks typings
typings/
.databricks
.databricks
.databricks-login.json
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
clean:
rm -fr build .databricks dlt_meta.egg-info

dev:
python3 -m venv .databricks
.databricks/bin/python -m pip install -e .
49 changes: 49 additions & 0 deletions discoverx/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import json
import logging
import sys

from databricks.connect.session import DatabricksSession
from discoverx import DX


logger = logging.getLogger('databricks.labs.discoverx')

def scan(spark, from_tables: str = '*.*.*', rules: str = '*', sample_size: str = '10000', what_if: str = 'false', locale='US'):
logger.info(f'scan: from_tables={from_tables} rules={rules}')
dx = DX(spark=spark, locale=locale)
dx.scan(from_tables=from_tables, rules=rules, sample_size=int(sample_size), what_if='true' == what_if)
print(dx.scan_result.head())


MAPPING = {
'scan': scan,
}


def main(raw):
console_handler = logging.StreamHandler(sys.stderr)
console_handler.setLevel('DEBUG')
logging.root.addHandler(console_handler)

payload = json.loads(raw)
command = payload['command']
if command not in MAPPING:
raise KeyError(f'cannot find command: {command}')
flags = payload['flags']
log_level = flags.pop('log_level')
if log_level != 'disabled':
databricks_logger = logging.getLogger("databricks")
databricks_logger.setLevel(log_level.upper())

kwargs = {k.replace('-', '_'): v for k,v in flags.items()}

try:
spark = DatabricksSession.builder.getOrCreate()
MAPPING[command](spark, **kwargs)
except Exception as e:
logger.error(f'ERROR: {e}')
logger.debug(f'Failed execution of {command}', exc_info=e)


if __name__ == "__main__":
main(*sys.argv[1:])
4 changes: 2 additions & 2 deletions discoverx/discovery.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Optional, List, Union

from discoverx import logging
from discoverx import logs
from discoverx.msql import Msql
from discoverx.table_info import TableInfo
from discoverx.scanner import Scanner, ScanResult
from discoverx.rules import Rules, Rule
from pyspark.sql import SparkSession

logger = logging.Logging()
logger = logs.Logging()


class Discovery:
Expand Down
5 changes: 3 additions & 2 deletions discoverx/dx.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pandas as pd
from pyspark.sql import SparkSession
from typing import List, Optional, Union
from discoverx import logging
from discoverx import logs
from discoverx.explorer import DataExplorer, InfoFetcher
from discoverx.msql import Msql
from discoverx.rules import Rules, Rule
Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(
if spark is None:
spark = SparkSession.getActiveSession()
self.spark = spark
self.logger = logging.Logging()
self.logger = logs.Logging()

self.rules = Rules(custom_rules=custom_rules, locale=locale)
self.uc_enabled = self.spark.conf.get("spark.databricks.unityCatalog.enabled", "false") == "true"
Expand All @@ -49,6 +49,7 @@ def __init__(

def _can_read_columns_table(self) -> bool:
try:
self.logger.debug(f'Verifying if can read from {self.COLUMNS_TABLE_NAME}')
self.spark.sql(f"SELECT * FROM {self.COLUMNS_TABLE_NAME} LIMIT 1")
return True
except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions discoverx/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import copy
import re
from typing import Optional, List
from discoverx import logging
from discoverx import logs
from discoverx.common import helper
from discoverx.discovery import Discovery
from discoverx.rules import Rule
Expand All @@ -13,7 +13,7 @@
from discoverx.table_info import InfoFetcher, TableInfo


logger = logging.Logging()
logger = logs.Logging()


class DataExplorer:
Expand Down
11 changes: 6 additions & 5 deletions discoverx/logging.py → discoverx/logs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
import re

logger = logging.getLogger('databricks.labs.discoverx')

class Logging:
def friendly(self, message):
print(re.sub("<[^<]+?>", "", message))
logging.info(message)
logger.info(message)

def friendlyHTML(self, message):
try:
Expand All @@ -15,15 +16,15 @@ def friendlyHTML(self, message):
except:
# Strip HTML classes
print(re.sub("<[^<]+?>", "", message))
logging.info(message)
logger.info(message)

def info(self, message):
print(message)
logging.info(message)
logger.info(message)

def debug(self, message):
logging.debug(message)
logger.debug(message)

def error(self, message):
print(message)
logging.error(message)
logger.error(message)
4 changes: 2 additions & 2 deletions discoverx/msql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""This module contains the M-SQL compiler"""
from dataclasses import dataclass
from functools import reduce
from discoverx import logging
from discoverx import logs
from discoverx.table_info import ColumnInfo, TableInfo
from discoverx.common.helper import strip_margin
from fnmatch import fnmatch
Expand Down Expand Up @@ -40,7 +40,7 @@ def __init__(self, msql: str) -> None:
# Extract command
self.command = self._extract_command()

self.logger = logging.Logging()
self.logger = logs.Logging()

def compile_msql(self, table_info: TableInfo) -> list[SQLRow]:
"""
Expand Down
4 changes: 2 additions & 2 deletions discoverx/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from pyspark.sql.utils import AnalysisException

from discoverx.common.helper import strip_margin, format_regex
from discoverx import logging
from discoverx import logs
from discoverx.table_info import InfoFetcher, TableInfo
from discoverx.rules import Rules, RuleTypes

logger = logging.Logging()
logger = logs.Logging()


@dataclass
Expand Down
28 changes: 28 additions & 0 deletions labs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
---
name: discoverx
description: Multi-table operations over the Lakehouse
install:
min_runtime_version: 13.1
require_running_cluster: true
require_databricks_connect: true
entrypoint: discoverx/cli.py
min_python: 3.10
commands:
- name: scan
description: Scans the lakehouse
flags:
- name: locale
default: US
description: Locale for scanning
- name: from_tables
default: '*.*.*'
description: The tables to be scanned in format "catalog.schema.table", use "*" as a wildcard. Defaults to "*.*.*".
- name: rules
default: '*'
description: The rule names to be used to scan the lakehouse, use "*" as a wildcard. Defaults to "*".
- name: sample_size
default: 10000
description: The number of rows to be scanned per table. Defaults to 10000.
- name: what_if
default: false
description: Whether to run the scan in what-if mode and print the SQL commands instead of executing them. Defaults to False.
2 changes: 2 additions & 0 deletions notebooks/interaction_commands.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Databricks notebook source

from discoverx import dx

# COMMAND ----------
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/dx_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pandas as pd
import pytest
from discoverx.dx import DX
from discoverx import logging
from discoverx import logs
from pyspark.sql.functions import col

logger = logging.Logging()
logger = logs.Logging()


@pytest.fixture(scope="module", name="dx_ip")
Expand Down
Loading