|
| 1 | +# Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the MIT License [see LICENSE for details]. |
| 4 | + |
| 5 | + |
| 6 | +""" |
| 7 | +Port ManiSkill2 demonstration replay visualization to web visualizer |
| 8 | +# Reference: https://github.com/haosulab/ManiSkill2/tree/main/examples/tutorials/imitation-learning |
| 9 | +
|
| 10 | +""" |
| 11 | + |
| 12 | + |
| 13 | +import argparse |
| 14 | +import os |
| 15 | +from typing import Optional |
| 16 | + |
| 17 | +import gymnasium as gym |
| 18 | +import h5py |
| 19 | +import numpy as np |
| 20 | +import sapien.core as sapien |
| 21 | +from mani_skill2.envs import sapien_env |
| 22 | +from mani_skill2.utils.io_utils import load_json |
| 23 | +from tqdm.auto import tqdm |
| 24 | + |
| 25 | +################################################################################################ |
| 26 | +# The following code is only used for visualizer and not presented in the original ManiSkill |
| 27 | +################################################################################################ |
| 28 | +from meshcat.servers.zmqserver import start_zmq_server_as_subprocess |
| 29 | +from sim_web_visualizer import bind_visualizer_to_sapien_scene, create_sapien_visualizer |
| 30 | + |
| 31 | + |
| 32 | +def wrapped_setup_scene(self: sapien_env.BaseEnv, scene_config: Optional[sapien.SceneConfig] = None): |
| 33 | + if scene_config is None: |
| 34 | + scene_config = self._get_default_scene_config() |
| 35 | + self._scene = self._engine.create_scene(scene_config) |
| 36 | + self._scene.set_timestep(1.0 / self._sim_freq) |
| 37 | + self._scene = bind_visualizer_to_sapien_scene(self._scene, self._engine, self._renderer) |
| 38 | + |
| 39 | + |
| 40 | +def wrapped_setup_viewer(self): |
| 41 | + self._viewer.set_scene(self._scene._scene) |
| 42 | + self._viewer.scene = self._scene |
| 43 | + self._viewer.toggle_axes(False) |
| 44 | + self._viewer.toggle_camera_lines(False) |
| 45 | + |
| 46 | + |
| 47 | +start_zmq_server_as_subprocess() |
| 48 | +# Set to True if you want to keep both the original viewer and the web visualizer. A display is needed for True |
| 49 | +keep_on_screen_renderer = True |
| 50 | + |
| 51 | +create_sapien_visualizer(port=6000, host="localhost", keep_default_viewer=keep_on_screen_renderer) |
| 52 | +sapien_env.BaseEnv._setup_scene = wrapped_setup_scene |
| 53 | +sapien_env.BaseEnv._setup_viewer = wrapped_setup_viewer |
| 54 | +################################################################################################ |
| 55 | +# End of visualizer code |
| 56 | +################################################################################################ |
| 57 | + |
| 58 | + |
| 59 | +def parse_args(args=None): |
| 60 | + parser = argparse.ArgumentParser() |
| 61 | + parser.add_argument("--traj-path", type=str, required=True) |
| 62 | + parser.add_argument("--verbose", action="store_true") |
| 63 | + parser.add_argument( |
| 64 | + "--count", |
| 65 | + type=int, |
| 66 | + default=None, |
| 67 | + help="number of demonstrations to replay before exiting. By default will replay all demonstrations", |
| 68 | + ) |
| 69 | + return parser.parse_args(args) |
| 70 | + |
| 71 | + |
| 72 | +def main(args): |
| 73 | + pbar = tqdm(position=0, leave=None, unit="step", dynamic_ncols=True) |
| 74 | + |
| 75 | + # Load HDF5 containing trajectories |
| 76 | + traj_path = args.traj_path |
| 77 | + ori_h5_file = h5py.File(traj_path, "r") |
| 78 | + |
| 79 | + # Load associated json |
| 80 | + json_path = traj_path.replace(".h5", ".json") |
| 81 | + json_data = load_json(json_path) |
| 82 | + |
| 83 | + env_info = json_data["env_info"] |
| 84 | + env_id = env_info["env_id"] |
| 85 | + ori_env_kwargs = env_info["env_kwargs"] |
| 86 | + |
| 87 | + # Create a main env for replay |
| 88 | + env_kwargs = ori_env_kwargs.copy() |
| 89 | + env_kwargs["obs_mode"] = "state" |
| 90 | + env_kwargs[ |
| 91 | + "render_mode" |
| 92 | + ] = "rgb_array" # note this only affects the videos saved as RecordEpisode wrapper calls env.render |
| 93 | + env = gym.make(env_id, **env_kwargs).unwrapped |
| 94 | + if pbar is not None: |
| 95 | + pbar.set_postfix( |
| 96 | + { |
| 97 | + "control_mode": env_kwargs.get("control_mode"), |
| 98 | + "obs_mode": env_kwargs.get("obs_mode"), |
| 99 | + } |
| 100 | + ) |
| 101 | + |
| 102 | + # Prepare for recording |
| 103 | + output_dir = os.path.dirname(traj_path) |
| 104 | + ori_traj_name = os.path.splitext(os.path.basename(traj_path))[0] |
| 105 | + suffix = "{}.{}".format(env.obs_mode, env.control_mode) |
| 106 | + new_traj_name = ori_traj_name + "." + suffix |
| 107 | + |
| 108 | + episodes = json_data["episodes"][: args.count] |
| 109 | + n_ep = len(episodes) |
| 110 | + inds = np.arange(n_ep) |
| 111 | + inds = np.array_split(inds, 1)[0] |
| 112 | + |
| 113 | + # Replay |
| 114 | + for ind in inds: |
| 115 | + ep = episodes[ind] |
| 116 | + episode_id = ep["episode_id"] |
| 117 | + traj_id = f"traj_{episode_id}" |
| 118 | + if pbar is not None: |
| 119 | + pbar.set_description(f"Replaying {traj_id}") |
| 120 | + |
| 121 | + if traj_id not in ori_h5_file: |
| 122 | + tqdm.write(f"{traj_id} does not exist in {traj_path}") |
| 123 | + continue |
| 124 | + |
| 125 | + reset_kwargs = ep["reset_kwargs"].copy() |
| 126 | + if "seed" in reset_kwargs: |
| 127 | + assert reset_kwargs["seed"] == ep["episode_seed"] |
| 128 | + else: |
| 129 | + reset_kwargs["seed"] = ep["episode_seed"] |
| 130 | + seed = reset_kwargs.pop("seed") |
| 131 | + |
| 132 | + env.reset(seed=seed, options=reset_kwargs) |
| 133 | + env.render_human() |
| 134 | + |
| 135 | + # Original actions to replay |
| 136 | + ori_actions = ori_h5_file[traj_id]["actions"][:] |
| 137 | + |
| 138 | + # Original env states to replay |
| 139 | + ori_env_states = ori_h5_file[traj_id]["env_states"][1:] |
| 140 | + |
| 141 | + info = {} |
| 142 | + |
| 143 | + # Without conversion between control modes |
| 144 | + n = len(ori_actions) |
| 145 | + if pbar is not None: |
| 146 | + pbar.reset(total=n) |
| 147 | + for t, a in enumerate(ori_actions): |
| 148 | + if pbar is not None: |
| 149 | + pbar.update() |
| 150 | + env.render_human() |
| 151 | + env.set_state(ori_env_states[t]) |
| 152 | + |
| 153 | + success = info.get("success", False) |
| 154 | + |
| 155 | + # Cleanup |
| 156 | + env.close() |
| 157 | + ori_h5_file.close() |
| 158 | + |
| 159 | + if pbar is not None: |
| 160 | + pbar.close() |
| 161 | + |
| 162 | + |
| 163 | +if __name__ == "__main__": |
| 164 | + main(parse_args()) |
0 commit comments