Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Episode iterator upgrades #216

Merged
merged 18 commits into from
Oct 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions configs/test/habitat_all_sensors_test.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
ENVIRONMENT:
MAX_EPISODE_STEPS: 10
ITERATOR_OPTIONS:
SHUFFLE: False
SIMULATOR:
AGENT_0:
SENSORS: ['RGB_SENSOR', 'DEPTH_SENSOR']
Expand Down
3 changes: 3 additions & 0 deletions configs/test/habitat_mp3d_eqa_test.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
TASK:
TYPE: EQA-v0
ENVIRONMENT:
ITERATOR_OPTIONS:
SHUFFLE: False
SIMULATOR:
SCENE: data/scene_datasets/mp3d/17DRP5sb8fy/17DRP5sb8fy.glb
FORWARD_STEP_SIZE: 0.1
Expand Down
5 changes: 3 additions & 2 deletions habitat/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
_C.ENVIRONMENT.MAX_EPISODE_SECONDS = 10000000
_C.ENVIRONMENT.ITERATOR_OPTIONS = CN()
_C.ENVIRONMENT.ITERATOR_OPTIONS.CYCLE = True
_C.ENVIRONMENT.ITERATOR_OPTIONS.SHUFFLE = False
_C.ENVIRONMENT.ITERATOR_OPTIONS.SHUFFLE = True
_C.ENVIRONMENT.ITERATOR_OPTIONS.GROUP_BY_SCENE = True
_C.ENVIRONMENT.ITERATOR_OPTIONS.NUM_EPISODE_SAMPLE = -1
_C.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT = -1
_C.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_EPISODES = -1
_C.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = int(1e4)
# -----------------------------------------------------------------------------
# TASK
# -----------------------------------------------------------------------------
Expand Down
130 changes: 103 additions & 27 deletions habitat/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,10 @@ def __init__(
cycle: bool = True,
shuffle: bool = False,
group_by_scene: bool = True,
max_scene_repeat: int = -1,
max_scene_repeat_episodes: int = -1,
max_scene_repeat_steps: int = -1,
num_episode_sample: int = -1,
step_repetition_range: float = 0.2,
):
r"""..

Expand All @@ -290,32 +292,48 @@ def __init__(
effect if cycle is set to :py:`False`. Will shuffle grouped scenes
if :p:`group_by_scene` is :py:`True`.
:param group_by_scene: if :py:`True`, group episodes from same scene.
:param max_scene_repeat: threshold of how many episodes from the same
:param max_scene_repeat_episodes: threshold of how many episodes from the same
scene can be loaded consecutively. :py:`-1` for no limit
:param max_scene_repeat_steps: threshold of how many steps from the same
scene can be taken consecutively. :py:`-1` for no limit
:param num_episode_sample: number of episodes to be sampled. :py:`-1`
for no sampling.
:param step_repetition_range: The maximum number of steps within each scene is
uniformly drawn from
[1 - step_repeat_range, 1 + step_repeat_range] * max_scene_repeat_steps
on each scene switch. This stops all workers from swapping scenes at
the same time
"""

# sample episodes
if num_episode_sample >= 0:
episodes = np.random.choice(
episodes, num_episode_sample, replace=False
)

self.episodes = episodes
self.cycle = cycle
self.group_by_scene = group_by_scene
if group_by_scene:
num_scene_groups = len(
list(groupby(episodes, key=lambda x: x.scene_id))
)
num_unique_scenes = len(set([e.scene_id for e in episodes]))
if num_scene_groups >= num_unique_scenes:
self.episodes = sorted(self.episodes, key=lambda x: x.scene_id)
self.max_scene_repetition = max_scene_repeat
self.shuffle = shuffle
self._rep_count = 0

if shuffle:
random.shuffle(self.episodes)

if group_by_scene:
self.episodes = sorted(self.episodes, key=lambda x: x.scene_id)

self.max_scene_repetition_episodes = max_scene_repeat_episodes
self.max_scene_repetition_steps = max_scene_repeat_steps

self._rep_count = -1 # 0 corresponds to first episode already returned
self._step_count = 0
self._prev_scene_id = None

self._iterator = iter(self.episodes)

self.step_repetition_range = step_repetition_range
self._set_shuffle_intervals()

def __iter__(self):
return self

Expand All @@ -324,40 +342,98 @@ def __next__(self):

:return: next episode.
"""
self._forced_scene_switch_if()

next_episode = next(self._iterator, None)
if next_episode is None:
if not self.cycle:
raise StopIteration

self._iterator = iter(self.episodes)

if self.shuffle:
self._shuffle_iterator()
self._shuffle()

next_episode = next(self._iterator)

if self._prev_scene_id == next_episode.scene_id:
self._rep_count += 1
if (
self.max_scene_repetition > 0
and self._rep_count >= self.max_scene_repetition - 1
self._prev_scene_id != next_episode.scene_id
and self._prev_scene_id is not None
):
self._shuffle_iterator()
self._rep_count = 0
self._step_count = 0

self._prev_scene_id = next_episode.scene_id
return next_episode

def _shuffle_iterator(self) -> None:
def _forced_scene_switch(self) -> None:
r"""Internal method to switch the scene. Moves remaining episodes
from current scene to the end and switch to next scene episodes.
"""
grouped_episodes = [
abhiskk marked this conversation as resolved.
Show resolved Hide resolved
list(g)
for k, g in groupby(self._iterator, key=lambda x: x.scene_id)
]

if len(grouped_episodes) > 1:
# Ensure we swap by moving the current group to the end
grouped_episodes = grouped_episodes[1:] + grouped_episodes[0:1]

self._iterator = iter(sum(grouped_episodes, []))
abhiskk marked this conversation as resolved.
Show resolved Hide resolved

def _shuffle(self) -> None:
r"""Internal method that shuffles the remaining episodes.
If self.group_by_scene is true, then shuffle groups of scenes.
"""
episodes = list(self._iterator)

random.shuffle(episodes)

if self.group_by_scene:
grouped_episodes = [
list(g)
for k, g in groupby(self._iterator, key=lambda x: x.scene_id)
]
random.shuffle(grouped_episodes)
self._iterator = iter(sum(grouped_episodes, []))
episodes = sorted(episodes, key=lambda x: x.scene_id)

self._iterator = iter(episodes)

def step_taken(self):
self._step_count += 1

@staticmethod
def _randomize_value(value, value_range):
return random.randint(
int(value * (1 - value_range)), int(value * (1 + value_range))
)

def _set_shuffle_intervals(self):
if self.max_scene_repetition_episodes > 0:
self._max_rep_episode = self.max_scene_repetition_episodes
else:
self._max_rep_episode = None

if self.max_scene_repetition_steps > 0:
self._max_rep_step = self._randomize_value(
self.max_scene_repetition_steps, self.step_repetition_range
)
else:
episodes = list(self._iterator)
random.shuffle(episodes)
self._iterator = iter(episodes)
self._max_rep_step = None

def _forced_scene_switch_if(self):
do_switch = False
self._rep_count += 1

# Shuffle if a scene has been selected more than _max_rep_episode times in a row
if (
self._max_rep_episode is not None
and self._rep_count >= self._max_rep_episode
):
do_switch = True

# Shuffle if a scene has been used for more than _max_rep_step steps in a row
if (
self._max_rep_step is not None
and self._step_count >= self._max_rep_step
):
do_switch = True

if do_switch:
self._forced_scene_switch()
self._set_shuffle_intervals()
7 changes: 6 additions & 1 deletion habitat/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from gym.spaces.dict_space import Dict as SpaceDict

from habitat.config import Config
from habitat.core.dataset import Dataset, Episode
from habitat.core.dataset import Dataset, Episode, EpisodeIterator
from habitat.core.embodied_task import EmbodiedTask, Metrics
from habitat.core.simulator import Observations, Simulator
from habitat.datasets import make_dataset
Expand Down Expand Up @@ -213,6 +213,11 @@ def _update_step_stats(self) -> None:
if self._past_limit():
self._episode_over = True

if self.episode_iterator is not None and isinstance(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While self.episode_iterator mentioned as Optional I don't see Env functioning without it here. Maybe check for subclass of EpisodeIterator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A separate PR probably makes sense for changing that functionality (dataset is also marked as optional).

self.episode_iterator, EpisodeIterator
):
self.episode_iterator.step_taken()

def step(
self, action: Union[int, str, Dict[str, Any]], **kwargs
) -> Observations:
Expand Down
111 changes: 100 additions & 11 deletions test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,24 +240,113 @@ def test_iterator_shuffle():
assert len(first_round_scene_groups) == len(set(first_round_scene_groups))


def test_iterator_scene_switching():
def test_iterator_scene_switching_episodes():
total_ep = 1000
max_repeat = 25
dataset = _construct_dataset(total_ep)

episode_iter = dataset.get_episode_iterator(max_scene_repeat=max_repeat)
episode_iter = dataset.get_episode_iterator(
max_scene_repeat_episodes=max_repeat, shuffle=False, cycle=True
)
episodes = sorted(dataset.episodes, key=lambda x: x.scene_id)

# episodes before max_repeat reached should be identical
for i in range(max_repeat):
episode = next(episode_iter)
assert episode.episode_id == episodes.pop(0).episode_id
assert (
episode.episode_id == episodes.pop(0).episode_id
), "episodes before max_repeat reached should be identical"

episode = next(episode_iter)
assert (
episode.scene_id != episodes.pop(0).scene_id
), "After max_repeat episodes a scene switch doesn't happen."

remaining_episodes = list(islice(episode_iter, total_ep - max_repeat - 1))
assert len(remaining_episodes) == len(
episodes
), "Remaining episodes should be identical."

assert len(set(e.scene_id for e in remaining_episodes)) == len(
set(map(lambda ep: ep.scene_id, remaining_episodes))
), "Next episodes should still include all scenes."

cycled_episodes = list(islice(episode_iter, 4 * total_ep))
assert (
len(set(map(lambda x: x.episode_id, cycled_episodes))) == total_ep
), "Some episodes leaked after cycling."

grouped_episodes = [
list(g) for k, g in groupby(cycled_episodes, key=lambda x: x.scene_id)
]
assert (
len(sum(grouped_episodes, [])) == 4 * total_ep
), "Cycled episode iterator returned unexpected number of episodes."
assert (
len(grouped_episodes) == 4 * total_ep / max_repeat
), "The number of scene switches is unexpected."

assert all(
[len(group) == max_repeat for group in grouped_episodes]
), "Not all scene switches are equal to required number."


def test_iterator_scene_switching_episodes_without_shuffle_cycle():
total_ep = 1000
max_repeat = 25
dataset = _construct_dataset(total_ep)
episode_iter = dataset.get_episode_iterator(
max_scene_repeat_episodes=max_repeat, shuffle=False, cycle=False
)

grouped_episodes = [
list(g) for k, g in groupby(episode_iter, key=lambda x: x.scene_id)
]
assert (
len(sum(grouped_episodes, [])) == total_ep
), "The episode iterator returned unexpected number of episodes."
assert (
len(grouped_episodes) == total_ep / max_repeat
), "The number of scene switches is unexpected."

assert all(
[len(group) == max_repeat for group in grouped_episodes]
), "Not all scene stitches are equal to requirement."


def test_iterator_scene_switching_steps():
total_ep = 1000
max_repeat_steps = 250
dataset = _construct_dataset(total_ep)

remaining_episodes = list(islice(episode_iter, total_ep - max_repeat))
# remaining episodes should be same but in different order
assert len(remaining_episodes) == len(episodes)
assert remaining_episodes != episodes
assert sorted(remaining_episodes) == sorted(episodes)
episode_iter = dataset.get_episode_iterator(
max_scene_repeat_steps=max_repeat_steps,
shuffle=False,
step_repetition_range=0.0,
)
episodes = sorted(dataset.episodes, key=lambda x: x.scene_id)

episode = next(episode_iter)
assert (
episode.episode_id == episodes.pop(0).episode_id
), "After max_repeat_steps episodes a scene switch doesn't happen."

# next episodes should still be grouped by scene (before next switching)
assert len(set([e.scene_id for e in remaining_episodes[:max_repeat]])) == 1
# episodes before max_repeat reached should be identical
for _ in range(max_repeat_steps):
episode_iter.step_taken()

episode = next(episode_iter)
assert (
episode.episode_id != episodes.pop(0).episode_id
), "After max_repeat_steps episodes a scene switch doesn't happen."

remaining_episodes = list(islice(episode_iter, total_ep - 2))
assert len(remaining_episodes) == len(
episodes
), "Remaining episodes numbers aren't equal."

assert len(set(e.scene_id for e in remaining_episodes)) == len(
list(groupby(remaining_episodes, lambda ep: ep.scene_id))
), (
"Next episodes should still be grouped by scene (before next "
"switching)."
)
16 changes: 5 additions & 11 deletions test/test_habitat_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,12 @@ def test_task_actions():
}
)
agent_state = env.sim.get_agent_state()
assert (
np.allclose(
np.array(TELEPORT_POSITION, dtype=np.float32), agent_state.position
)
is True
assert np.allclose(
np.array(TELEPORT_POSITION, dtype=np.float32), agent_state.position
), "mismatch in position after teleport"
assert (
np.allclose(
np.array(TELEPORT_ROTATION, dtype=np.float32),
np.array([*agent_state.rotation.imag, agent_state.rotation.real]),
)
is True
assert np.allclose(
np.array(TELEPORT_ROTATION, dtype=np.float32),
np.array([*agent_state.rotation.imag, agent_state.rotation.real]),
), "mismatch in rotation after teleport"
env.step("TURN_RIGHT")
env.close()
Expand Down