Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pre-commit.ci] pre-commit suggestions #13

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ ci:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.6.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
Expand All @@ -25,7 +25,7 @@ repos:
- id: detect-private-key

- repo: https://github.com/asottile/pyupgrade
rev: v2.37.3
rev: v3.16.0
hooks:
- id: pyupgrade
args: [--py36-plus]
Expand All @@ -35,8 +35,8 @@ repos:
app.py
)

- repo: https://github.com/myint/docformatter
rev: v1.5.0
- repo: https://github.com/PyCQA/docformatter
rev: v1.7.5
hooks:
- id: docformatter
args: [--in-place, --wrap-summaries=120, --wrap-descriptions=120]
Expand All @@ -46,7 +46,7 @@ repos:
)

- repo: https://github.com/PyCQA/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
args: [--profile, black]
Expand All @@ -56,7 +56,7 @@ repos:
)

- repo: https://github.com/psf/black
rev: 22.6.0
rev: 24.4.2
hooks:
- id: black
name: Black code
Expand All @@ -66,7 +66,7 @@ repos:
)

- repo: https://github.com/executablebooks/mdformat
rev: 0.7.15
rev: 0.7.17
hooks:
- id: mdformat
additional_dependencies:
Expand All @@ -80,7 +80,7 @@ repos:
)

- repo: https://github.com/asottile/yesqa
rev: v1.4.0
rev: v1.5.0
hooks:
- id: yesqa
exclude: |
Expand All @@ -89,7 +89,7 @@ repos:
)

- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
rev: 7.1.0
hooks:
- id: flake8
exclude: |
Expand Down
17 changes: 8 additions & 9 deletions lai_tldr/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""Code in this file is based on https://github.com/Shivanandroy/simpleT5 by Shivanand Roy."""

import os
Expand Down Expand Up @@ -55,11 +54,11 @@ def __init__(
self.target_max_token_len = target_max_token_len

def __len__(self):
"""returns length of data."""
"""Returns length of data."""
return len(self.data)

def __getitem__(self, index: int):
"""returns dictionary of input tensors to feed into T5/MT5 model."""
"""Returns dictionary of input tensors to feed into T5/MT5 model."""

data_row = self.data.iloc[index]
source_text = data_row["source_text"]
Expand All @@ -85,9 +84,9 @@ def __getitem__(self, index: int):
)

labels = target_text_encoding["input_ids"]
labels[
labels == 0
] = -100 # to make sure we have correct labels for T5 text generation
labels[labels == 0] = (
-100
) # to make sure we have correct labels for T5 text generation

return dict(
source_text_input_ids=source_text_encoding["input_ids"].flatten(),
Expand Down Expand Up @@ -170,7 +169,7 @@ def setup(self, stage=None):
)

def train_dataloader(self):
"""training dataloader."""
"""Training dataloader."""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
Expand All @@ -179,7 +178,7 @@ def train_dataloader(self):
)

def test_dataloader(self):
"""test dataloader."""
"""Test dataloader."""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
Expand All @@ -188,7 +187,7 @@ def test_dataloader(self):
)

def val_dataloader(self):
"""validation dataloader."""
"""Validation dataloader."""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
Expand Down
11 changes: 5 additions & 6 deletions lai_tldr/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""Code in this file is based on https://github.com/Shivanandroy/simpleT5 by Shivanand Roy."""

from lightning.pytorch import LightningModule
Expand All @@ -42,7 +41,7 @@ def __init__(
self.save_only_last_epoch = False

def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
"""forward step."""
"""Forward step."""
output = self.model(
input_ids,
attention_mask=attention_mask,
Expand All @@ -53,7 +52,7 @@ def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None
return output.loss, output.logits

def training_step(self, batch, batch_idx):
"""training step."""
"""Training step."""
input_ids = batch["source_text_input_ids"]
attention_mask = batch["source_text_attention_mask"]
labels = batch["labels"]
Expand All @@ -70,7 +69,7 @@ def training_step(self, batch, batch_idx):
return loss

def validation_step(self, batch, batch_idx):
"""validation step."""
"""Validation step."""
input_ids = batch["source_text_input_ids"]
attention_mask = batch["source_text_attention_mask"]
labels = batch["labels"]
Expand All @@ -86,7 +85,7 @@ def validation_step(self, batch, batch_idx):
self.log("val_loss", loss, prog_bar=True)

def test_step(self, batch, batch_idx):
"""test step."""
"""Test step."""
input_ids = batch["source_text_input_ids"]
attention_mask = batch["source_text_attention_mask"]
labels = batch["labels"]
Expand All @@ -102,7 +101,7 @@ def test_step(self, batch, batch_idx):
self.log("test_loss", loss, prog_bar=True)

def configure_optimizers(self):
"""configure optimizers."""
"""Configure optimizers."""
return AdamW(self.parameters(), lr=0.0001)


Expand Down
Loading