Skip to content

Commit

Permalink
One less bridge (facebookresearch#418)
Browse files Browse the repository at this point in the history
* One less bridge
  • Loading branch information
erikwijmans authored Jul 9, 2020
1 parent 0cbaea7 commit 5877ffd
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 45 deletions.
1 change: 1 addition & 0 deletions habitat/core/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ class Simulator:
r"""Basic simulator class for habitat. New simulators to be added to habtiat
must derive from this class and implement the abstarct methods.
"""
habitat_config: Config

@property
def sensor_suite(self) -> SensorSuite:
Expand Down
2 changes: 1 addition & 1 deletion habitat/datasets/pointnav/pointnav_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def generate_pointnav_episode(

episode = _create_episode(
episode_id=episode_count,
scene_id=sim.config.SCENE,
scene_id=sim.habitat_config.SCENE,
start_position=source_position,
start_rotation=source_rotation,
target_position=target_position,
Expand Down
75 changes: 34 additions & 41 deletions habitat/sims/habitat_simulator/habitat_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def get_observation(self, sim_obs):


@registry.register_simulator(name="Sim-v0")
class HabitatSim(Simulator):
class HabitatSim(habitat_sim.Simulator, Simulator):
r"""Simulator wrapper over habitat-sim
habitat-sim repo: https://github.com/facebookresearch/habitat-sim
Expand All @@ -166,12 +166,12 @@ class HabitatSim(Simulator):
"""

def __init__(self, config: Config) -> None:
self.config = config
self.habitat_config = config
agent_config = self._get_agent_config()

sim_sensors = []
for sensor_name in agent_config.SENSORS:
sensor_cfg = getattr(self.config, sensor_name)
sensor_cfg = getattr(self.habitat_config, sensor_name)
sensor_type = registry.get_sensor(sensor_cfg.TYPE)

assert sensor_type is not None, "invalid sensor type {}".format(
Expand All @@ -182,7 +182,7 @@ def __init__(self, config: Config) -> None:
self._sensor_suite = SensorSuite(sim_sensors)
self.sim_config = self.create_sim_config(self._sensor_suite)
self._current_scene = self.sim_config.sim_cfg.scene.id
self._sim = habitat_sim.Simulator(self.sim_config)
super().__init__(self.sim_config)
self._action_space = spaces.Discrete(
len(self.sim_config.agents[0].action_space)
)
Expand All @@ -193,9 +193,10 @@ def create_sim_config(
) -> habitat_sim.Configuration:
sim_config = habitat_sim.SimulatorConfiguration()
overwrite_config(
config_from=self.config.HABITAT_SIM_V0, config_to=sim_config
config_from=self.habitat_config.HABITAT_SIM_V0,
config_to=sim_config,
)
sim_config.scene.id = self.config.SCENE
sim_config.scene.id = self.habitat_config.SCENE
agent_config = habitat_sim.AgentConfiguration()
overwrite_config(
config_from=self._get_agent_config(), config_to=agent_config
Expand All @@ -217,14 +218,14 @@ def create_sim_config(
# accessing child attributes through parent interface
sim_sensor_cfg.sensor_type = sensor.sim_sensor_type # type: ignore
sim_sensor_cfg.gpu2gpu_transfer = (
self.config.HABITAT_SIM_V0.GPU_GPU
self.habitat_config.HABITAT_SIM_V0.GPU_GPU
)
sensor_specifications.append(sim_sensor_cfg)

agent_config.sensor_specifications = sensor_specifications
agent_config.action_space = registry.get_action_space_configuration(
self.config.ACTION_SPACE_CONFIG
)(self.config).get()
self.habitat_config.ACTION_SPACE_CONFIG
)(self.habitat_config).get()

return habitat_sim.Configuration(sim_config, [agent_config])

Expand All @@ -238,7 +239,7 @@ def action_space(self) -> Space:

def _update_agents_state(self) -> bool:
is_updated = False
for agent_id, _ in enumerate(self.config.AGENTS):
for agent_id, _ in enumerate(self.habitat_config.AGENTS):
agent_cfg = self._get_agent_config(agent_id)
if agent_cfg.IS_SET_START_STATE:
self.set_agent_state(
Expand All @@ -251,15 +252,15 @@ def _update_agents_state(self) -> bool:
return is_updated

def reset(self):
sim_obs = self._sim.reset()
sim_obs = super().reset()
if self._update_agents_state():
sim_obs = self._sim.get_sensor_observations()
sim_obs = self.get_sensor_observations()

self._prev_sim_obs = sim_obs
return self._sensor_suite.get_observations(sim_obs)

def step(self, action):
sim_obs = self._sim.step(action)
sim_obs = super().step(action)
self._prev_sim_obs = sim_obs
observations = self._sensor_suite.get_observations(sim_obs)
return observations
Expand All @@ -273,7 +274,7 @@ def render(self, mode: str = "rgb") -> Any:
Returns:
rendered frame according to the mode
"""
sim_obs = self._sim.get_sensor_observations()
sim_obs = self.get_sensor_observations()
observations = self._sensor_suite.get_observations(sim_obs)

output = observations.get(mode)
Expand All @@ -285,19 +286,15 @@ def render(self, mode: str = "rgb") -> Any:

return output

def seed(self, seed):
self._sim.seed(seed)

def reconfigure(self, config: Config) -> None:
def reconfigure(self, habitat_config: Config) -> None:
# TODO(maksymets): Switch to Habitat-Sim more efficient caching
is_same_scene = config.SCENE == self._current_scene
self.config = config
is_same_scene = habitat_config.SCENE == self._current_scene
self.habitat_config = habitat_config
self.sim_config = self.create_sim_config(self._sensor_suite)
if not is_same_scene:
self._current_scene = config.SCENE
self._sim.close()
del self._sim
self._sim = habitat_sim.Simulator(self.sim_config)
self._current_scene = habitat_config.SCENE
self.close()
super().reconfigure(self.sim_config)

self._update_agents_state()

Expand All @@ -319,7 +316,7 @@ def geodesic_distance(

path.requested_start = np.array(position_a, dtype=np.float32)

self._sim.pathfinder.find_path(path)
self.pathfinder.find_path(path)

if episode is not None:
episode._shortest_path_cache = path
Expand Down Expand Up @@ -356,14 +353,14 @@ def get_straight_shortest_path_points(self, position_a, position_b):
path = habitat_sim.ShortestPath()
path.requested_start = position_a
path.requested_end = position_b
self._sim.pathfinder.find_path(path)
self.pathfinder.find_path(path)
return path.points

def sample_navigable_point(self):
return self._sim.pathfinder.get_random_navigable_point().tolist()
return self.pathfinder.get_random_navigable_point().tolist()

def is_navigable(self, point: List[float]):
return self._sim.pathfinder.is_navigable(point)
return self.pathfinder.is_navigable(point)

def semantic_annotations(self):
r"""
Expand Down Expand Up @@ -395,23 +392,20 @@ def semantic_annotations(self):
for region in level.regions:
for obj in region.objects:
"""
return self._sim.semantic_scene

def close(self):
self._sim.close()
return self.semantic_scene

def _get_agent_config(self, agent_id: Optional[int] = None) -> Any:
if agent_id is None:
agent_id = self.config.DEFAULT_AGENT_ID
agent_name = self.config.AGENTS[agent_id]
agent_config = getattr(self.config, agent_name)
agent_id = self.habitat_config.DEFAULT_AGENT_ID
agent_name = self.habitat_config.AGENTS[agent_id]
agent_config = getattr(self.habitat_config, agent_name)
return agent_config

def get_agent_state(self, agent_id: int = 0) -> habitat_sim.AgentState:
assert agent_id == 0, "No support of multi agent in {} yet.".format(
self.__class__.__name__
)
return self._sim.get_agent(agent_id).get_state()
return self.get_agent(agent_id).get_state()

def set_agent_state(
self,
Expand All @@ -437,8 +431,7 @@ def set_agent_state(
True if the set was successful else moves the agent back to its
original pose and returns false.
"""
agent = self._sim.get_agent(agent_id)
original_state = self.get_agent_state(agent_id)
agent = self.get_agent(agent_id)
new_state = self.get_agent_state(agent_id)
new_state.position = position
new_state.rotation = rotation
Expand Down Expand Up @@ -467,7 +460,7 @@ def get_observations_at(
)

if success:
sim_obs = self._sim.get_sensor_observations()
sim_obs = self.get_sensor_observations()

self._prev_sim_obs = sim_obs

Expand All @@ -483,12 +476,12 @@ def get_observations_at(
return None

def distance_to_closest_obstacle(self, position, max_search_radius=2.0):
return self._sim.pathfinder.distance_to_closest_obstacle(
return self.pathfinder.distance_to_closest_obstacle(
position, max_search_radius
)

def island_radius(self, position):
return self._sim.pathfinder.island_radius(position)
return self.pathfinder.island_radius(position)

@property
def previous_step_collided(self):
Expand Down
6 changes: 3 additions & 3 deletions habitat/tasks/nav/shortest_path_follower.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,16 @@ def __init__(
self._stop_on_error = stop_on_error

def _build_follower(self):
if self._current_scene != self._sim.config.SCENE:
self._follower = self._sim._sim.make_greedy_follower(
if self._current_scene != self._sim.habitat_config.SCENE:
self._follower = self._sim.make_greedy_follower(
0,
self._goal_radius,
stop_key=HabitatSimActions.STOP,
forward_key=HabitatSimActions.MOVE_FORWARD,
left_key=HabitatSimActions.TURN_LEFT,
right_key=HabitatSimActions.TURN_RIGHT,
)
self._current_scene = self._sim.config.SCENE
self._current_scene = self._sim.habitat_config.SCENE

def _get_return_value(self, action) -> Union[int, np.array]:
if self._return_one_hot:
Expand Down

0 comments on commit 5877ffd

Please sign in to comment.