Skip to content

Commit

Permalink
Bugfix metric mutual information (#118)
Browse files Browse the repository at this point in the history
* init bug-fix

* add time_test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add datetime_test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Z712023 and pre-commit-ci[bot] committed Jan 18, 2024
1 parent dd70515 commit 9ac9988
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
4 changes: 2 additions & 2 deletions sdgx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def get_demo_single_table(data_dir: str | Path = "./dataset"):
return pd_obj, discrete_cols


def time2int(datetime, form):
time_array = time.strptime(datetime, form)
def time2int(datetime, form="%Y-%m-%d %H:%M:%S"):
time_array = time.strptime(str(datetime), form)
time_stamp = int(time.mktime(time_array))
return time_stamp

Expand Down
41 changes: 40 additions & 1 deletion tests/metrics/test_MISim.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import math
import random
from datetime import datetime, timedelta

import numpy as np
import pandas as pd
Expand All @@ -11,6 +13,29 @@
# 创建测试数据


def generate_random_time(start_date, end_date):
start_datetime = datetime.strptime(start_date, "%Y-%m-%d")
end_datetime = datetime.strptime(end_date, "%Y-%m-%d")

random_time_delta = random.randint(0, int((end_datetime - start_datetime).total_seconds()))
random_datetime = start_datetime + timedelta(seconds=random_time_delta)

return random_datetime


@pytest.fixture
def test_data_time():
start_date = "1900-01-01"
end_date = "2023-12-31"
df = pd.DataFrame(
{
"time_x": [generate_random_time(start_date, end_date) for _ in range(10)],
"time_y": [generate_random_time(start_date, end_date) for _ in range(10)],
}
)
return df


@pytest.fixture
def test_data_category():
role_set = ["admin", "user", "guest"]
Expand Down Expand Up @@ -56,7 +81,7 @@ def test_MISim_discrete(test_data_category, mi_sim_instance):
assert result >= 0
assert result <= 1
assert result1 == 1
assert result2 == result
assert round(result2, 9) == round(result, 9)


def test_MISim_continuous(test_data_num, mi_sim_instance):
Expand All @@ -67,6 +92,20 @@ def test_MISim_continuous(test_data_num, mi_sim_instance):
result1 = mi_sim_instance.calculate(test_data_num[col_src], test_data_num[col_src], metadata)
result2 = mi_sim_instance.calculate(test_data_num[col_tar], test_data_num[col_src], metadata)

assert result >= 0
assert result <= 1
assert result1 == 1
assert round(result2, 9) == round(result, 9)


def test_MISim_time(test_data_time, mi_sim_instance):
metadata = {"time_x": "datetime", "time_y": "datetime"}
col_src = "time_x"
col_tar = "time_y"
result = mi_sim_instance.calculate(test_data_time[col_src], test_data_time[col_tar], metadata)
result1 = mi_sim_instance.calculate(test_data_time[col_src], test_data_time[col_src], metadata)
result2 = mi_sim_instance.calculate(test_data_time[col_tar], test_data_time[col_src], metadata)

assert result >= 0
assert result <= 1
assert result1 == 1
Expand Down

0 comments on commit 9ac9988

Please sign in to comment.