Skip to content

Commit

Permalink
fixed bug with None as TEMP target bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Diyago committed Sep 30, 2023
1 parent 0ff72b4 commit 2789612
Showing 1 changed file with 37 additions and 18 deletions.
55 changes: 37 additions & 18 deletions src/tabgan/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,26 +306,27 @@ def generate_data(
else:
self.TEMP_TARGET = None
logging.info("Fitting ForestDiffusion model")

if self.cat_cols is None:
forest_model = ForestDiffusionModel(train_df, label_y=self.TEMP_TARGET, n_t=50,
forest_model = ForestDiffusionModel(train_df.to_numpy(), label_y=self.TEMP_TARGET, n_t=50,
duplicate_K=100,
diffusion_type='flow', n_jobs=-1)
else:
forest_model = ForestDiffusionModel(train_df, label_y=self.TEMP_TARGET, n_t=50,
forest_model = ForestDiffusionModel(train_df.to_numpy(), label_y=self.TEMP_TARGET, n_t=50,
duplicate_K=100,
# todo fix bug with cat cols
#cat_indexes=self.get_column_indexes(train_df, self.cat_cols),
diffusion_type='flow', n_jobs=-1)
logging.info("Finished training ForestDiffusionModel")
generated_df = forest_model.generate(batch_size=int(self.gen_x_times*train_df.shape[0]))
generated_df = forest_model.generate(batch_size=int(self.gen_x_times*train_df.to_numpy().shape[0]))
data_dtype = train_df.dtypes.values

generated_df = pd.DataFrame(generated_df)
generated_df.columns = train_df.columns
for i in range(len(generated_df.columns)):
generated_df[generated_df.columns[i]] = generated_df[
generated_df.columns[i]
].astype(data_dtype[i])
gc.collect()
self.TEMP_TARGET = "TEMP_TARGET"
if not only_generated_data:
train_df = pd.concat([train_df, generated_df]).reset_index(drop=True)
logging.info(
Expand Down Expand Up @@ -385,28 +386,46 @@ def get_columns_if_exists(df, col) -> pd.DataFrame:

if __name__ == "__main__":
setup_logging(logging.DEBUG)
train_size = 100
train_size = 75
train = pd.DataFrame(
np.random.randint(-10, 150, size=(train_size, 4)), columns=list("ABCD")
)
logging.info(train)
target = pd.DataFrame(np.random.randint(0, 2, size=(train_size, 1)), columns=list("Y"))
test = pd.DataFrame(np.random.randint(0, 100, size=(train_size, 4)), columns=list("ABCD"))
_sampler(OriginalGenerator(gen_x_times=15), train, target, test)
# _sampler(OriginalGenerator(gen_x_times=15), train, target, test)
# _sampler(
# GANGenerator(gen_x_times=10, only_generated_data=False,
# gen_params={"batch_size": 500, "patience": 25, "epochs": 500, }), train, target, test
# )
#
# _sampler(OriginalGenerator(gen_x_times=15), train, None, train)
# _sampler(
# GANGenerator(cat_cols=["A"], gen_x_times=20, only_generated_data=True),
# train,
# None,
# train,
# )
_sampler(
ForestDiffusionGenerator(cat_cols=["A"], gen_x_times=2, only_generated_data=True),
ForestDiffusionGenerator(cat_cols=["A"], gen_x_times=1, only_generated_data=True),
train,
None,
train,
)

train['Date'] = min_date + pd.to_timedelta(np.random.randint(d, size=train_size), unit='d')
train = get_year_mnth_dt_from_date(train, 'Date')

new_train, new_target = GANGenerator(gen_x_times=1.1, cat_cols=['year'], bot_filter_quantile=0.001,
top_filter_quantile=0.999,
is_post_process=True, pregeneration_frac=2, only_generated_data=False). \
generate_data_pipe(train.drop('Date', axis=1), None,
train.drop('Date', axis=1)
)
new_train = collect_dates(new_train)
#
# min_date = pd.to_datetime('2019-01-01')
# max_date = pd.to_datetime('2021-12-31')
#
# d = (max_date - min_date).days + 1
#
# train['Date'] = min_date + pd.to_timedelta(np.random.randint(d, size=train_size), unit='d')
# train = get_year_mnth_dt_from_date(train, 'Date')
#
# new_train, new_target = GANGenerator(gen_x_times=1.1, cat_cols=['year'], bot_filter_quantile=0.001,
# top_filter_quantile=0.999,
# is_post_process=True, pregeneration_frac=2, only_generated_data=False). \
# generate_data_pipe(train.drop('Date', axis=1), None,
# train.drop('Date', axis=1)
# )
# new_train = collect_dates(new_train)

0 comments on commit 2789612

Please sign in to comment.