-
Notifications
You must be signed in to change notification settings - Fork 483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Episode iterator upgrades #216
Changes from 12 commits
ebaa6ec
3b411d0
c5de063
bb06827
69cf09a
0256a4b
b1383d0
cccfc9b
9cba332
19b9d2a
01847cd
b892f65
96d1375
e43bdc5
974dbca
8b8309f
fb54383
ebf8358
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -278,7 +278,8 @@ def __init__( | |
cycle: bool = True, | ||
shuffle: bool = False, | ||
group_by_scene: bool = True, | ||
max_scene_repeat: int = -1, | ||
max_scene_repeat_episodes: int = -1, | ||
max_scene_repeat_steps: int = -1, | ||
num_episode_sample: int = -1, | ||
): | ||
r""".. | ||
|
@@ -290,32 +291,43 @@ def __init__( | |
effect if cycle is set to :py:`False`. Will shuffle grouped scenes | ||
if :p:`group_by_scene` is :py:`True`. | ||
:param group_by_scene: if :py:`True`, group episodes from same scene. | ||
:param max_scene_repeat: threshold of how many episodes from the same | ||
:param max_scene_repeat_episodes: threshold of how many episodes from the same | ||
scene can be loaded consecutively. :py:`-1` for no limit | ||
:param max_scene_repeat_steps: threshold of how many steps from the same | ||
scene can be taken consecutively. :py:`-1` for no limit | ||
:param num_episode_sample: number of episodes to be sampled. :py:`-1` | ||
for no sampling. | ||
""" | ||
self._repetition_rand_interval = 0.2 | ||
|
||
# sample episodes | ||
if num_episode_sample >= 0: | ||
episodes = np.random.choice( | ||
episodes, num_episode_sample, replace=False | ||
) | ||
|
||
self.episodes = episodes | ||
self.cycle = cycle | ||
self.group_by_scene = group_by_scene | ||
if group_by_scene: | ||
num_scene_groups = len( | ||
list(groupby(episodes, key=lambda x: x.scene_id)) | ||
) | ||
num_unique_scenes = len(set([e.scene_id for e in episodes])) | ||
if num_scene_groups >= num_unique_scenes: | ||
self.episodes = sorted(self.episodes, key=lambda x: x.scene_id) | ||
self.max_scene_repetition = max_scene_repeat | ||
self.shuffle = shuffle | ||
|
||
if shuffle: | ||
random.shuffle(self.episodes) | ||
|
||
if group_by_scene: | ||
self.episodes = sorted(self.episodes, key=lambda x: x.scene_id) | ||
|
||
self.max_scene_repetition_episodes = max_scene_repeat_episodes | ||
self.max_scene_repetition_steps = max_scene_repeat_steps | ||
|
||
self._rep_count = 0 | ||
self._step_count = 0 | ||
self._prev_scene_id = None | ||
|
||
self._iterator = iter(self.episodes) | ||
|
||
self._set_shuffle_intervals() | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
|
@@ -324,6 +336,7 @@ def __next__(self): | |
|
||
:return: next episode. | ||
""" | ||
self._switch_scene_if() | ||
|
||
next_episode = next(self._iterator, None) | ||
if next_episode is None: | ||
|
@@ -334,14 +347,9 @@ def __next__(self): | |
self._shuffle_iterator() | ||
next_episode = next(self._iterator) | ||
|
||
if self._prev_scene_id == next_episode.scene_id: | ||
self._rep_count += 1 | ||
if ( | ||
self.max_scene_repetition > 0 | ||
and self._rep_count >= self.max_scene_repetition - 1 | ||
): | ||
self._shuffle_iterator() | ||
if self._prev_scene_id != next_episode.scene_id: | ||
self._rep_count = 0 | ||
self._step_count = 0 | ||
|
||
self._prev_scene_id = next_episode.scene_id | ||
return next_episode | ||
|
@@ -355,9 +363,60 @@ def _shuffle_iterator(self) -> None: | |
list(g) | ||
for k, g in groupby(self._iterator, key=lambda x: x.scene_id) | ||
] | ||
|
||
random.shuffle(grouped_episodes) | ||
for i in range(len(grouped_episodes)): | ||
erikwijmans marked this conversation as resolved.
Show resolved
Hide resolved
|
||
random.shuffle(grouped_episodes[i]) | ||
|
||
self._iterator = iter(sum(grouped_episodes, [])) | ||
else: | ||
episodes = list(self._iterator) | ||
random.shuffle(episodes) | ||
self._iterator = iter(episodes) | ||
|
||
def step_taken(self): | ||
self._step_count += 1 | ||
|
||
@staticmethod | ||
def _randomize_value(value, interval): | ||
return random.randint( | ||
int(value * (1 - interval)), int(value * (1 + interval)) | ||
) | ||
|
||
def _set_shuffle_intervals(self): | ||
if self.max_scene_repetition_episodes > 0: | ||
self._max_rep_episode = self._randomize_value( | ||
erikwijmans marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.max_scene_repetition_episodes, | ||
self._repetition_rand_interval, | ||
) | ||
else: | ||
self._max_rep_episode = None | ||
|
||
if self.max_scene_repetition_steps > 0: | ||
self._max_rep_step = self._randomize_value( | ||
self.max_scene_repetition_steps, self._repetition_rand_interval | ||
) | ||
else: | ||
self._max_rep_step = None | ||
|
||
def _switch_scene_if(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question here: would it be switching to frequently at the later stage of training? with potentially less than 100 episode per scene switch? Will it help to incorporate both count schemes, like a logical AND? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shuffling on steps is consistent from an optimization standpoint, you swap scenes every N parameters updates, which is why I like it :) I don't think you can switch scenes "too often". Ideally, we'd just randomly sample a new episode irrespective of the scene it is in, but this incurs the non-trivial cost of swapping the scene way too often. |
||
do_switch = False | ||
self._rep_count += 1 | ||
|
||
# Shuffle if a scene has been selected more than _max_rep_episode times in a row | ||
if ( | ||
self._max_rep_episode is not None | ||
and self._rep_count > self._max_rep_episode | ||
): | ||
do_switch = True | ||
|
||
# Shuffle if a scene has been used for more than _max_rep_step steps in a row | ||
if ( | ||
self._max_rep_step is not None | ||
and self._step_count >= self._max_rep_step | ||
): | ||
do_switch = True | ||
|
||
if do_switch: | ||
self._shuffle_iterator() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we check for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's strange as there was no option to enable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The cycle logic doesn't rely on the shuffle |
||
self._set_shuffle_intervals() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ | |
from gym.spaces.dict_space import Dict as SpaceDict | ||
|
||
from habitat.config import Config | ||
from habitat.core.dataset import Dataset, Episode | ||
from habitat.core.dataset import Dataset, Episode, EpisodeIterator | ||
from habitat.core.embodied_task import EmbodiedTask, Metrics | ||
from habitat.core.simulator import Observations, Simulator | ||
from habitat.datasets import make_dataset | ||
|
@@ -213,6 +213,11 @@ def _update_step_stats(self) -> None: | |
if self._past_limit(): | ||
self._episode_over = True | ||
|
||
if self.episode_iterator is not None and isinstance( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A separate PR probably makes sense for changing that functionality (dataset is also marked as optional). |
||
self.episode_iterator, EpisodeIterator | ||
): | ||
self.episode_iterator.step_taken() | ||
|
||
def step( | ||
self, action: Union[int, str, Dict[str, Any]], **kwargs | ||
) -> Observations: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move
repetition_rand_interval
toinit
argument with default value, otherwise no option to turn it off.