Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 1, 2024
1 parent 8a9f822 commit c9fa82b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
17 changes: 8 additions & 9 deletions lai_tldr/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""Code in this file is based on https://github.com/Shivanandroy/simpleT5 by Shivanand Roy."""

import os
Expand Down Expand Up @@ -55,11 +54,11 @@ def __init__(
self.target_max_token_len = target_max_token_len

def __len__(self):
"""returns length of data."""
"""Returns length of data."""
return len(self.data)

def __getitem__(self, index: int):
"""returns dictionary of input tensors to feed into T5/MT5 model."""
"""Returns dictionary of input tensors to feed into T5/MT5 model."""

data_row = self.data.iloc[index]
source_text = data_row["source_text"]
Expand All @@ -85,9 +84,9 @@ def __getitem__(self, index: int):
)

labels = target_text_encoding["input_ids"]
labels[
labels == 0
] = -100 # to make sure we have correct labels for T5 text generation
labels[labels == 0] = (
-100
) # to make sure we have correct labels for T5 text generation

return dict(
source_text_input_ids=source_text_encoding["input_ids"].flatten(),
Expand Down Expand Up @@ -170,7 +169,7 @@ def setup(self, stage=None):
)

def train_dataloader(self):
"""training dataloader."""
"""Training dataloader."""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
Expand All @@ -179,7 +178,7 @@ def train_dataloader(self):
)

def test_dataloader(self):
"""test dataloader."""
"""Test dataloader."""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
Expand All @@ -188,7 +187,7 @@ def test_dataloader(self):
)

def val_dataloader(self):
"""validation dataloader."""
"""Validation dataloader."""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
Expand Down
11 changes: 5 additions & 6 deletions lai_tldr/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""Code in this file is based on https://github.com/Shivanandroy/simpleT5 by Shivanand Roy."""

from lightning.pytorch import LightningModule
Expand All @@ -42,7 +41,7 @@ def __init__(
self.save_only_last_epoch = False

def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
"""forward step."""
"""Forward step."""
output = self.model(
input_ids,
attention_mask=attention_mask,
Expand All @@ -53,7 +52,7 @@ def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None
return output.loss, output.logits

def training_step(self, batch, batch_idx):
"""training step."""
"""Training step."""
input_ids = batch["source_text_input_ids"]
attention_mask = batch["source_text_attention_mask"]
labels = batch["labels"]
Expand All @@ -70,7 +69,7 @@ def training_step(self, batch, batch_idx):
return loss

def validation_step(self, batch, batch_idx):
"""validation step."""
"""Validation step."""
input_ids = batch["source_text_input_ids"]
attention_mask = batch["source_text_attention_mask"]
labels = batch["labels"]
Expand All @@ -86,7 +85,7 @@ def validation_step(self, batch, batch_idx):
self.log("val_loss", loss, prog_bar=True)

def test_step(self, batch, batch_idx):
"""test step."""
"""Test step."""
input_ids = batch["source_text_input_ids"]
attention_mask = batch["source_text_attention_mask"]
labels = batch["labels"]
Expand All @@ -102,7 +101,7 @@ def test_step(self, batch, batch_idx):
self.log("test_loss", loss, prog_bar=True)

def configure_optimizers(self):
"""configure optimizers."""
"""Configure optimizers."""
return AdamW(self.parameters(), lr=0.0001)


Expand Down

0 comments on commit c9fa82b

Please sign in to comment.