Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Base Class of Metric #60

Merged
merged 13 commits into from
Dec 18, 2023
25 changes: 16 additions & 9 deletions example/1_ctgan_example.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
# 运行该例子,可使用:
# ipython -i example/1_ctgan_example.py
# 并查看 sampled_data 变量
# To run this example, you can use:
# ipython - i example/1_ctgan_example.py
# then view the sampled_data

from sdgx.models.ml.single_table.ctgan import CTGAN
import numpy as np

# from sdgx.data_process.sampling.sampler import DataSamplerCTGAN
# from sdgx.data_processors.transformers.transform import DataTransformer
from sdgx.utils import *
from sdgx.metrics.column.jsd import JSD
from sdgx.models.single_table.ctgan import CTGAN
from sdgx.utils.io.csv_utils import *

# 针对 csv 格式的小规模数据
# 目前我们以 df 作为输入的数据的格式
demo_data, discrete_cols = get_demo_single_table()
JSD = JSD()

model = CTGAN(epochs=10)
model.fit(demo_data, discrete_cols)

# sampled
sampled_data = model.sample(1000)
print(sampled_data)

# selected_columns = ["education-num", "fnlwgt"]
# isDiscrete = False
selected_columns = ["workclass"]
isDiscrete = True
metrics = JSD.calculate(demo_data, sampled_data, selected_columns, isDiscrete)

print("JSD metric of column %s: %g" % (selected_columns[0], metrics))
14 changes: 0 additions & 14 deletions sdgx/metrics/base.py

This file was deleted.

83 changes: 83 additions & 0 deletions sdgx/metrics/column/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pandas as pd

from sdgx.log import logger


class ColumnMetric(object):
"""ColumnMetric
Metrics used to evaluate the quality of synthetic data columns.
"""

upper_bound = None
lower_bound = None
metric_name = "Accuracy"

def __init__(self) -> None:
pass

@classmethod
def check_input(
cls, real_data: pd.Series | pd.DataFrame, synthetic_data: pd.Series | pd.DataFrame
):
"""Input check for column or table input.
Args:
real_data(pd.DataFrame or pd.Series): the real (original) data table / column.
synthetic_data(pd.DataFrame or pd.Series): the synthetic (generated) data table / column.
"""

# Input parameter must not contain None value
if real_data is None or synthetic_data is None:
raise TypeError("Input contains None.")

# The data type should be same
if type(real_data) is not type(synthetic_data):
raise TypeError("Data type of real_data and synthetic data should be the same.")

# Check some data-types that must not be allowed
if type(real_data) in [int, float, str]:
raise TypeError("real_data's type must not be None, int, float or str")

# if type is pd.Series, return directly
if isinstance(real_data, pd.Series) or isinstance(real_data, pd.DataFrame):
return real_data, synthetic_data

# if type is not pd.Series or pd.DataFrame tranfer it to Series
try:
real_data = pd.Series(real_data)
synthetic_data = pd.Series(synthetic_data)
return real_data, synthetic_data
except Exception as e:
logger.error(f"An error occurred while converting to pd.Series: {e}")

return None, None

@classmethod
def calculate(
cls, real_data: pd.Series | pd.DataFrame, synthetic_data: pd.Series | pd.DataFrame
):
"""Calculate the metric value between columns between real table and synthetic table.
Args:
real_data(pd.DataFrame or pd.Series): the real (original) data table / column.
synthetic_data(pd.DataFrame or pd.Series): the synthetic (generated) data table / column.
"""
# This method should first check the input
# such as:
real_data, synthetic_data = ColumnMetric.check_input(real_data, synthetic_data)

raise NotImplementedError()

@classmethod
def check_output(cls, raw_metric_value: float):
"""Check the output value.
Args:
raw_metric_value (float): the calculated raw value of the JSD metric.
"""
raise NotImplementedError()

pass
116 changes: 116 additions & 0 deletions sdgx/metrics/column/jsd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import numpy as np
import pandas as pd
from scipy.stats import entropy, gaussian_kde

from sdgx.metrics.column.base import ColumnMetric


class JSD(ColumnMetric):
"""JSD : Jensen Shannon Divergence
This class is used to calculate the Jensen Shannon divergence value betweenthe target columns of real data and synthetic data.
Currently, we support discrete and continuous columns as inputs.
"""

def __init__(self) -> None:
super().__init__()
self.lower_bound = 0
self.upper_bound = 1
self.metric_name = "jensen_shannon_divergence"

@classmethod
def calculate(
cls,
real_data: pd.DataFrame,
synthetic_data: pd.DataFrame,
cols: list[str] | None,
discrete: bool = True,
) -> pd.DataFrame:
"""
Calculate the JSD value between a real column and a synthetic column.
Args:
real_data (pd.DataFrame): The real data.
synthetic_data (pd.DataFrame): The synthetic data.
cols (list[str]): The target column to calculat JSD metric.
discrete (bool): Whether this column is a discrete column.
Returns:
JSD_val (float): The meteic value.
"""
if discrete:
# 对离散变量求
JSD.check_input(real_data, synthetic_data)
joint_pd_real = real_data.groupby(cols, dropna=False).size() / len(real_data)
joint_pd_syn = synthetic_data.groupby(cols, dropna=False).size() / len(synthetic_data)

# 对齐操作
joint_pdf_values_real, joint_pdf_values_syn = joint_pd_real.align(
joint_pd_syn, fill_value=0
)
else:
# 对连续列
# 一个非常大的问题在于求联合概率密度的数组是N^d问题,所以一旦选取3列以上求联合概率密度时间复杂度就不可接受的高,哪怕只取每个值范围只取100个点都算不完
# 离散列由于是直接用原始数据进行排列求密度,只涉及一次除法,不管多少列都算的很快
real_data_T = real_data[cols].values.T # 转置
syn_data_T = synthetic_data[cols].values.T

# 对连续列估计KDE概率密度
kde_joint_real = gaussian_kde(real_data_T)
kde_joint_syn = gaussian_kde(syn_data_T)

# 均匀取点,取值范围选取真实数据集的最大最小范围
variables_range = [np.linspace(min(col), max(col), 100) for col in real_data_T]
grid_points = np.meshgrid(*variables_range)
grid_points_flat = np.vstack([item.ravel() for item in grid_points])

# 计算概率分布数组
joint_pdf_values_real = (
kde_joint_real(grid_points_flat).reshape(grid_points[0].shape).ravel()
)
joint_pdf_values_syn = (
kde_joint_syn(grid_points_flat).reshape(grid_points[0].shape).ravel()
)

# 传入概率分布数组
JSD_val = JSD.jensen_shannon_divergence(joint_pdf_values_real, joint_pdf_values_syn)

JSD.check_output(JSD_val)

return JSD_val

@classmethod
def check_output(cls, raw_metric_value: float):
"""Check the output value.
Args:
raw_metric_value (float): the calculated raw value of the JSD metric.
"""
instance = cls()
if raw_metric_value < instance.lower_bound or raw_metric_value > instance.upper_bound:
raise ValueError

@classmethod
def jensen_shannon_divergence(cls, p: float, q: float):
"""Calculate the jensen_shannon_divergence of p and q.
Args:
p (float): the input parameter p.
q (float): the input parameter q.
"""
# Calculate the average distribution of p and q
m = 0.5 * (p + q)

# Calculate KL divergence
kl_p = entropy(p, m, base=2)
kl_q = entropy(q, m, base=2)

# Calculate Jensen Shannon divergence
js_divergence = 0.5 * (kl_p + kl_q)

return js_divergence
75 changes: 75 additions & 0 deletions sdgx/metrics/multi_table/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from sdgx.log import logger


class MultiTableMetric:
"""MultiTableMetric
Metrics used to evaluate the quality of synthetic multi-table data.
"""

upper_bound = None
lower_bound = None
metric_name = None
metadata = None
table_list = []

def __init__(self, metadata: dict) -> None:
"""Initialization
Args:
metadata(dict): This parameter accepts a metadata description dict, which is used to describe the table relations and column description information for each table.
"""
self.metadata = metadata

@classmethod
def check_input(cls, real_data: dict, synthetic_data: dict):
"""Format check for single table input.
The `real_data` and `synthetic_data` should be dict, which contains tables (in pd.DataFrame).
Args:
real_data(dict): the real (original) data table.
synthetic_data(dict): the synthetic (generated) data table.
"""
if real_data is None or synthetic_data is None:
raise TypeError("Input contains None.")

# The data type should be same
if type(real_data) is not type(synthetic_data):
raise TypeError("Data type of real_data and synthetic data should be the same.")

# if type is dict, return directly
if (
isinstance(real_data, dict)
and len(real_data.keys()) > 0
and len(synthetic_data.keys()) > 0
):
return real_data, synthetic_data

logger.error("An error occurred while checking the input.")

return None, None

# not a class method
def calculate(self, real_data: dict, synthetic_data: dict):
"""Calculate the metric value between real tables and synthetic tables.
Args:
real_data(dict): the real (original) data table.
synthetic_data(dict): the synthetic (generated) data table.
"""
raise NotImplementedError()

@classmethod
def check_output(raw_metric_value: float):
"""Check the output value.
Args:
raw_metric_value (float): the calculated raw value of the JSD metric.
"""
raise NotImplementedError()

pass
Empty file added sdgx/metrics/report/base.py
Empty file.
Loading