Skip to content

Commit

Permalink
Add testing for JSD metrics (#100)
Browse files Browse the repository at this point in the history
* test_jsd

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* using the fixture

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: MoooCat <141886018+MooooCat@users.noreply.github.com>
  • Loading branch information
3 people committed Jan 16, 2024
1 parent 3264d06 commit f18b552
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions tests/metrics/test_jsd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

import random

import numpy as np
import pandas as pd
import pytest

from sdgx.metrics.column.jsd import JSD


# 创建测试数据
@pytest.fixture
def dummy_data(dummy_single_table_path):
yield pd.read_csv(dummy_single_table_path)


@pytest.fixture
def test_data():
role_set = ["admin", "user", "guest"]
df = pd.DataFrame(
{
"role": [random.choice(role_set) for _ in range(10)],
"feature_x": [random.random() for _ in range(10)],
}
)
return df


@pytest.fixture
def jsd_instance():
return JSD()


def test_jsd_discrete(dummy_data, test_data, jsd_instance):
cols = ["role"]
result = jsd_instance.calculate(dummy_data, test_data, cols, discrete=True)
result1 = jsd_instance.calculate(dummy_data, dummy_data, cols, discrete=True)
result2 = jsd_instance.calculate(test_data, dummy_data, cols, discrete=True)

assert result >= 0
assert result <= 1
assert result1 == 0
assert result2 == result


def test_jsd_continuous(dummy_data, test_data, jsd_instance):
cols = ["feature_x"]
result = jsd_instance.calculate(dummy_data, test_data, cols, discrete=False)
result1 = jsd_instance.calculate(dummy_data, dummy_data, cols, discrete=False)

assert result >= 0
assert result <= 1
assert result1 == 0


if __name__ == "__main__":
pytest.main(["-vv", "-s", __file__])

0 comments on commit f18b552

Please sign in to comment.