-
Notifications
You must be signed in to change notification settings - Fork 541
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into 0.1.0-ctgan
- Loading branch information
Showing
10 changed files
with
381 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,27 @@ | ||
# 运行该例子,可使用: | ||
# ipython -i example/1_ctgan_example.py | ||
# 并查看 sampled_data 变量 | ||
# To run this example, you can use: | ||
# ipython - i example/1_ctgan_example.py | ||
# then view the sampled_data | ||
|
||
from sdgx.models.ml.single_table.ctgan import CTGAN | ||
import numpy as np | ||
|
||
# from sdgx.data_process.sampling.sampler import DataSamplerCTGAN | ||
# from sdgx.data_processors.transformers.transform import DataTransformer | ||
from sdgx.utils import * | ||
from sdgx.metrics.column.jsd import JSD | ||
from sdgx.models.single_table.ctgan import CTGAN | ||
from sdgx.utils.io.csv_utils import * | ||
|
||
# 针对 csv 格式的小规模数据 | ||
# 目前我们以 df 作为输入的数据的格式 | ||
demo_data, discrete_cols = get_demo_single_table() | ||
JSD = JSD() | ||
|
||
model = CTGAN(epochs=10) | ||
model.fit(demo_data, discrete_cols) | ||
|
||
# sampled | ||
sampled_data = model.sample(1000) | ||
print(sampled_data) | ||
|
||
# selected_columns = ["education-num", "fnlwgt"] | ||
# isDiscrete = False | ||
selected_columns = ["workclass"] | ||
isDiscrete = True | ||
metrics = JSD.calculate(demo_data, sampled_data, selected_columns, isDiscrete) | ||
|
||
print("JSD metric of column %s: %g" % (selected_columns[0], metrics)) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import pandas as pd | ||
|
||
from sdgx.log import logger | ||
|
||
|
||
class ColumnMetric(object): | ||
"""ColumnMetric | ||
Metrics used to evaluate the quality of synthetic data columns. | ||
""" | ||
|
||
upper_bound = None | ||
lower_bound = None | ||
metric_name = "Accuracy" | ||
|
||
def __init__(self) -> None: | ||
pass | ||
|
||
@classmethod | ||
def check_input( | ||
cls, real_data: pd.Series | pd.DataFrame, synthetic_data: pd.Series | pd.DataFrame | ||
): | ||
"""Input check for column or table input. | ||
Args: | ||
real_data(pd.DataFrame or pd.Series): the real (original) data table / column. | ||
synthetic_data(pd.DataFrame or pd.Series): the synthetic (generated) data table / column. | ||
""" | ||
|
||
# Input parameter must not contain None value | ||
if real_data is None or synthetic_data is None: | ||
raise TypeError("Input contains None.") | ||
|
||
# The data type should be same | ||
if type(real_data) is not type(synthetic_data): | ||
raise TypeError("Data type of real_data and synthetic data should be the same.") | ||
|
||
# Check some data-types that must not be allowed | ||
if type(real_data) in [int, float, str]: | ||
raise TypeError("real_data's type must not be None, int, float or str") | ||
|
||
# if type is pd.Series, return directly | ||
if isinstance(real_data, pd.Series) or isinstance(real_data, pd.DataFrame): | ||
return real_data, synthetic_data | ||
|
||
# if type is not pd.Series or pd.DataFrame tranfer it to Series | ||
try: | ||
real_data = pd.Series(real_data) | ||
synthetic_data = pd.Series(synthetic_data) | ||
return real_data, synthetic_data | ||
except Exception as e: | ||
logger.error(f"An error occurred while converting to pd.Series: {e}") | ||
|
||
return None, None | ||
|
||
@classmethod | ||
def calculate( | ||
cls, real_data: pd.Series | pd.DataFrame, synthetic_data: pd.Series | pd.DataFrame | ||
): | ||
"""Calculate the metric value between columns between real table and synthetic table. | ||
Args: | ||
real_data(pd.DataFrame or pd.Series): the real (original) data table / column. | ||
synthetic_data(pd.DataFrame or pd.Series): the synthetic (generated) data table / column. | ||
""" | ||
# This method should first check the input | ||
# such as: | ||
real_data, synthetic_data = ColumnMetric.check_input(real_data, synthetic_data) | ||
|
||
raise NotImplementedError() | ||
|
||
@classmethod | ||
def check_output(cls, raw_metric_value: float): | ||
"""Check the output value. | ||
Args: | ||
raw_metric_value (float): the calculated raw value of the JSD metric. | ||
""" | ||
raise NotImplementedError() | ||
|
||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import numpy as np | ||
import pandas as pd | ||
from scipy.stats import entropy, gaussian_kde | ||
|
||
from sdgx.metrics.column.base import ColumnMetric | ||
|
||
|
||
class JSD(ColumnMetric): | ||
"""JSD : Jensen Shannon Divergence | ||
This class is used to calculate the Jensen Shannon divergence value betweenthe target columns of real data and synthetic data. | ||
Currently, we support discrete and continuous columns as inputs. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
self.lower_bound = 0 | ||
self.upper_bound = 1 | ||
self.metric_name = "jensen_shannon_divergence" | ||
|
||
@classmethod | ||
def calculate( | ||
cls, | ||
real_data: pd.DataFrame, | ||
synthetic_data: pd.DataFrame, | ||
cols: list[str] | None, | ||
discrete: bool = True, | ||
) -> pd.DataFrame: | ||
""" | ||
Calculate the JSD value between a real column and a synthetic column. | ||
Args: | ||
real_data (pd.DataFrame): The real data. | ||
synthetic_data (pd.DataFrame): The synthetic data. | ||
cols (list[str]): The target column to calculat JSD metric. | ||
discrete (bool): Whether this column is a discrete column. | ||
Returns: | ||
JSD_val (float): The meteic value. | ||
""" | ||
if discrete: | ||
# 对离散变量求 | ||
JSD.check_input(real_data, synthetic_data) | ||
joint_pd_real = real_data.groupby(cols, dropna=False).size() / len(real_data) | ||
joint_pd_syn = synthetic_data.groupby(cols, dropna=False).size() / len(synthetic_data) | ||
|
||
# 对齐操作 | ||
joint_pdf_values_real, joint_pdf_values_syn = joint_pd_real.align( | ||
joint_pd_syn, fill_value=0 | ||
) | ||
else: | ||
# 对连续列 | ||
# 一个非常大的问题在于求联合概率密度的数组是N^d问题,所以一旦选取3列以上求联合概率密度时间复杂度就不可接受的高,哪怕只取每个值范围只取100个点都算不完 | ||
# 离散列由于是直接用原始数据进行排列求密度,只涉及一次除法,不管多少列都算的很快 | ||
real_data_T = real_data[cols].values.T # 转置 | ||
syn_data_T = synthetic_data[cols].values.T | ||
|
||
# 对连续列估计KDE概率密度 | ||
kde_joint_real = gaussian_kde(real_data_T) | ||
kde_joint_syn = gaussian_kde(syn_data_T) | ||
|
||
# 均匀取点,取值范围选取真实数据集的最大最小范围 | ||
variables_range = [np.linspace(min(col), max(col), 100) for col in real_data_T] | ||
grid_points = np.meshgrid(*variables_range) | ||
grid_points_flat = np.vstack([item.ravel() for item in grid_points]) | ||
|
||
# 计算概率分布数组 | ||
joint_pdf_values_real = ( | ||
kde_joint_real(grid_points_flat).reshape(grid_points[0].shape).ravel() | ||
) | ||
joint_pdf_values_syn = ( | ||
kde_joint_syn(grid_points_flat).reshape(grid_points[0].shape).ravel() | ||
) | ||
|
||
# 传入概率分布数组 | ||
JSD_val = JSD.jensen_shannon_divergence(joint_pdf_values_real, joint_pdf_values_syn) | ||
|
||
JSD.check_output(JSD_val) | ||
|
||
return JSD_val | ||
|
||
@classmethod | ||
def check_output(cls, raw_metric_value: float): | ||
"""Check the output value. | ||
Args: | ||
raw_metric_value (float): the calculated raw value of the JSD metric. | ||
""" | ||
instance = cls() | ||
if raw_metric_value < instance.lower_bound or raw_metric_value > instance.upper_bound: | ||
raise ValueError | ||
|
||
@classmethod | ||
def jensen_shannon_divergence(cls, p: float, q: float): | ||
"""Calculate the jensen_shannon_divergence of p and q. | ||
Args: | ||
p (float): the input parameter p. | ||
q (float): the input parameter q. | ||
""" | ||
# Calculate the average distribution of p and q | ||
m = 0.5 * (p + q) | ||
|
||
# Calculate KL divergence | ||
kl_p = entropy(p, m, base=2) | ||
kl_q = entropy(q, m, base=2) | ||
|
||
# Calculate Jensen Shannon divergence | ||
js_divergence = 0.5 * (kl_p + kl_q) | ||
|
||
return js_divergence |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from sdgx.log import logger | ||
|
||
|
||
class MultiTableMetric: | ||
"""MultiTableMetric | ||
Metrics used to evaluate the quality of synthetic multi-table data. | ||
""" | ||
|
||
upper_bound = None | ||
lower_bound = None | ||
metric_name = None | ||
metadata = None | ||
table_list = [] | ||
|
||
def __init__(self, metadata: dict) -> None: | ||
"""Initialization | ||
Args: | ||
metadata(dict): This parameter accepts a metadata description dict, which is used to describe the table relations and column description information for each table. | ||
""" | ||
self.metadata = metadata | ||
|
||
@classmethod | ||
def check_input(cls, real_data: dict, synthetic_data: dict): | ||
"""Format check for single table input. | ||
The `real_data` and `synthetic_data` should be dict, which contains tables (in pd.DataFrame). | ||
Args: | ||
real_data(dict): the real (original) data table. | ||
synthetic_data(dict): the synthetic (generated) data table. | ||
""" | ||
if real_data is None or synthetic_data is None: | ||
raise TypeError("Input contains None.") | ||
|
||
# The data type should be same | ||
if type(real_data) is not type(synthetic_data): | ||
raise TypeError("Data type of real_data and synthetic data should be the same.") | ||
|
||
# if type is dict, return directly | ||
if ( | ||
isinstance(real_data, dict) | ||
and len(real_data.keys()) > 0 | ||
and len(synthetic_data.keys()) > 0 | ||
): | ||
return real_data, synthetic_data | ||
|
||
logger.error("An error occurred while checking the input.") | ||
|
||
return None, None | ||
|
||
# not a class method | ||
def calculate(self, real_data: dict, synthetic_data: dict): | ||
"""Calculate the metric value between real tables and synthetic tables. | ||
Args: | ||
real_data(dict): the real (original) data table. | ||
synthetic_data(dict): the synthetic (generated) data table. | ||
""" | ||
raise NotImplementedError() | ||
|
||
@classmethod | ||
def check_output(raw_metric_value: float): | ||
"""Check the output value. | ||
Args: | ||
raw_metric_value (float): the calculated raw value of the JSD metric. | ||
""" | ||
raise NotImplementedError() | ||
|
||
pass |
Empty file.
Oops, something went wrong.