Skip to content

Commit

Permalink
fix wandb so mypy doesn't complain
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 13, 2023
1 parent 5b67ea9 commit 492cde8
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
3 changes: 3 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ ignore_missing_imports = True
[mypy-axolotl.monkeypatch.*]
ignore_errors = True

[mypy-axolotl.utils.callbacks]
disable_error_code = attr-defined

[mypy-flash_attn.*]
ignore_missing_imports = True

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ scipy
scikit-learn==1.2.2
pynvml
art
wandb
4 changes: 2 additions & 2 deletions src/axolotl/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import pandas as pd
import torch
import torch.distributed as dist
import wandb
from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm
Expand All @@ -25,6 +24,7 @@
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy

import wandb
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import (
barrier,
Expand Down Expand Up @@ -367,7 +367,7 @@ def on_evaluate(
output_scores=False,
)

def logits_to_tokens(logits) -> str:
def logits_to_tokens(logits) -> torch.Tensor:
probabilities = torch.softmax(logits, dim=-1)
# Get the predicted token ids (the ones with the highest probability)
predicted_token_ids = torch.argmax(probabilities, dim=-1)
Expand Down

0 comments on commit 492cde8

Please sign in to comment.