Skip to content
This repository has been archived by the owner on Aug 27, 2024. It is now read-only.

Commit

Permalink
Merge pull request #669 from predictive-analytics-lab/name-arg-in-rep…
Browse files Browse the repository at this point in the history
…lace-data

Add an optional name argument to replace_data()
  • Loading branch information
tmke8 committed May 27, 2022
2 parents 8983e1a + be6b125 commit 3e2d623
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 21 deletions.
4 changes: 2 additions & 2 deletions ethicml/algorithms/preprocess/upsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ def upsample(
upsampled_dataframes = pd.concat(
[upsampled_dataframes, df.drop(["preds"], axis="columns")], axis="index"
).reset_index(drop=True)
upsampled_datatuple = dataset.replace_data(upsampled_dataframes).rename(
f"{name}: {dataset.name}"
upsampled_datatuple = dataset.replace_data(
upsampled_dataframes, name=f"{name}: {dataset.name}"
)

assert upsampled_datatuple is not None
Expand Down
2 changes: 1 addition & 1 deletion ethicml/implementations/beutel.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def encode_dataset(
for embedding, _, _ in dataloader:
data_to_return += enc(embedding).data.numpy().tolist()

return datatuple.replace(x=pd.DataFrame(data_to_return)).rename(f"Beutel: {datatuple.name}")
return datatuple.replace(x=pd.DataFrame(data_to_return))


def encode_testset(enc: nn.Module, dataloader: torch.utils.data.DataLoader, testtuple: T) -> T:
Expand Down
2 changes: 1 addition & 1 deletion ethicml/implementations/vfae.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def transform(model: VFAENetwork, dataset: T, flags: VfaeArgs) -> T:
# z1 = model.reparameterize(z1_mu, z1_logvar)
post_train += z1_mu.data.tolist()

return dataset.replace(x=pd.DataFrame(post_train)).rename(f"VFAE: {dataset.name}")
return dataset.replace(x=pd.DataFrame(post_train))


def train_and_transform(train: DataTuple, test: T, flags: VfaeArgs) -> Tuple[DataTuple, T]:
Expand Down
20 changes: 10 additions & 10 deletions ethicml/preprocessing/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def train_test_split(
all_data_test = all_data_test.reset_index(drop=True)

# ================================== assemble train and test ==================================
train: DataTuple = data.replace_data(data=all_data_train).rename(f"{data.name} - Train")
test: DataTuple = data.replace_data(data=all_data_test).rename(f"{data.name} - Test")
train: DataTuple = data.replace_data(data=all_data_train, name=f"{data.name} - Train")
test: DataTuple = data.replace_data(data=all_data_test, name=f"{data.name} - Test")

assert isinstance(train.x, pd.DataFrame)
assert isinstance(test.x, pd.DataFrame)
Expand Down Expand Up @@ -208,12 +208,12 @@ def __call__(
data, train_percentage=self.train_percentage, random_seed=random_seed
)

train = data.replace_data(data.data.iloc[train_indices].reset_index(drop=True)).rename(
f"{data.name} - Train"
train = data.replace_data(
data.data.iloc[train_indices].reset_index(drop=True), name=f"{data.name} - Train"
)

test = data.replace_data(data.data.iloc[test_indices].reset_index(drop=True)).rename(
f"{data.name} - Test"
test = data.replace_data(
data.data.iloc[test_indices].reset_index(drop=True), name=f"{data.name} - Test"
)

# assert that no data points got lost anywhere
Expand Down Expand Up @@ -294,12 +294,12 @@ def __call__(
train_idx = np.concatenate(train_indices, axis=0)
test_idx = np.concatenate(test_indices, axis=0)

train = data.replace_data(data.data.iloc[train_idx].reset_index(drop=True)).rename(
f"{data.name} - Train"
train = data.replace_data(
data.data.iloc[train_idx].reset_index(drop=True), name=f"{data.name} - Train"
)

test = data.replace_data(data.data.iloc[test_idx].reset_index(drop=True)).rename(
f"{data.name} - Test"
test = data.replace_data(
data.data.iloc[test_idx].reset_index(drop=True), name=f"{data.name} - Test"
)

unbalanced_test_len = round(len(data) * (1 - self.train_percentage))
Expand Down
26 changes: 19 additions & 7 deletions ethicml/utility/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class SubsetMixin(ABC):
s_column: str

@abstractmethod
def replace_data(self: _S, data: pd.DataFrame) -> _S:
def replace_data(self: _S, data: pd.DataFrame, name: str | None = None) -> _S:
"""Make a copy of the container but change the underlying data."""

@property
Expand Down Expand Up @@ -144,10 +144,12 @@ def replace(
x=x if x is not None else self.x, s=s if s is not None else self.s, name=self.name
)

def replace_data(self, data: pd.DataFrame) -> SubgroupTuple:
def replace_data(self, data: pd.DataFrame, name: str | None = None) -> SubgroupTuple:
"""Make a copy of the DataTuple but change the underlying data."""
assert self.s_column in data.columns, f"column {self.s_column} not present"
return SubgroupTuple(data=data, s_column=self.s_column, name=self.name)
return SubgroupTuple(
data=data, s_column=self.s_column, name=self.name if name is None else name
)

def rename(self, name: str) -> SubgroupTuple:
"""Change only the name."""
Expand Down Expand Up @@ -251,11 +253,16 @@ def rename(self, name: str) -> DataTuple:
"""Change only the name."""
return DataTuple(data=self.data, s_column=self.s_column, y_column=self.y_column, name=name)

def replace_data(self, data: pd.DataFrame) -> DataTuple:
def replace_data(self, data: pd.DataFrame, name: str | None = None) -> DataTuple:
"""Make a copy of the DataTuple but change the underlying data."""
assert self.s_column in data.columns, f"column {self.s_column} not present"
assert self.y_column in data.columns, f"column {self.y_column} not present"
return DataTuple(data=data, s_column=self.s_column, y_column=self.y_column, name=self.name)
return DataTuple(
data=data,
s_column=self.s_column,
y_column=self.y_column,
name=self.name if name is None else name,
)

def apply_to_joined_df(self, mapper: Callable[[pd.DataFrame], pd.DataFrame]) -> DataTuple:
"""Concatenate the dataframes in the DataTuple and then apply a function to it.
Expand Down Expand Up @@ -358,11 +365,16 @@ def rename(self, name: str) -> LabelTuple:
"""Change only the name."""
return LabelTuple(data=self.data, s_column=self.s_column, y_column=self.y_column, name=name)

def replace_data(self, data: pd.DataFrame) -> LabelTuple:
def replace_data(self, data: pd.DataFrame, name: str | None = None) -> LabelTuple:
"""Make a copy of the LabelTuple but change the underlying data."""
assert self.s_column in data.columns, f"column {self.s_column} not present"
assert self.y_column in data.columns, f"column {self.y_column} not present"
return LabelTuple(data=data, s_column=self.s_column, y_column=self.y_column, name=self.name)
return LabelTuple(
data=data,
s_column=self.s_column,
y_column=self.y_column,
name=self.name if name is None else name,
)


TestTuple: TypeAlias = Union[SubgroupTuple, DataTuple]
Expand Down

0 comments on commit 3e2d623

Please sign in to comment.