Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[0.1.0]Breaking changes: Reactoring models Part 1 #68

Merged
merged 11 commits into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,5 @@ cython_debug/
# End of https://www.toptal.com/developers/gitignore/api/macos,emacs,python

*.log

.sdgx_cache
5 changes: 3 additions & 2 deletions example/1_ctgan_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
# ipython -i example/1_ctgan_example.py
# 并查看 sampled_data 变量

from sdgx.models.single_table.ctgan import CTGAN
from sdgx.models.ml.single_table.ctgan import CTGAN

# from sdgx.data_process.sampling.sampler import DataSamplerCTGAN
# from sdgx.data_processors.transformers.transform import DataTransformer
from sdgx.utils.io.csv_utils import *
from sdgx.utils import *

# 针对 csv 格式的小规模数据
# 目前我们以 df 作为输入的数据的格式
Expand All @@ -17,3 +17,4 @@

# sampled
sampled_data = model.sample(1000)
print(sampled_data)
2 changes: 1 addition & 1 deletion example/2_guassian_copula_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# 并查看 sampled_data 变量

from sdgx.models.statistics.single_table.copula import GaussianCopulaSynthesizer
from sdgx.utils.io.csv_utils import *
from sdgx.utils import *

# 针对 csv 格式的小规模数据
# 目前我们以 df 作为输入的数据的格式
Expand Down
2 changes: 1 addition & 1 deletion sdgx/data_connectors/csv_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def identity(self):
Identity of the data source is the sha256 of the file
"""
with open(self.path, "rb") as f:
return hashlib.sha256(f.read()).hexdigest()
return f"csvfile-{hashlib.sha256(f.read()).hexdigest()}"

def __init__(
self,
Expand Down
6 changes: 4 additions & 2 deletions sdgx/data_connectors/generator_connector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import os
from functools import cached_property
from typing import Callable, Generator

import pandas as pd
Expand All @@ -21,9 +23,9 @@ class GeneratorConnector(DataConnector):
This connector is not been registered by default. So only be used with the library way.
"""

@property
@cached_property
def identity(self) -> str:
return f"{id(self.generator_caller)}"
return f"generator-{os.getpid()}-{id(self.generator_caller)}"

def __init__(
self,
Expand Down
156 changes: 156 additions & 0 deletions sdgx/data_connectors/mysql_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""
Move from utils/io/mysql.py

TODO: Implement as MySQLConnector
"""

# import os

# import pymysql
# from config import sdg_log_dir
# from pymysql.err import IntegrityError, OperationalError

# # logger 需要解决
# from sdgx.utils import logger


# class mysql_db_connector:
# def __init__(
# self,
# databaes_ip,
# database_port,
# database_username,
# database_passwd,
# database_name,
# database_type="mysql",
# ) -> None:
# # 配置log路径
# if not os.path.exists(sdg_log_dir):
# os.makedirs(sdg_log_dir)
# log_path = os.path.join(sdg_log_dir, "mysql_log.log")
# log_path_handler = get_log_file_handler(log_path)
# logger.addHandler(log_path_handler)

# # 进行初始化
# self.database_type = database_type
# self.databaes_ip = databaes_ip
# self.database_port = database_port
# self.database_username = database_username
# self.database_passwd = database_passwd
# self.database_name = database_name

# # 记录状态
# self.conn = None
# self.status = "waiting"
# self.insert_success = False
# self.del_success = False

# # 建立一个 conn,
# if self.database_type == "mysql":
# logger.info(
# "Establish Mysql Connection with %s:%s" % (self.databaes_ip, self.database_port)
# )
# try:
# self.conn = pymysql.connect(
# host=self.databaes_ip,
# port=self.database_port,
# user=self.database_username,
# password=self.database_passwd,
# database=self.database_name,
# )
# self.status = "success"
# logger.info("MySQL Connection Created.")
# except ConnectionRefusedError:
# self.status = "failed"
# logger.error("MySQL Connection Refused")
# except OperationalError:
# logger.error("MySQL OperationalError")
# self.status = "failed"
# else:
# logger.error("Not support other Database yet.")
# raise NotImplementedError("Not support other Database yet.")
# pass

# def query(self, sql):
# # 查询动作,把数据给提出来,以列表形式返回
# # assert 'select' in sql.lowercase()
# if "select" not in sql.lower():
# raise ValueError("NOT a Query SQL")
# cursor = self.conn.cursor()
# select_cnt = cursor.execute(sql) # 得到结果行数
# res = []
# for _ in range(select_cnt):
# sample = list(cursor.fetchone())
# res.append(sample)
# cursor.close()
# logger.debug('Query: "%s", Success' % sql.replace("\n", " ").strip())
# return res # 结果一定是二维数组

# def insert(self, sql):
# # INSERT INTO `fed_sql`.`b_job` (`id`, `job_id`, `job_sql`, `job_ priority`, `grammar_check`, `integrity_check`, `status`, `create_time`, `update_time`, `start_time`) VALUES ('49', '1', 'selec * from A.xx;', '2', '1', '1', 'success', '1657871901', '1657871901', '1657871901');
# if "insert" not in sql.lower() and "update" not in sql.lower():
# raise ValueError("NOT a Insert SQL")
# cursor = self.conn.cursor()
# try:
# cursor.execute(sql)
# cursor.close()
# except IntegrityError:
# logger.error('Insert: "%s", Failed' % sql.replace("\n", " "))
# self.insert_success = False
# cursor.close()
# self.conn.commit()
# if "insert" in sql.lower():
# logger.debug('Insert: "%s", Success' % sql.replace("\n", " "))
# elif "update" in sql.lower():
# logger.debug('Update: "%s", Success' % sql.replace("\n", " "))
# self.insert_success = True

# def delete(self, sql):
# # DELETE FROM `fed_sql`.`s_party` WHERE (`id` = '3');
# if "delete" not in sql.lower():
# raise ValueError("NOT a Delete SQL")
# cursor = self.conn.cursor()
# try:
# cursor.execute(sql)
# cursor.close()
# except Exception as e:
# logger.error('Delete: "%s", Failed' % sql.replace("\n", " "))
# logger.error(str(e))
# self.del_success = False
# cursor.close()
# self.conn.commit()
# logger.debug('Delete: "%s", Success' % sql.replace("\n", " "))
# self.del_success = True

# def get_tables(self):
# sql = "show tables;"
# cursor = self.conn.cursor()
# table_cnt = cursor.execute(sql)
# res = []
# for _ in range(table_cnt):
# table = cursor.fetchone()[0]
# res.append(table)
# return res

# def create_table(self, sql):
# if "create" not in sql.lower():
# raise ValueError("NOT a Create SQL")
# cursor = self.conn.cursor()
# cursor.execute(sql)
# self.conn.commit()
# cursor.close()

# def __del__(self):
# # 关闭数据库连接
# if self.status == "failed":
# logger.info("NO MySQL Connection, Closed.")
# elif self.status != "closed":
# self.conn.close()
# logger.info("MySQL Connection Closed.")

# # 手工 close
# # 一般不太用得到
# def close(self):
# self.conn.close()
# self.status = "closed"
# logger.info("MySQL Connection Closed in func <close>.")
2 changes: 1 addition & 1 deletion sdgx/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class DataLoader:
data_connector (:ref:`DataConnector`): The data connector
chunksize (int, optional): The chunksize of the cacher. Defaults to 1000.
cacher (:ref:`Cacher`, optional): The cacher. Defaults to None.
cache_mode (str, optional): The cache mode(name). Defaults to "DiskCache".
cache_mode (str, optional): The cache mode(cachers' name). Defaults to "DiskCache", more info in :ref:`DiskCache`.
cacher_kwargs (dict, optional): The kwargs for cacher. Defaults to None
"""

Expand Down
15 changes: 10 additions & 5 deletions sdgx/data_models/inspectors/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import pandas as pd
from __future__ import annotations

from typing import Any

from sdgx.data_models.inspectors.inspect_meta import InspectMeta
import pandas as pd


class Inspector:
"""
Base Inspector class

Inspector is used to inspect data and generate metadata automatically.

Parameters:
ready (bool): Ready to inspect, maybe all fields are fitted, or indicate if there is more data, inspector will be more precise.
"""

ready: bool
"""Ready to inspect, maybe all fields are fitted."""
def __init__(self, *args, **kwargs):
self.ready: bool = False

def fit(self, raw_data: pd.DataFrame):
"""Fit the inspector.
Expand All @@ -21,5 +26,5 @@ def fit(self, raw_data: pd.DataFrame):
"""
return

def inspect(self) -> InspectMeta:
def inspect(self) -> dict[str, Any]:
"""Inspect raw data and generate metadata."""
38 changes: 38 additions & 0 deletions sdgx/data_models/inspectors/discrete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from typing import Any

import pandas as pd

from sdgx.data_models.inspectors.base import Inspector
from sdgx.data_models.inspectors.extension import hookimpl


class DiscreteInspector(Inspector):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.discrete_columns: set[str] = set()

def fit(self, raw_data: pd.DataFrame):
"""Fit the inspector.

Gets the list of discrete columns from the raw data.

Args:
raw_data (pd.DataFrame): Raw data
"""

self.discrete_columns = self.discrete_columns.union(
set(raw_data.select_dtypes(include="object").columns)
)
self.ready = True

def inspect(self) -> dict[str, Any]:
"""Inspect raw data and generate metadata."""

return {"discrete_columns": list(self.discrete_columns)}


@hookimpl
def register(manager):
manager.register("DiscreteInspector", DiscreteInspector)
5 changes: 0 additions & 5 deletions sdgx/data_models/inspectors/inspect_meta.py

This file was deleted.

44 changes: 30 additions & 14 deletions sdgx/data_models/metadata.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,48 @@
from __future__ import annotations

from enum import Enum
from typing import Any
from typing import Any, Dict, List

import pandas as pd
from pydantic import BaseModel

from sdgx.data_loader import DataLoader
from sdgx.data_models.inspectors.inspect_meta import InspectMeta
from sdgx.data_models.inspectors.manager import InspectorManager
from sdgx.exceptions import MetadataInitError
from sdgx.utils import cache

# TODO: Design metadata for relationships...
# class DType(Enum):
# datetime = "datetime"
# timestamp = "timestamp"
# numeric = "numeric"
# category = "category"

class DType(Enum):
datetime = "datetime"
timestamp = "timestamp"
numeric = "numeric"
category = "category"


class Relationship:
pass
# class Relationship:
# pass


class Metadata(BaseModel):
# fields: List[str]
discrete_columns: List[str] = []
_extend: Dict[str, Any] = {}

def get(self, key: str, default=None) -> Any:
return getattr(self, key, getattr(self._extend, key, default))

def set(self, key: str, value: Any):
if key == "_extend":
raise MetadataInitError("Cannot set _extend directly")

if key in self.model_fields:
setattr(self, key, value)
else:
self._extend[key] = value

def update(self, attributes: dict[str, Any]):
for k, v in attributes.items():
self.set(k, v)

def update(self, inspect_meta: InspectMeta):
return self

@classmethod
Expand Down Expand Up @@ -68,7 +85,6 @@ def from_dataframe(

metadata = Metadata()
for inspector in inspectors:
if inspector.ready:
metadata.update(inspector.inspect())
metadata.update(inspector.inspect())

return metadata
6 changes: 6 additions & 0 deletions sdgx/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,9 @@ class CacheError(SdgxError):
"""
Exception to indicate that exception when using cache.
"""


class MetadataInitError(SdgxError):
"""
Exception to indicate that exception when initializing metadata.
"""
2 changes: 1 addition & 1 deletion sdgx/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sdgx import models
from sdgx.exceptions import InitializationError, NotFoundError, RegisterError
from sdgx.log import logger
from sdgx.utils.utils import Singleton
from sdgx.utils import Singleton


class Manager(metaclass=Singleton):
Expand Down
Loading