Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 4, 2023
1 parent be40e62 commit a3d94e2
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 13 deletions.
11 changes: 5 additions & 6 deletions lai_tldr/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""Code in this file is based on https://github.com/Shivanandroy/simpleT5 by Shivanand Roy."""

import os
Expand Down Expand Up @@ -55,11 +54,11 @@ def __init__(
self.target_max_token_len = target_max_token_len

def __len__(self):
"""returns length of data."""
"""Returns length of data."""
return len(self.data)

def __getitem__(self, index: int):
"""returns dictionary of input tensors to feed into T5/MT5 model."""
"""Returns dictionary of input tensors to feed into T5/MT5 model."""

data_row = self.data.iloc[index]
source_text = data_row["source_text"]
Expand Down Expand Up @@ -170,7 +169,7 @@ def setup(self, stage=None):
)

def train_dataloader(self):
"""training dataloader."""
"""Training dataloader."""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
Expand All @@ -179,7 +178,7 @@ def train_dataloader(self):
)

def test_dataloader(self):
"""test dataloader."""
"""Test dataloader."""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
Expand All @@ -188,7 +187,7 @@ def test_dataloader(self):
)

def val_dataloader(self):
"""validation dataloader."""
"""Validation dataloader."""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
Expand Down
11 changes: 5 additions & 6 deletions lai_tldr/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""Code in this file is based on https://github.com/Shivanandroy/simpleT5 by Shivanand Roy."""

from lightning.pytorch import LightningModule
Expand All @@ -42,7 +41,7 @@ def __init__(
self.save_only_last_epoch = False

def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
"""forward step."""
"""Forward step."""
output = self.model(
input_ids,
attention_mask=attention_mask,
Expand All @@ -53,7 +52,7 @@ def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None
return output.loss, output.logits

def training_step(self, batch, batch_idx):
"""training step."""
"""Training step."""
input_ids = batch["source_text_input_ids"]
attention_mask = batch["source_text_attention_mask"]
labels = batch["labels"]
Expand All @@ -70,7 +69,7 @@ def training_step(self, batch, batch_idx):
return loss

def validation_step(self, batch, batch_idx):
"""validation step."""
"""Validation step."""
input_ids = batch["source_text_input_ids"]
attention_mask = batch["source_text_attention_mask"]
labels = batch["labels"]
Expand All @@ -86,7 +85,7 @@ def validation_step(self, batch, batch_idx):
self.log("val_loss", loss, prog_bar=True)

def test_step(self, batch, batch_idx):
"""test step."""
"""Test step."""
input_ids = batch["source_text_input_ids"]
attention_mask = batch["source_text_attention_mask"]
labels = batch["labels"]
Expand All @@ -102,7 +101,7 @@ def test_step(self, batch, batch_idx):
self.log("test_loss", loss, prog_bar=True)

def configure_optimizers(self):
"""configure optimizers."""
"""Configure optimizers."""
return AdamW(self.parameters(), lr=0.0001)


Expand Down
1 change: 0 additions & 1 deletion tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def test_summarization_dataset():

counter = 0
for sample in dset:

assert isinstance(sample, dict)
keys = list(sample.keys())

Expand Down

0 comments on commit a3d94e2

Please sign in to comment.