Skip to content

Commit

Permalink
Merge pull request #14 from hitsz-ids/bugfix_func-arg-default-list
Browse files Browse the repository at this point in the history
bugfix: Fix function's arg default List
  • Loading branch information
MooooCat committed Aug 22, 2023
2 parents 19638e0 + b3f6694 commit 701abcc
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
5 changes: 5 additions & 0 deletions sdgx/models/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Optional

import numpy as np
import torch

Expand All @@ -13,6 +15,9 @@ def __init__(self, transformer=None, sampler=None) -> None:
# self.epochs = epochs
self._device = "CPU"

def fit(self, input_df, discrete_cols: Optional[List] = None):
raise NotImplementedError

def set_device(self, device):
"""Set the `device` to be used ('GPU' or 'CPU')."""
self._device = device
Expand Down
18 changes: 13 additions & 5 deletions sdgx/models/single_table/ctgan.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from typing import List, Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -269,7 +270,9 @@ def _cond_loss(self, data, c, m):
ed = st + span_info.dim
ed_c = st_c + span_info.dim
tmp = functional.cross_entropy(
data[:, st:ed], torch.argmax(c[:, st_c:ed_c], dim=1), reduction="none"
data[:, st:ed],
torch.argmax(c[:, st_c:ed_c], dim=1),
reduction="none",
)
loss.append(tmp)
st = ed
Expand Down Expand Up @@ -305,7 +308,7 @@ def _validate_discrete_columns(self, train_data, discrete_columns):
raise ValueError(f"Invalid columns found: {invalid_columns}")

@random_state
def fit(self, train_data, discrete_columns=(), epochs=None):
def fit(self, train_data, discrete_columns: Optional[List] = None, epochs=None):
"""Fit the CTGAN Synthesizer models to the training data.
Args:
Expand All @@ -317,7 +320,8 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
contain the integer indices of the columns. Otherwise, if it is
a ``pandas.DataFrame``, this list should contain the column names.
"""

if not discrete_cols:
discrete_cols = []
# 离散列检查
self._validate_discrete_columns(train_data, discrete_columns)

Expand Down Expand Up @@ -350,11 +354,15 @@ def fit(self, train_data, discrete_columns=(), epochs=None):

# sampler 作为参数给到 Generator 以及 Discriminator
self._generator = Generator(
self._embedding_dim + self._data_sampler.dim_cond_vec(), self._generator_dim, data_dim
self._embedding_dim + self._data_sampler.dim_cond_vec(),
self._generator_dim,
data_dim,
).to(self._device)

discriminator = Discriminator(
data_dim + self._data_sampler.dim_cond_vec(), self._discriminator_dim, pac=self.pac
data_dim + self._data_sampler.dim_cond_vec(),
self._discriminator_dim,
pac=self.pac,
).to(self._device)

# 初始化 optimizer G 以及 D
Expand Down
5 changes: 4 additions & 1 deletion sdgx/transform/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

from collections import namedtuple
from typing import List, Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -101,7 +102,7 @@ def _fit_discrete(self, data):
output_dimensions=num_categories,
)

def fit(self, raw_data, discrete_columns=()):
def fit(self, raw_data, discrete_columns: Optional[List] = None):
"""Fit the ``DataTransformer``.
Fits a ``ClusterBasedNormalizer`` for continuous columns and a
Expand All @@ -112,6 +113,8 @@ def fit(self, raw_data, discrete_columns=()):
self.output_info_list = []
self.output_dimensions = 0
self.dataframe = True
if not discrete_columns:
discrete_columns = []

if not isinstance(raw_data, pd.DataFrame):
self.dataframe = False
Expand Down

0 comments on commit 701abcc

Please sign in to comment.