From 01f3bb1f7a5f1324ffa3a6c2401664e8329e1428 Mon Sep 17 00:00:00 2001 From: Thomas MK Date: Fri, 27 May 2022 18:07:14 +0200 Subject: [PATCH 1/2] Add an optional name argument to replace_data() --- ethicml/algorithms/preprocess/upsampler.py | 4 ++-- ethicml/implementations/beutel.py | 2 +- ethicml/implementations/vfae.py | 2 +- ethicml/preprocessing/splits.py | 20 ++++++++--------- ethicml/utility/data_structures.py | 26 ++++++++++++++++------ 5 files changed, 33 insertions(+), 21 deletions(-) diff --git a/ethicml/algorithms/preprocess/upsampler.py b/ethicml/algorithms/preprocess/upsampler.py index df5e4ab2..a9942cd0 100644 --- a/ethicml/algorithms/preprocess/upsampler.py +++ b/ethicml/algorithms/preprocess/upsampler.py @@ -177,8 +177,8 @@ def upsample( upsampled_dataframes = pd.concat( [upsampled_dataframes, df.drop(["preds"], axis="columns")], axis="index" ).reset_index(drop=True) - upsampled_datatuple = dataset.replace_data(upsampled_dataframes).rename( - f"{name}: {dataset.name}" + upsampled_datatuple = dataset.replace_data( + upsampled_dataframes, name=f"{name}: {dataset.name}" ) assert upsampled_datatuple is not None diff --git a/ethicml/implementations/beutel.py b/ethicml/implementations/beutel.py index 3b5144c3..98dc6e64 100644 --- a/ethicml/implementations/beutel.py +++ b/ethicml/implementations/beutel.py @@ -260,7 +260,7 @@ def encode_dataset( for embedding, _, _ in dataloader: data_to_return += enc(embedding).data.numpy().tolist() - return datatuple.replace(x=pd.DataFrame(data_to_return)).rename(f"Beutel: {datatuple.name}") + return datatuple.replace(x=pd.DataFrame(data_to_return), name=f"Beutel: {datatuple.name}") def encode_testset(enc: nn.Module, dataloader: torch.utils.data.DataLoader, testtuple: T) -> T: diff --git a/ethicml/implementations/vfae.py b/ethicml/implementations/vfae.py index 7c7aa6ec..33106cac 100644 --- a/ethicml/implementations/vfae.py +++ b/ethicml/implementations/vfae.py @@ -83,7 +83,7 @@ def transform(model: VFAENetwork, dataset: T, flags: VfaeArgs) -> T: # z1 = model.reparameterize(z1_mu, z1_logvar) post_train += z1_mu.data.tolist() - return dataset.replace(x=pd.DataFrame(post_train)).rename(f"VFAE: {dataset.name}") + return dataset.replace(x=pd.DataFrame(post_train), name=f"VFAE: {dataset.name}") def train_and_transform(train: DataTuple, test: T, flags: VfaeArgs) -> Tuple[DataTuple, T]: diff --git a/ethicml/preprocessing/splits.py b/ethicml/preprocessing/splits.py index 1d11e268..3833a5f7 100644 --- a/ethicml/preprocessing/splits.py +++ b/ethicml/preprocessing/splits.py @@ -107,8 +107,8 @@ def train_test_split( all_data_test = all_data_test.reset_index(drop=True) # ================================== assemble train and test ================================== - train: DataTuple = data.replace_data(data=all_data_train).rename(f"{data.name} - Train") - test: DataTuple = data.replace_data(data=all_data_test).rename(f"{data.name} - Test") + train: DataTuple = data.replace_data(data=all_data_train, name=f"{data.name} - Train") + test: DataTuple = data.replace_data(data=all_data_test, name=f"{data.name} - Test") assert isinstance(train.x, pd.DataFrame) assert isinstance(test.x, pd.DataFrame) @@ -208,12 +208,12 @@ def __call__( data, train_percentage=self.train_percentage, random_seed=random_seed ) - train = data.replace_data(data.data.iloc[train_indices].reset_index(drop=True)).rename( - f"{data.name} - Train" + train = data.replace_data( + data.data.iloc[train_indices].reset_index(drop=True), name=f"{data.name} - Train" ) - test = data.replace_data(data.data.iloc[test_indices].reset_index(drop=True)).rename( - f"{data.name} - Test" + test = data.replace_data( + data.data.iloc[test_indices].reset_index(drop=True), name=f"{data.name} - Test" ) # assert that no data points got lost anywhere @@ -294,12 +294,12 @@ def __call__( train_idx = np.concatenate(train_indices, axis=0) test_idx = np.concatenate(test_indices, axis=0) - train = data.replace_data(data.data.iloc[train_idx].reset_index(drop=True)).rename( - f"{data.name} - Train" + train = data.replace_data( + data.data.iloc[train_idx].reset_index(drop=True), name=f"{data.name} - Train" ) - test = data.replace_data(data.data.iloc[test_idx].reset_index(drop=True)).rename( - f"{data.name} - Test" + test = data.replace_data( + data.data.iloc[test_idx].reset_index(drop=True), name=f"{data.name} - Test" ) unbalanced_test_len = round(len(data) * (1 - self.train_percentage)) diff --git a/ethicml/utility/data_structures.py b/ethicml/utility/data_structures.py index 8098df2d..4ff512fc 100644 --- a/ethicml/utility/data_structures.py +++ b/ethicml/utility/data_structures.py @@ -76,7 +76,7 @@ class SubsetMixin(ABC): s_column: str @abstractmethod - def replace_data(self: _S, data: pd.DataFrame) -> _S: + def replace_data(self: _S, data: pd.DataFrame, name: str | None = None) -> _S: """Make a copy of the container but change the underlying data.""" @property @@ -144,10 +144,12 @@ def replace( x=x if x is not None else self.x, s=s if s is not None else self.s, name=self.name ) - def replace_data(self, data: pd.DataFrame) -> SubgroupTuple: + def replace_data(self, data: pd.DataFrame, name: str | None = None) -> SubgroupTuple: """Make a copy of the DataTuple but change the underlying data.""" assert self.s_column in data.columns, f"column {self.s_column} not present" - return SubgroupTuple(data=data, s_column=self.s_column, name=self.name) + return SubgroupTuple( + data=data, s_column=self.s_column, name=self.name if name is None else name + ) def rename(self, name: str) -> SubgroupTuple: """Change only the name.""" @@ -251,11 +253,16 @@ def rename(self, name: str) -> DataTuple: """Change only the name.""" return DataTuple(data=self.data, s_column=self.s_column, y_column=self.y_column, name=name) - def replace_data(self, data: pd.DataFrame) -> DataTuple: + def replace_data(self, data: pd.DataFrame, name: str | None = None) -> DataTuple: """Make a copy of the DataTuple but change the underlying data.""" assert self.s_column in data.columns, f"column {self.s_column} not present" assert self.y_column in data.columns, f"column {self.y_column} not present" - return DataTuple(data=data, s_column=self.s_column, y_column=self.y_column, name=self.name) + return DataTuple( + data=data, + s_column=self.s_column, + y_column=self.y_column, + name=self.name if name is None else name, + ) def apply_to_joined_df(self, mapper: Callable[[pd.DataFrame], pd.DataFrame]) -> DataTuple: """Concatenate the dataframes in the DataTuple and then apply a function to it. @@ -358,11 +365,16 @@ def rename(self, name: str) -> LabelTuple: """Change only the name.""" return LabelTuple(data=self.data, s_column=self.s_column, y_column=self.y_column, name=name) - def replace_data(self, data: pd.DataFrame) -> LabelTuple: + def replace_data(self, data: pd.DataFrame, name: str | None = None) -> LabelTuple: """Make a copy of the LabelTuple but change the underlying data.""" assert self.s_column in data.columns, f"column {self.s_column} not present" assert self.y_column in data.columns, f"column {self.y_column} not present" - return LabelTuple(data=data, s_column=self.s_column, y_column=self.y_column, name=self.name) + return LabelTuple( + data=data, + s_column=self.s_column, + y_column=self.y_column, + name=self.name if name is None else name, + ) TestTuple: TypeAlias = Union[SubgroupTuple, DataTuple] From be6b125be327cd9bb9d4cf34612898e7b85f2adf Mon Sep 17 00:00:00 2001 From: Thomas MK Date: Fri, 27 May 2022 18:10:13 +0200 Subject: [PATCH 2/2] Don't confuse replace() and replace_data() --- ethicml/implementations/beutel.py | 2 +- ethicml/implementations/vfae.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ethicml/implementations/beutel.py b/ethicml/implementations/beutel.py index 98dc6e64..94e1e5a1 100644 --- a/ethicml/implementations/beutel.py +++ b/ethicml/implementations/beutel.py @@ -260,7 +260,7 @@ def encode_dataset( for embedding, _, _ in dataloader: data_to_return += enc(embedding).data.numpy().tolist() - return datatuple.replace(x=pd.DataFrame(data_to_return), name=f"Beutel: {datatuple.name}") + return datatuple.replace(x=pd.DataFrame(data_to_return)) def encode_testset(enc: nn.Module, dataloader: torch.utils.data.DataLoader, testtuple: T) -> T: diff --git a/ethicml/implementations/vfae.py b/ethicml/implementations/vfae.py index 33106cac..333991c4 100644 --- a/ethicml/implementations/vfae.py +++ b/ethicml/implementations/vfae.py @@ -83,7 +83,7 @@ def transform(model: VFAENetwork, dataset: T, flags: VfaeArgs) -> T: # z1 = model.reparameterize(z1_mu, z1_logvar) post_train += z1_mu.data.tolist() - return dataset.replace(x=pd.DataFrame(post_train), name=f"VFAE: {dataset.name}") + return dataset.replace(x=pd.DataFrame(post_train)) def train_and_transform(train: DataTuple, test: T, flags: VfaeArgs) -> Tuple[DataTuple, T]: