Skip to content

Commit

Permalink
Fix type annotation of ExperienceSource.__iter__ (#645)
Browse files Browse the repository at this point in the history
Fixes #644

__iter__ returns Iterator
  • Loading branch information
t-vi authored May 14, 2021
1 parent cb687e6 commit 5bfb846
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pl_bolts/datamodules/experience_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
from abc import ABC
from collections import deque, namedtuple
from typing import Callable, Iterable, List, Tuple
from typing import Callable, Iterator, List, Tuple

import torch
from torch.utils.data import IterableDataset
Expand All @@ -30,7 +30,7 @@ class ExperienceSourceDataset(IterableDataset):
def __init__(self, generate_batch: Callable) -> None:
self.generate_batch = generate_batch

def __iter__(self) -> Iterable:
def __iter__(self) -> Iterator:
iterator = self.generate_batch()
return iterator

Expand Down

0 comments on commit 5bfb846

Please sign in to comment.