From 2789612676e0afcadcf687cee0c6ca46bcd520ce Mon Sep 17 00:00:00 2001 From: Insaf Ashrapov Date: Sat, 30 Sep 2023 22:01:30 +0300 Subject: [PATCH] fixed bug with None as TEMP target bug --- src/tabgan/sampler.py | 55 +++++++++++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/src/tabgan/sampler.py b/src/tabgan/sampler.py index 023723f..7b679ac 100644 --- a/src/tabgan/sampler.py +++ b/src/tabgan/sampler.py @@ -306,26 +306,27 @@ def generate_data( else: self.TEMP_TARGET = None logging.info("Fitting ForestDiffusion model") - if self.cat_cols is None: - forest_model = ForestDiffusionModel(train_df, label_y=self.TEMP_TARGET, n_t=50, + forest_model = ForestDiffusionModel(train_df.to_numpy(), label_y=self.TEMP_TARGET, n_t=50, duplicate_K=100, diffusion_type='flow', n_jobs=-1) else: - forest_model = ForestDiffusionModel(train_df, label_y=self.TEMP_TARGET, n_t=50, + forest_model = ForestDiffusionModel(train_df.to_numpy(), label_y=self.TEMP_TARGET, n_t=50, duplicate_K=100, # todo fix bug with cat cols #cat_indexes=self.get_column_indexes(train_df, self.cat_cols), diffusion_type='flow', n_jobs=-1) logging.info("Finished training ForestDiffusionModel") - generated_df = forest_model.generate(batch_size=int(self.gen_x_times*train_df.shape[0])) + generated_df = forest_model.generate(batch_size=int(self.gen_x_times*train_df.to_numpy().shape[0])) data_dtype = train_df.dtypes.values - + generated_df = pd.DataFrame(generated_df) + generated_df.columns = train_df.columns for i in range(len(generated_df.columns)): generated_df[generated_df.columns[i]] = generated_df[ generated_df.columns[i] ].astype(data_dtype[i]) gc.collect() + self.TEMP_TARGET = "TEMP_TARGET" if not only_generated_data: train_df = pd.concat([train_df, generated_df]).reset_index(drop=True) logging.info( @@ -385,28 +386,46 @@ def get_columns_if_exists(df, col) -> pd.DataFrame: if __name__ == "__main__": setup_logging(logging.DEBUG) - train_size = 100 + train_size = 75 train = pd.DataFrame( np.random.randint(-10, 150, size=(train_size, 4)), columns=list("ABCD") ) logging.info(train) target = pd.DataFrame(np.random.randint(0, 2, size=(train_size, 1)), columns=list("Y")) test = pd.DataFrame(np.random.randint(0, 100, size=(train_size, 4)), columns=list("ABCD")) - _sampler(OriginalGenerator(gen_x_times=15), train, target, test) + # _sampler(OriginalGenerator(gen_x_times=15), train, target, test) + # _sampler( + # GANGenerator(gen_x_times=10, only_generated_data=False, + # gen_params={"batch_size": 500, "patience": 25, "epochs": 500, }), train, target, test + # ) + # + # _sampler(OriginalGenerator(gen_x_times=15), train, None, train) + # _sampler( + # GANGenerator(cat_cols=["A"], gen_x_times=20, only_generated_data=True), + # train, + # None, + # train, + # ) _sampler( - ForestDiffusionGenerator(cat_cols=["A"], gen_x_times=2, only_generated_data=True), + ForestDiffusionGenerator(cat_cols=["A"], gen_x_times=1, only_generated_data=True), train, None, train, ) - train['Date'] = min_date + pd.to_timedelta(np.random.randint(d, size=train_size), unit='d') - train = get_year_mnth_dt_from_date(train, 'Date') - - new_train, new_target = GANGenerator(gen_x_times=1.1, cat_cols=['year'], bot_filter_quantile=0.001, - top_filter_quantile=0.999, - is_post_process=True, pregeneration_frac=2, only_generated_data=False). \ - generate_data_pipe(train.drop('Date', axis=1), None, - train.drop('Date', axis=1) - ) - new_train = collect_dates(new_train) + # + # min_date = pd.to_datetime('2019-01-01') + # max_date = pd.to_datetime('2021-12-31') + # + # d = (max_date - min_date).days + 1 + # + # train['Date'] = min_date + pd.to_timedelta(np.random.randint(d, size=train_size), unit='d') + # train = get_year_mnth_dt_from_date(train, 'Date') + # + # new_train, new_target = GANGenerator(gen_x_times=1.1, cat_cols=['year'], bot_filter_quantile=0.001, + # top_filter_quantile=0.999, + # is_post_process=True, pregeneration_frac=2, only_generated_data=False). \ + # generate_data_pipe(train.drop('Date', axis=1), None, + # train.drop('Date', axis=1) + # ) + # new_train = collect_dates(new_train)