Shortcuts

Source code for grutopia.core.runner

import json
from typing import Dict, Optional, Tuple, Union

from omni.isaac.core import World
from omni.isaac.core.loggers import DataLogger
from omni.isaac.core.prims.xform_prim import XFormPrim
from omni.isaac.core.simulation_context import SimulationContext
from omni.isaac.core.utils.stage import add_reference_to_stage  # noqa F401
from omni.physx.scripts import utils
from pxr import Usd  # noqa

# Init
from grutopia.core.runtime import SimulatorRuntime
from grutopia.core.runtime.task_runtime import TaskRuntime
from grutopia.core.scene import delete_prim_in_stage  # noqa F401
from grutopia.core.scene import create_object, create_scene  # noqa F401
from grutopia.core.task.task import BaseTask, create_task
from grutopia.core.util import AsyncRequest, log  # noqa F401
from grutopia.core.util.clear_task import clear_stage_by_prim_path


[docs]class SimulatorRunner: def __init__(self, simulator_runtime: SimulatorRuntime): self.task_runtime_manager = simulator_runtime.task_runtime_manager self._simulator_runtime = simulator_runtime physics_dt = ( self._simulator_runtime.simulator.physics_dt if self._simulator_runtime.simulator.physics_dt is not None else None ) self.rendering_dt = ( self._simulator_runtime.simulator.rendering_dt if self._simulator_runtime.simulator.rendering_dt is not None else None ) physics_dt = eval(physics_dt) if isinstance(physics_dt, str) else physics_dt self.rendering_dt = eval(self.rendering_dt) if isinstance(self.rendering_dt, str) else self.rendering_dt self.dt = physics_dt self.use_fabric = self._simulator_runtime.simulator.use_fabric log.info( f'simulator params: physics dt={self.dt}, rendering dt={self.rendering_dt}, use_fabric={self.use_fabric}' ) self.metrics_config = None self.metrics_save_path = simulator_runtime.task_runtime_manager.metrics_save_path if self.metrics_save_path != 'console': try: with open(self.metrics_save_path, 'w'): pass except Exception as e: log.error(f'Can not create result file at {self.metrics_save_path}.') raise e self._world: World = World( physics_dt=self.dt, rendering_dt=self.rendering_dt, stage_units_in_meters=1.0, sim_params={'use_fabric': self.use_fabric}, ) self._scene = self._world.scene self._stage = self._world.stage # Map task_name -> env_id: self.task_name_to_env_map = {} # finished_tasks contains all the finished tasks in current tasks dict self.finished_tasks = set() self.render_interval = ( self._simulator_runtime.simulator.rendering_interval if self._simulator_runtime.simulator.rendering_interval is not None else 5 ) log.info(f'rendering interval: {self.render_interval}') self.render_trigger = 0 self.loop = False self._render = False @property def current_tasks(self) -> dict[str, BaseTask]: return self._world._current_tasks def warm_up(self, steps=10, render=True): for _ in range(steps): self._world.step(render=render) def reload(self): self._world.reset() self._world.clear() self._world.stop() del self._world self._world = World(physics_dt=self.dt, rendering_dt=self.rendering_dt, stage_units_in_meters=1.0) self._scene = self._world.scene self._stage = self._world.stage self.reset()
[docs] def step( self, actions: Union[Dict, None] = None, render: bool = True ) -> Tuple[Dict, Dict[str, bool], Dict[str, float]]: """ Step function to advance the simulation environment by one time step. This method processes actions for active tasks, steps the simulation world, collects observations, updates metrics, and determines task terminations. It also handles rendering based on specified intervals. Args: actions (Union[Dict, None], optional): A dictionary mapping task names to another dictionary of robot names and their respective actions. If None, no actions are applied. Defaults to None. render (bool, optional): Flag indicating whether to render the simulation at this step. True triggers rendering if the render interval is met. Defaults to True. Returns: Tuple[Dict, Dict[str, bool], Dict[str, float]]: - obs (Dict): A dictionary containing observations for each task, further divided by robot names and their observation data. - terminated_status (Dict[str, bool]): A dictionary mapping task names to boolean values indicating whether the task has terminated. - reward (Dict[str, float]): A dictionary that would contain rewards for each task or robot; however, the actual return and computation of rewards is not shown in the provided code snippet. Raises: Exception: If an error occurs when applying an action to a robot, the exception is logged and re-raised, providing context about the task, robot, and current tasks state. Notes: - The `_world.step()` method advances the simulation, optionally rendering the environment based on the `render` flag and the render interval. - `get_obs()` is a method to collect observations from the simulation world, though its implementation details are not shown. - Metrics for each task are updated, and upon task completion, results are saved to a JSON file. This includes a flag 'normally_end' set to True, which seems to indicate normal termination of the task. - The function also manages a mechanism to prevent further action application and metric updates for tasks that have been marked as finished. Caution: The snippet contains a `TODO` comment suggesting there's an aspect requiring attention related to "Key optimization interval," which isn't addressed in the docstring or the code shown. """ """ ================ TODO: Key optimization interval ================= """ terminated_status = {} reward = {} obs = {} for task_name, action_dict in actions.items(): # terminated tasks will no longer apply action if task_name in self.finished_tasks: continue if task_name not in self.current_tasks: continue task = self.current_tasks.get(task_name) for name, action in action_dict.items(): if name in task.robots: try: task.robots[name].apply_action(action) except Exception as e: log.error('task_name : %s', task_name) log.error('robot_name : %s', name) log.error('current_tasks : %s', [i for i in self.current_tasks.keys()]) raise e self.render_trigger += 1 self._render = render and self.render_trigger > self.render_interval if self.render_trigger > self.render_interval: self.render_trigger = 0 # Step self._world.step(render=self._render) # Get obs obs = self.get_obs() # update metrics for task in self.current_tasks.values(): if task.is_done(): self.finished_tasks.add(task.name) log.info(f'Task {task.name} finished.') metrics_results = task.calculate_metrics() metrics_results['normally_end'] = True if self.metrics_save_path == 'console': print(json.dumps(metrics_results, indent=4)) elif self.metrics_save_path == 'none': pass else: with open(self.metrics_save_path, 'a') as f: f.write(json.dumps(metrics_results)) f.write('\n') # finished tasks will no longer update metrics if task.name not in self.finished_tasks: for metric in task.metrics.values(): metric.update(obs[task.name]) # update terminated_status and rewards for task_name in self.current_tasks.keys(): terminated_status[task_name] = False if task_name in self.finished_tasks: terminated_status[task_name] = True reward[task_name] = -1 else: r = self.current_tasks[task_name].reward reward[task_name] = r.calc() if r is not None else -1 return obs, terminated_status, reward
[docs] def get_obs(self) -> Dict: """ Get obs Returns: Dict: obs from isaac sim. """ obs = {} for task_name, task in self.current_tasks.items(): obs[task_name] = task.get_observations() # Add render obs for task_name, task_obs in obs.items(): for robot_name, robot_obs in task_obs.items(): obs[task_name][robot_name]['render'] = self._render return obs
[docs] def stop(self): """ Stop all current operations and clean up the **World** """ self._world.reset() self._world.clear() self._world.stop()
def get_current_time_step_index(self) -> int: return self._world.current_time_step_index
[docs] def reset(self, task: Optional[str] = None) -> Tuple[Dict, Union[TaskRuntime, None]]: """ Reset the task. Args: task (str): A task name to reset. if task is None, it always means the reset is invoked for the first time before agents invoke step(). Returns: Tuple[Dict, TaskRuntime]: A tuple of two values. The first is a dict of observations. The second is a TaskRuntime object representing the new task runtime. """ obs = self.get_obs() new_task_runtime: Union[TaskRuntime, None] = None if task is not None and task not in self.current_tasks: return obs, new_task_runtime # switch to next episodes new_task_runtime = self.next_episode(task) self.finished_tasks.discard(task) obs = self.get_obs() # finalize tasks if len(self.current_tasks) == 0: self._finalize() return obs, new_task_runtime
def _finalize(self): """ Finalize the tasks and do some post-processing. """ pass def world_clear(self): self._world.clear()
[docs] def clear_single_task(self, task_name: str): """ Clear single task with task_name Args: task_name (str): Task name to clear. """ if task_name not in self.current_tasks: log.warning(f'Clear task {task_name} fail. The task {task_name} is not in current_tasks.') return old_task = self.current_tasks[task_name] old_task.cleanup() del self.current_tasks[task_name] self._world._task_scene_built = False self._world._data_logger = DataLogger() log.info(f'Clear stage: /World/env_{self.task_name_to_env_map[task_name]}') clear_stage_by_prim_path(f'/World/env_{self.task_name_to_env_map[task_name]}')
[docs] def next_episode(self, task_name: Optional[str] = None) -> Union[TaskRuntime, None]: """ Switch to the next episode. This method cleanups a finished task specified by task name and then switches to the next task if exists. Args: task_name (Optional[str]): The task name of the finished task. Returns: TaskRuntime: new task runtime. Raises: RuntimeError: If the specified task_name is not found in the current tasks. """ runtime_env = None if task_name is not None: if task_name not in self.current_tasks: raise RuntimeError(f'Task with task_name {task_name} not in current task_runtime_manager.') old_task = self.current_tasks[task_name] old_task_runtime = old_task.runtime self.clear_single_task(task_name) runtime_env = old_task_runtime.env next_task_runtime: Union[TaskRuntime, None] = self.task_runtime_manager.get_next_task_runtime(runtime_env) if next_task_runtime is None: return next_task_runtime env_id = next_task_runtime.env.env_id task = create_task(next_task_runtime, self._scene) self._world.add_task(task) task.set_up_scene(self._scene) self._reset_sim_context() task.post_reset() new_task_name = f'{next_task_runtime.name}' # Map task_name and env self.task_name_to_env_map[new_task_name] = str(env_id) # Log log.info('===================== episode ========================') log.info(f'Next episode: {new_task_name} at {str(env_id)}') log.info('======================================================') return next_task_runtime
def _reset_sim_context(self): SimulationContext.reset(self._world, soft=False) self._world.scene._finalize(self._world.physics_sim_view) # noqa def get_obj(self, name: str) -> XFormPrim: return self._world.scene.get_object(name) def remove_collider(self, prim_path: str): build = self._world.stage.GetPrimAtPath(prim_path) if build.IsValid(): utils.removeCollider(build) def add_collider(self, prim_path: str): build = self._world.stage.GetPrimAtPath(prim_path) if build.IsValid(): utils.setCollider(build, approximationShape=None)