Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mutual information metric #101

Merged
merged 64 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
3dcd842
test
Z712023 Jan 4, 2024
34af206
test_v2
Z712023 Jan 4, 2024
9bf9108
no-test
Z712023 Jan 5, 2024
7fd6c75
pair_v1
Z712023 Jan 9, 2024
a6ad779
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2024
05223ea
remove_old_mi_sim
Z712023 Jan 10, 2024
a692a54
remove mi_sim in columns
Z712023 Jan 10, 2024
730bd9b
modify single&multi_table MISim
Z712023 Jan 10, 2024
b100dd9
modify single_mi_sim by using pair_sim instance
Z712023 Jan 10, 2024
40a19c5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2024
0031062
modify multi_mi_sim by using pair_sim instance
Z712023 Jan 10, 2024
88eaa2a
modify multi_mi_sim by using pair_sim instance
Z712023 Jan 10, 2024
1c4026b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2024
8c333dd
change_class_name_err
Z712023 Jan 10, 2024
032df09
Merge branch 'feature-metric-mutual_information' of github.com:hitsz-…
Z712023 Jan 10, 2024
0ebf11d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2024
844c13d
modify_paircolumn
Z712023 Jan 10, 2024
ca53f1a
mi only needs dataframe
Z712023 Jan 10, 2024
e4efe41
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2024
93018e0
Merge branch 'main' into feature-metric-mutual_information
MooooCat Jan 16, 2024
f3ffab7
modify based on review
Z712023 Jan 16, 2024
8583aae
test
Z712023 Jan 16, 2024
a1fb0ec
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
b0c4282
complete test_mi_sim
Z712023 Jan 16, 2024
f9024e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
fe7b080
modify test file
Z712023 Jan 16, 2024
df1e572
change_var_name
Z712023 Jan 16, 2024
dd08734
Update sdgx/metrics/multi_table/multitable_mi_sim.py
Z712023 Jan 16, 2024
3264704
add MULTI_TABLE_DEMO_DATA
Z712023 Jan 16, 2024
63854ac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
6e2ee7b
modify comments
Z712023 Jan 16, 2024
e85571e
JSD->MISIM
Z712023 Jan 16, 2024
a0ac893
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
aafa3cc
modify base of pair_column
Z712023 Jan 16, 2024
84efc35
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
c55b98d
add cls
Z712023 Jan 16, 2024
d69af4e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
0ddde82
change self into cls instance
Z712023 Jan 16, 2024
89d2bca
Merge branch 'feature-metric-mutual_information' of github.com:hitsz-…
Z712023 Jan 16, 2024
52720f0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
94fb07a
change cls
Z712023 Jan 16, 2024
0190305
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
8384b06
series2array
Z712023 Jan 16, 2024
ac4ad49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
c1503cd
test
Z712023 Jan 16, 2024
bb83518
Merge branch 'feature-metric-mutual_information' of github.com:hitsz-…
Z712023 Jan 16, 2024
5481df2
test
Z712023 Jan 16, 2024
c649868
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
110b63a
add label_encoder for category in mi_sim
Z712023 Jan 16, 2024
107587a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
7ffa05b
use series.array
Z712023 Jan 16, 2024
5847e35
change le_fit
Z712023 Jan 16, 2024
cee6dac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
96e2c0d
change transform type to np.array instead of list
Z712023 Jan 16, 2024
3df03a8
add astype
Z712023 Jan 16, 2024
7a8f766
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
03c9fd2
series2array
Z712023 Jan 16, 2024
d4c949a
foo
Z712023 Jan 16, 2024
fe956d5
change test_suit
Z712023 Jan 16, 2024
15cdebc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
e2a0db5
all right?
Z712023 Jan 16, 2024
0e21332
all right
Z712023 Jan 16, 2024
9244401
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
8020a9a
Merge branch 'main' into feature-metric-mutual_information
MooooCat Jan 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions sdgx/metrics/column/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,8 @@ def calculate(
cls, real_data: pd.Series | pd.DataFrame, synthetic_data: pd.Series | pd.DataFrame
):
"""Calculate the metric value between columns between real table and synthetic table.

Args:
real_data(pd.DataFrame or pd.Series): the real (original) data table / column.

synthetic_data(pd.DataFrame or pd.Series): the synthetic (generated) data table / column.
"""
# This method should first check the input
Expand Down
87 changes: 87 additions & 0 deletions sdgx/metrics/multi_table/multitable_mi_sim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import time
from datetime import datetime

import numpy as np
import pandas as pd
from scipy.stats import entropy
from sklearn.metrics.cluster import normalized_mutual_info_score

from sdgx.metrics.multi_table.base import MultiTableMetric
from sdgx.metrics.pair_column.mi_sim import MISim


class MISim(MultiTableMetric):
"""MISim : Mutual Information Similarity

This class is used to calculate the Mutual Information Similarity between the target columns of real data and synthetic data.

Currently, we support discrete and continuous(need to be discretized) columns as inputs.
"""

def __init__(self) -> None:
super().__init__()
self.lower_bound = 0
self.upper_bound = 1
self.metric_name = "mutual_information_similarity"
self.numerical_bins = 50

@classmethod
def calculate(
real_data: pd.DataFrame, synthetic_data: pd.DataFrame, metadata: dict
) -> pd.DataFrame:
"""
Calculate the JSD value between a real column and a synthetic column.
Z712023 marked this conversation as resolved.
Show resolved Hide resolved
Args:
real_data (pd.DataFrame): The real data.
synthetic_data (pd.DataFrame): The synthetic data.
metadata(dict): The metadata that describes the data type of each column

Returns:
MI_similarity (float): The metric value.
"""

# 传入概率分布数组

columns = synthetic_data.columns
n = len(columns)
mi_sim_instance = MISim()
nMI_sim = np.zeros((n, n))

for i in range(len(columns)):
for j in range(len(columns)):
syn_data = pd.concat(
[synthetic_data[columns[i]], synthetic_data[columns[j]]], axis=1
)
real_data = pd.concat([real_data[columns[i]], real_data[columns[j]]], axis=1)

nMI_sim[i][j] = mi_sim_instance.calculate(real_data, syn_data, metadata)

MI_sim = np.sum(nMI_sim) / n / n
# test
MISim.check_output(MI_sim)

return MI_sim

@classmethod
def check_output(raw_metric_value: float):
"""Check the output value.

Args:
raw_metric_value (float): the calculated raw value of the JSD metric.
Z712023 marked this conversation as resolved.
Show resolved Hide resolved
"""
# instance = cls()
if raw_metric_value < self.lower_bound or raw_metric_value > self.upper_bound:
raise ValueError

# @classmethod
# def normailized_mutual_information(cls, p: float, q: float):
# """Calculate the jensen_shannon_divergence of p and q.

# Args:
# p (float): the input parameter p.

# q (float): the input parameter q.
# """
# n_MI = None

# return n_MI
79 changes: 79 additions & 0 deletions sdgx/metrics/pair_column/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pandas as pd

from sdgx.log import logger


class PairMetric(object):
"""PairMetric
Metrics used to evaluate the quality of synthetic data columns.
"""

upper_bound = None
lower_bound = None
metric_name = "Correlation"

def __init__(self) -> None:
pass

@classmethod
def check_input(
cls, real_data: pd.DataFrame, synthetic_data: pd.DataFrame, real_metadata, syn_metadata
):
"""Input check for table input.
Args:
real_data(pd.DataFrame ): the real (original) data table.
synthetic_data(pd.DataFrame): the synthetic (generated) data table .
"""
# Input parameter must not contain None value
if real_data is None or synthetic_data is None:
raise TypeError("Input contains None.")
# check column_names
real_cols = real_data.columns
syn_cols = synthetic_data.columns
if set(real_cols) != set(syn_cols):
raise TypeError("Columns of Dataframe are Different.")

# check column_types
for col in real_cols:
if real_metadata[col] != syn_metadata[col]:
raise TypeError("Columns of Dataframe are Different.")

# if type is pd.DataFrame, return directly
if isinstance(real_data, pd.DataFrame):
return real_data, synthetic_data

# if type is not pd.Series or pd.DataFrame tranfer it to Series
try:
real_data = pd.DataFrame(real_data)
synthetic_data = pd.DataFrame(synthetic_data)
return real_data, synthetic_data
except Exception as e:
logger.error(f"An error occurred while converting to pd.DataFrame: {e}")

return None, None

@classmethod
def calculate(cls, real_data: pd.DataFrame, synthetic_data: pd.DataFrame):
"""Calculate the metric value between pair-columns between real table and synthetic table.

Args:
real_data(pd.DataFrame or pd.Series): the real (original) data pair.

synthetic_data(pd.DataFrame or pd.Series): the synthetic (generated) data pair.
"""
# This method should first check the input
# such as:
real_data, synthetic_data = PairMetric.check_input(real_data, synthetic_data)

raise NotImplementedError()

@classmethod
def check_output(cls, raw_metric_value: float):
"""Check the output value.

Args:
raw_metric_value (float): the calculated raw value of the JSD metric.
Z712023 marked this conversation as resolved.
Show resolved Hide resolved
"""
raise NotImplementedError()

pass
120 changes: 120 additions & 0 deletions sdgx/metrics/pair_column/mi_sim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import time
from datetime import datetime

import numpy as np
import pandas as pd
from scipy.stats import entropy
from sklearn.metrics.cluster import normalized_mutual_info_score

from sdgx.metrics.pair_column.base import PairMetric


def Jaccard_index(A, B):
return min(A, B) / max(A, B)


def time2int(datetime, form):
time_array = time.strptime(datetime, form)
time_stamp = int(time.mktime(time_array))
return time_stamp
MooooCat marked this conversation as resolved.
Show resolved Hide resolved


class MISim(PairMetric):
"""MISim : Mutual Information Similarity

This class is used to calculate the Mutual Information Similarity between the target columns of real data and synthetic data.

Currently, we support discrete and continuous(need to be discretized) columns as inputs.
"""

def __init__(self) -> None:
super().__init__()
self.lower_bound = 0
self.upper_bound = 1
self.metric_name = "mutual_information_similarity"
self.numerical_bins = 50

@classmethod
def calculate(
real_data: pd.DataFrame,
synthetic_data: pd.DataFrame,
metadata: dict,
) -> float:
"""
Calculate the JSD value between a real column and a synthetic column.
Z712023 marked this conversation as resolved.
Show resolved Hide resolved
Args:
real_data (pd.DataFrame): The real data.
synthetic_data (pd.DataFrame): The synthetic data.
metadata(dict): The metadata that describes the data type of each column
Returns:
MI_similarity (float): The metric value.
"""

# 传入概率分布数组

columns = synthetic_data.columns
n = len(columns)

for col in columns:
# data_type = self.metadata[col]
if data_type == "numerical":
# max_value = real_data[col].max()
# min_value = real_data[col].min()
real_data[col] = pd.cut(
real_data[col], self.numerical_bins, labels=range(self.numerical_bins)
)
synthetic_data[col] = pd.cut(
synthetic_data[col], self.numerical_bins, labels=range(self.numerical_bins)
)

elif data_type == "datetime":
real_data[col] = real_data[col].apply(time2int)
synthetic_data[col] = synthetic_data[col].apply(time2int)
real_data[col] = pd.cut(
real_data[col], self.numerical_bins, labels=range(self.numerical_bins)
)
synthetic_data[col] = pd.cut(
synthetic_data[col], self.numerical_bins, labels=range(self.numerical_bins)
)

# nMI_sim = np.zeros((n,n))

# for i in range(len(columns)):
# for j in range(len(columns)):
syn_MI = normalized_mutual_info_score(
synthetic_data[columns[0]], synthetic_data[columns[1]]
)
real_MI = normalized_mutual_info_score(real_data[columns[0]], real_data[columns[1]])

MI_sim = Jaccard_index(syn_MI, real_MI)
"""
MI_sim = np.sum(nMI_sim)/n/n
# test
MISim.check_output(MI_sim)
"""
MISim.check_output(MI_sim)
return MI_sim

@classmethod
def check_output(raw_metric_value: float):
"""Check the output value.

Args:
raw_metric_value (float): the calculated raw value of the JSD metric.
"""
# instance = cls()
if raw_metric_value < self.lower_bound or raw_metric_value > self.upper_bound:
raise ValueError

# @classmethod
MooooCat marked this conversation as resolved.
Show resolved Hide resolved
# def normailized_mutual_information(cls, p: float, q: float):
# """Calculate the jensen_shannon_divergence of p and q.

# Args:
# p (float): the input parameter p.

# q (float): the input parameter q.
# """
# n_MI = None

# return n_MI
84 changes: 84 additions & 0 deletions sdgx/metrics/single_table/single_mi_sim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import time
from datetime import datetime

import numpy as np
import pandas as pd
from scipy.stats import entropy
from sklearn.metrics.cluster import normalized_mutual_info_score

from sdgx.metrics.pair_column.mi_sim import MISim
from sdgx.metrics.single_table.base import SingleTableMetric


class SinTabMISim(SingleTableMetric):
"""MISim : Mutual Information Similarity

This class is used to calculate the Mutual Information Similarity between the target columns of real data and synthetic data.

Currently, we support discrete and continuous(need to be discretized) columns as inputs.
"""

def __init__(self) -> None:
super().__init__()
self.lower_bound = 0
self.upper_bound = 1
self.metric_name = "mutual_information_similarity"
self.numerical_bins = 50

@classmethod
def calculate(real_data: pd.DataFrame, synthetic_data: pd.DataFrame, metadata) -> pd.DataFrame:
"""
Calculate the JSD value between a real column and a synthetic column.
Z712023 marked this conversation as resolved.
Show resolved Hide resolved
Args:
real_data (pd.DataFrame): The real data.
synthetic_data (pd.DataFrame): The synthetic data.
metadata(dict): The metadata that describes the data type of each column
Returns:
MI_similarity (float): The metric value.
"""

# 传入概率分布数组

columns = synthetic_data.columns
n = len(columns)
mi_sim_instance = MISim()
nMI_sim = np.zeros((n, n))

for i in range(len(columns)):
for j in range(len(columns)):
syn_data = pd.concat(
[synthetic_data[columns[i]], synthetic_data[columns[j]]], axis=1
)
real_data = pd.concat([real_data[columns[i]], real_data[columns[j]]], axis=1)

nMI_sim[i][j] = mi_sim_instance.calculate(real_data, syn_data, metadata)

MI_sim = np.sum(nMI_sim) / n / n
# test
MISim.check_output(MI_sim)

return MI_sim

@classmethod
def check_output(raw_metric_value: float):
"""Check the output value.

Args:
raw_metric_value (float): the calculated raw value of the JSD metric.
Z712023 marked this conversation as resolved.
Show resolved Hide resolved
"""
# instance = cls()
if raw_metric_value < self.lower_bound or raw_metric_value > self.upper_bound:
raise ValueError

# @classmethod
MooooCat marked this conversation as resolved.
Show resolved Hide resolved
# def normailized_mutual_information(cls, p: float, q: float):
# """Calculate the jensen_shannon_divergence of p and q.

# Args:
# p (float): the input parameter p.

# q (float): the input parameter q.
# """
# n_MI = None

# return n_MI
Loading