Skip to content

Commit

Permalink
Update Base Class of Metric (#60)
Browse files Browse the repository at this point in the history
* Update metric base-class

* Update column metric base class

* Update Metric Base Class

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestions from code review

Co-authored-by: Zhongsheng Ji <9573586@qq.com>

* Feature: Add metric jsd (#66) (#71)

* Feature metric jsd (#66)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestions from code review

Co-authored-by: Zhongsheng Ji <9573586@qq.com>

* jsd

* jsd

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* base更新

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Delete .idea directory

---------


Co-authored-by: Jinhang Su <171846802@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: MoooCat <141886018+MooooCat@users.noreply.github.com>
Co-authored-by: Zhongsheng Ji <9573586@qq.com>


* Apply suggestions from code review

Co-authored-by: Zhongsheng Ji <9573586@qq.com>

* Update type hint and comments.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestions from code review

Co-authored-by: Zhongsheng Ji <9573586@qq.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------


Co-authored-by: Jinhang Su <171846802@qq.com>
Co-authored-by: sjh120 <171846802@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Zhongsheng Ji <9573586@qq.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo error

unnecessary imports are also removed.

* Add type hints and comments

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update input check methods

The format of some codes has also been adjusted.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Zhongsheng Ji <9573586@qq.com>
Co-authored-by: Jinhang Su <171846802@qq.com>
  • Loading branch information
4 people committed Dec 18, 2023
1 parent 8896a52 commit 7b66566
Show file tree
Hide file tree
Showing 7 changed files with 368 additions and 23 deletions.
25 changes: 16 additions & 9 deletions example/1_ctgan_example.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
# 运行该例子,可使用:
# ipython -i example/1_ctgan_example.py
# 并查看 sampled_data 变量
# To run this example, you can use:
# ipython - i example/1_ctgan_example.py
# then view the sampled_data

from sdgx.models.ml.single_table.ctgan import CTGAN
import numpy as np

# from sdgx.data_process.sampling.sampler import DataSamplerCTGAN
# from sdgx.data_processors.transformers.transform import DataTransformer
from sdgx.utils import *
from sdgx.metrics.column.jsd import JSD
from sdgx.models.single_table.ctgan import CTGAN
from sdgx.utils.io.csv_utils import *

# 针对 csv 格式的小规模数据
# 目前我们以 df 作为输入的数据的格式
demo_data, discrete_cols = get_demo_single_table()
JSD = JSD()

model = CTGAN(epochs=10)
model.fit(demo_data, discrete_cols)

# sampled
sampled_data = model.sample(1000)
print(sampled_data)

# selected_columns = ["education-num", "fnlwgt"]
# isDiscrete = False
selected_columns = ["workclass"]
isDiscrete = True
metrics = JSD.calculate(demo_data, sampled_data, selected_columns, isDiscrete)

print("JSD metric of column %s: %g" % (selected_columns[0], metrics))
14 changes: 0 additions & 14 deletions sdgx/metrics/base.py

This file was deleted.

83 changes: 83 additions & 0 deletions sdgx/metrics/column/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pandas as pd

from sdgx.log import logger


class ColumnMetric(object):
"""ColumnMetric
Metrics used to evaluate the quality of synthetic data columns.
"""

upper_bound = None
lower_bound = None
metric_name = "Accuracy"

def __init__(self) -> None:
pass

@classmethod
def check_input(
cls, real_data: pd.Series | pd.DataFrame, synthetic_data: pd.Series | pd.DataFrame
):
"""Input check for column or table input.
Args:
real_data(pd.DataFrame or pd.Series): the real (original) data table / column.
synthetic_data(pd.DataFrame or pd.Series): the synthetic (generated) data table / column.
"""

# Input parameter must not contain None value
if real_data is None or synthetic_data is None:
raise TypeError("Input contains None.")

# The data type should be same
if type(real_data) is not type(synthetic_data):
raise TypeError("Data type of real_data and synthetic data should be the same.")

# Check some data-types that must not be allowed
if type(real_data) in [int, float, str]:
raise TypeError("real_data's type must not be None, int, float or str")

# if type is pd.Series, return directly
if isinstance(real_data, pd.Series) or isinstance(real_data, pd.DataFrame):
return real_data, synthetic_data

# if type is not pd.Series or pd.DataFrame tranfer it to Series
try:
real_data = pd.Series(real_data)
synthetic_data = pd.Series(synthetic_data)
return real_data, synthetic_data
except Exception as e:
logger.error(f"An error occurred while converting to pd.Series: {e}")

return None, None

@classmethod
def calculate(
cls, real_data: pd.Series | pd.DataFrame, synthetic_data: pd.Series | pd.DataFrame
):
"""Calculate the metric value between columns between real table and synthetic table.
Args:
real_data(pd.DataFrame or pd.Series): the real (original) data table / column.
synthetic_data(pd.DataFrame or pd.Series): the synthetic (generated) data table / column.
"""
# This method should first check the input
# such as:
real_data, synthetic_data = ColumnMetric.check_input(real_data, synthetic_data)

raise NotImplementedError()

@classmethod
def check_output(cls, raw_metric_value: float):
"""Check the output value.
Args:
raw_metric_value (float): the calculated raw value of the JSD metric.
"""
raise NotImplementedError()

pass
116 changes: 116 additions & 0 deletions sdgx/metrics/column/jsd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import numpy as np
import pandas as pd
from scipy.stats import entropy, gaussian_kde

from sdgx.metrics.column.base import ColumnMetric


class JSD(ColumnMetric):
"""JSD : Jensen Shannon Divergence
This class is used to calculate the Jensen Shannon divergence value betweenthe target columns of real data and synthetic data.
Currently, we support discrete and continuous columns as inputs.
"""

def __init__(self) -> None:
super().__init__()
self.lower_bound = 0
self.upper_bound = 1
self.metric_name = "jensen_shannon_divergence"

@classmethod
def calculate(
cls,
real_data: pd.DataFrame,
synthetic_data: pd.DataFrame,
cols: list[str] | None,
discrete: bool = True,
) -> pd.DataFrame:
"""
Calculate the JSD value between a real column and a synthetic column.
Args:
real_data (pd.DataFrame): The real data.
synthetic_data (pd.DataFrame): The synthetic data.
cols (list[str]): The target column to calculat JSD metric.
discrete (bool): Whether this column is a discrete column.
Returns:
JSD_val (float): The meteic value.
"""
if discrete:
# 对离散变量求
JSD.check_input(real_data, synthetic_data)
joint_pd_real = real_data.groupby(cols, dropna=False).size() / len(real_data)
joint_pd_syn = synthetic_data.groupby(cols, dropna=False).size() / len(synthetic_data)

# 对齐操作
joint_pdf_values_real, joint_pdf_values_syn = joint_pd_real.align(
joint_pd_syn, fill_value=0
)
else:
# 对连续列
# 一个非常大的问题在于求联合概率密度的数组是N^d问题,所以一旦选取3列以上求联合概率密度时间复杂度就不可接受的高,哪怕只取每个值范围只取100个点都算不完
# 离散列由于是直接用原始数据进行排列求密度,只涉及一次除法,不管多少列都算的很快
real_data_T = real_data[cols].values.T # 转置
syn_data_T = synthetic_data[cols].values.T

# 对连续列估计KDE概率密度
kde_joint_real = gaussian_kde(real_data_T)
kde_joint_syn = gaussian_kde(syn_data_T)

# 均匀取点,取值范围选取真实数据集的最大最小范围
variables_range = [np.linspace(min(col), max(col), 100) for col in real_data_T]
grid_points = np.meshgrid(*variables_range)
grid_points_flat = np.vstack([item.ravel() for item in grid_points])

# 计算概率分布数组
joint_pdf_values_real = (
kde_joint_real(grid_points_flat).reshape(grid_points[0].shape).ravel()
)
joint_pdf_values_syn = (
kde_joint_syn(grid_points_flat).reshape(grid_points[0].shape).ravel()
)

# 传入概率分布数组
JSD_val = JSD.jensen_shannon_divergence(joint_pdf_values_real, joint_pdf_values_syn)

JSD.check_output(JSD_val)

return JSD_val

@classmethod
def check_output(cls, raw_metric_value: float):
"""Check the output value.
Args:
raw_metric_value (float): the calculated raw value of the JSD metric.
"""
instance = cls()
if raw_metric_value < instance.lower_bound or raw_metric_value > instance.upper_bound:
raise ValueError

@classmethod
def jensen_shannon_divergence(cls, p: float, q: float):
"""Calculate the jensen_shannon_divergence of p and q.
Args:
p (float): the input parameter p.
q (float): the input parameter q.
"""
# Calculate the average distribution of p and q
m = 0.5 * (p + q)

# Calculate KL divergence
kl_p = entropy(p, m, base=2)
kl_q = entropy(q, m, base=2)

# Calculate Jensen Shannon divergence
js_divergence = 0.5 * (kl_p + kl_q)

return js_divergence
75 changes: 75 additions & 0 deletions sdgx/metrics/multi_table/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from sdgx.log import logger


class MultiTableMetric:
"""MultiTableMetric
Metrics used to evaluate the quality of synthetic multi-table data.
"""

upper_bound = None
lower_bound = None
metric_name = None
metadata = None
table_list = []

def __init__(self, metadata: dict) -> None:
"""Initialization
Args:
metadata(dict): This parameter accepts a metadata description dict, which is used to describe the table relations and column description information for each table.
"""
self.metadata = metadata

@classmethod
def check_input(cls, real_data: dict, synthetic_data: dict):
"""Format check for single table input.
The `real_data` and `synthetic_data` should be dict, which contains tables (in pd.DataFrame).
Args:
real_data(dict): the real (original) data table.
synthetic_data(dict): the synthetic (generated) data table.
"""
if real_data is None or synthetic_data is None:
raise TypeError("Input contains None.")

# The data type should be same
if type(real_data) is not type(synthetic_data):
raise TypeError("Data type of real_data and synthetic data should be the same.")

# if type is dict, return directly
if (
isinstance(real_data, dict)
and len(real_data.keys()) > 0
and len(synthetic_data.keys()) > 0
):
return real_data, synthetic_data

logger.error("An error occurred while checking the input.")

return None, None

# not a class method
def calculate(self, real_data: dict, synthetic_data: dict):
"""Calculate the metric value between real tables and synthetic tables.
Args:
real_data(dict): the real (original) data table.
synthetic_data(dict): the synthetic (generated) data table.
"""
raise NotImplementedError()

@classmethod
def check_output(raw_metric_value: float):
"""Check the output value.
Args:
raw_metric_value (float): the calculated raw value of the JSD metric.
"""
raise NotImplementedError()

pass
Empty file added sdgx/metrics/report/base.py
Empty file.
Loading

0 comments on commit 7b66566

Please sign in to comment.