-
Notifications
You must be signed in to change notification settings - Fork 541
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Update metric base-class * Update column metric base class * Update Metric Base Class * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: Zhongsheng Ji <9573586@qq.com> * Feature: Add metric jsd (#66) (#71) * Feature metric jsd (#66) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: Zhongsheng Ji <9573586@qq.com> * jsd * jsd * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * base更新 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Delete .idea directory --------- Co-authored-by: Jinhang Su <171846802@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: MoooCat <141886018+MooooCat@users.noreply.github.com> Co-authored-by: Zhongsheng Ji <9573586@qq.com> * Apply suggestions from code review Co-authored-by: Zhongsheng Ji <9573586@qq.com> * Update type hint and comments. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: Zhongsheng Ji <9573586@qq.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Jinhang Su <171846802@qq.com> Co-authored-by: sjh120 <171846802@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Zhongsheng Ji <9573586@qq.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo error unnecessary imports are also removed. * Add type hints and comments * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update input check methods The format of some codes has also been adjusted. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Zhongsheng Ji <9573586@qq.com> Co-authored-by: Jinhang Su <171846802@qq.com>
- Loading branch information
1 parent
8896a52
commit 7b66566
Showing
7 changed files
with
368 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,27 @@ | ||
# 运行该例子,可使用: | ||
# ipython -i example/1_ctgan_example.py | ||
# 并查看 sampled_data 变量 | ||
# To run this example, you can use: | ||
# ipython - i example/1_ctgan_example.py | ||
# then view the sampled_data | ||
|
||
from sdgx.models.ml.single_table.ctgan import CTGAN | ||
import numpy as np | ||
|
||
# from sdgx.data_process.sampling.sampler import DataSamplerCTGAN | ||
# from sdgx.data_processors.transformers.transform import DataTransformer | ||
from sdgx.utils import * | ||
from sdgx.metrics.column.jsd import JSD | ||
from sdgx.models.single_table.ctgan import CTGAN | ||
from sdgx.utils.io.csv_utils import * | ||
|
||
# 针对 csv 格式的小规模数据 | ||
# 目前我们以 df 作为输入的数据的格式 | ||
demo_data, discrete_cols = get_demo_single_table() | ||
JSD = JSD() | ||
|
||
model = CTGAN(epochs=10) | ||
model.fit(demo_data, discrete_cols) | ||
|
||
# sampled | ||
sampled_data = model.sample(1000) | ||
print(sampled_data) | ||
|
||
# selected_columns = ["education-num", "fnlwgt"] | ||
# isDiscrete = False | ||
selected_columns = ["workclass"] | ||
isDiscrete = True | ||
metrics = JSD.calculate(demo_data, sampled_data, selected_columns, isDiscrete) | ||
|
||
print("JSD metric of column %s: %g" % (selected_columns[0], metrics)) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import pandas as pd | ||
|
||
from sdgx.log import logger | ||
|
||
|
||
class ColumnMetric(object): | ||
"""ColumnMetric | ||
Metrics used to evaluate the quality of synthetic data columns. | ||
""" | ||
|
||
upper_bound = None | ||
lower_bound = None | ||
metric_name = "Accuracy" | ||
|
||
def __init__(self) -> None: | ||
pass | ||
|
||
@classmethod | ||
def check_input( | ||
cls, real_data: pd.Series | pd.DataFrame, synthetic_data: pd.Series | pd.DataFrame | ||
): | ||
"""Input check for column or table input. | ||
Args: | ||
real_data(pd.DataFrame or pd.Series): the real (original) data table / column. | ||
synthetic_data(pd.DataFrame or pd.Series): the synthetic (generated) data table / column. | ||
""" | ||
|
||
# Input parameter must not contain None value | ||
if real_data is None or synthetic_data is None: | ||
raise TypeError("Input contains None.") | ||
|
||
# The data type should be same | ||
if type(real_data) is not type(synthetic_data): | ||
raise TypeError("Data type of real_data and synthetic data should be the same.") | ||
|
||
# Check some data-types that must not be allowed | ||
if type(real_data) in [int, float, str]: | ||
raise TypeError("real_data's type must not be None, int, float or str") | ||
|
||
# if type is pd.Series, return directly | ||
if isinstance(real_data, pd.Series) or isinstance(real_data, pd.DataFrame): | ||
return real_data, synthetic_data | ||
|
||
# if type is not pd.Series or pd.DataFrame tranfer it to Series | ||
try: | ||
real_data = pd.Series(real_data) | ||
synthetic_data = pd.Series(synthetic_data) | ||
return real_data, synthetic_data | ||
except Exception as e: | ||
logger.error(f"An error occurred while converting to pd.Series: {e}") | ||
|
||
return None, None | ||
|
||
@classmethod | ||
def calculate( | ||
cls, real_data: pd.Series | pd.DataFrame, synthetic_data: pd.Series | pd.DataFrame | ||
): | ||
"""Calculate the metric value between columns between real table and synthetic table. | ||
Args: | ||
real_data(pd.DataFrame or pd.Series): the real (original) data table / column. | ||
synthetic_data(pd.DataFrame or pd.Series): the synthetic (generated) data table / column. | ||
""" | ||
# This method should first check the input | ||
# such as: | ||
real_data, synthetic_data = ColumnMetric.check_input(real_data, synthetic_data) | ||
|
||
raise NotImplementedError() | ||
|
||
@classmethod | ||
def check_output(cls, raw_metric_value: float): | ||
"""Check the output value. | ||
Args: | ||
raw_metric_value (float): the calculated raw value of the JSD metric. | ||
""" | ||
raise NotImplementedError() | ||
|
||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import numpy as np | ||
import pandas as pd | ||
from scipy.stats import entropy, gaussian_kde | ||
|
||
from sdgx.metrics.column.base import ColumnMetric | ||
|
||
|
||
class JSD(ColumnMetric): | ||
"""JSD : Jensen Shannon Divergence | ||
This class is used to calculate the Jensen Shannon divergence value betweenthe target columns of real data and synthetic data. | ||
Currently, we support discrete and continuous columns as inputs. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
self.lower_bound = 0 | ||
self.upper_bound = 1 | ||
self.metric_name = "jensen_shannon_divergence" | ||
|
||
@classmethod | ||
def calculate( | ||
cls, | ||
real_data: pd.DataFrame, | ||
synthetic_data: pd.DataFrame, | ||
cols: list[str] | None, | ||
discrete: bool = True, | ||
) -> pd.DataFrame: | ||
""" | ||
Calculate the JSD value between a real column and a synthetic column. | ||
Args: | ||
real_data (pd.DataFrame): The real data. | ||
synthetic_data (pd.DataFrame): The synthetic data. | ||
cols (list[str]): The target column to calculat JSD metric. | ||
discrete (bool): Whether this column is a discrete column. | ||
Returns: | ||
JSD_val (float): The meteic value. | ||
""" | ||
if discrete: | ||
# 对离散变量求 | ||
JSD.check_input(real_data, synthetic_data) | ||
joint_pd_real = real_data.groupby(cols, dropna=False).size() / len(real_data) | ||
joint_pd_syn = synthetic_data.groupby(cols, dropna=False).size() / len(synthetic_data) | ||
|
||
# 对齐操作 | ||
joint_pdf_values_real, joint_pdf_values_syn = joint_pd_real.align( | ||
joint_pd_syn, fill_value=0 | ||
) | ||
else: | ||
# 对连续列 | ||
# 一个非常大的问题在于求联合概率密度的数组是N^d问题,所以一旦选取3列以上求联合概率密度时间复杂度就不可接受的高,哪怕只取每个值范围只取100个点都算不完 | ||
# 离散列由于是直接用原始数据进行排列求密度,只涉及一次除法,不管多少列都算的很快 | ||
real_data_T = real_data[cols].values.T # 转置 | ||
syn_data_T = synthetic_data[cols].values.T | ||
|
||
# 对连续列估计KDE概率密度 | ||
kde_joint_real = gaussian_kde(real_data_T) | ||
kde_joint_syn = gaussian_kde(syn_data_T) | ||
|
||
# 均匀取点,取值范围选取真实数据集的最大最小范围 | ||
variables_range = [np.linspace(min(col), max(col), 100) for col in real_data_T] | ||
grid_points = np.meshgrid(*variables_range) | ||
grid_points_flat = np.vstack([item.ravel() for item in grid_points]) | ||
|
||
# 计算概率分布数组 | ||
joint_pdf_values_real = ( | ||
kde_joint_real(grid_points_flat).reshape(grid_points[0].shape).ravel() | ||
) | ||
joint_pdf_values_syn = ( | ||
kde_joint_syn(grid_points_flat).reshape(grid_points[0].shape).ravel() | ||
) | ||
|
||
# 传入概率分布数组 | ||
JSD_val = JSD.jensen_shannon_divergence(joint_pdf_values_real, joint_pdf_values_syn) | ||
|
||
JSD.check_output(JSD_val) | ||
|
||
return JSD_val | ||
|
||
@classmethod | ||
def check_output(cls, raw_metric_value: float): | ||
"""Check the output value. | ||
Args: | ||
raw_metric_value (float): the calculated raw value of the JSD metric. | ||
""" | ||
instance = cls() | ||
if raw_metric_value < instance.lower_bound or raw_metric_value > instance.upper_bound: | ||
raise ValueError | ||
|
||
@classmethod | ||
def jensen_shannon_divergence(cls, p: float, q: float): | ||
"""Calculate the jensen_shannon_divergence of p and q. | ||
Args: | ||
p (float): the input parameter p. | ||
q (float): the input parameter q. | ||
""" | ||
# Calculate the average distribution of p and q | ||
m = 0.5 * (p + q) | ||
|
||
# Calculate KL divergence | ||
kl_p = entropy(p, m, base=2) | ||
kl_q = entropy(q, m, base=2) | ||
|
||
# Calculate Jensen Shannon divergence | ||
js_divergence = 0.5 * (kl_p + kl_q) | ||
|
||
return js_divergence |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from sdgx.log import logger | ||
|
||
|
||
class MultiTableMetric: | ||
"""MultiTableMetric | ||
Metrics used to evaluate the quality of synthetic multi-table data. | ||
""" | ||
|
||
upper_bound = None | ||
lower_bound = None | ||
metric_name = None | ||
metadata = None | ||
table_list = [] | ||
|
||
def __init__(self, metadata: dict) -> None: | ||
"""Initialization | ||
Args: | ||
metadata(dict): This parameter accepts a metadata description dict, which is used to describe the table relations and column description information for each table. | ||
""" | ||
self.metadata = metadata | ||
|
||
@classmethod | ||
def check_input(cls, real_data: dict, synthetic_data: dict): | ||
"""Format check for single table input. | ||
The `real_data` and `synthetic_data` should be dict, which contains tables (in pd.DataFrame). | ||
Args: | ||
real_data(dict): the real (original) data table. | ||
synthetic_data(dict): the synthetic (generated) data table. | ||
""" | ||
if real_data is None or synthetic_data is None: | ||
raise TypeError("Input contains None.") | ||
|
||
# The data type should be same | ||
if type(real_data) is not type(synthetic_data): | ||
raise TypeError("Data type of real_data and synthetic data should be the same.") | ||
|
||
# if type is dict, return directly | ||
if ( | ||
isinstance(real_data, dict) | ||
and len(real_data.keys()) > 0 | ||
and len(synthetic_data.keys()) > 0 | ||
): | ||
return real_data, synthetic_data | ||
|
||
logger.error("An error occurred while checking the input.") | ||
|
||
return None, None | ||
|
||
# not a class method | ||
def calculate(self, real_data: dict, synthetic_data: dict): | ||
"""Calculate the metric value between real tables and synthetic tables. | ||
Args: | ||
real_data(dict): the real (original) data table. | ||
synthetic_data(dict): the synthetic (generated) data table. | ||
""" | ||
raise NotImplementedError() | ||
|
||
@classmethod | ||
def check_output(raw_metric_value: float): | ||
"""Check the output value. | ||
Args: | ||
raw_metric_value (float): the calculated raw value of the JSD metric. | ||
""" | ||
raise NotImplementedError() | ||
|
||
pass |
Empty file.
Oops, something went wrong.