|
| 1 | +# Copyright 2021 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""Implementation of PyReach Gym I/O Device.""" |
| 15 | + |
| 16 | +import collections.abc as collections_abc |
| 17 | +import copy |
| 18 | +import sys |
| 19 | +from typing import Any, cast, Dict, List, Set, Tuple, Optional, Union |
| 20 | + |
| 21 | +import gym # type: ignore |
| 22 | + |
| 23 | +import pyreach |
| 24 | +from pyreach import arm as pyreach_arm |
| 25 | +from pyreach import core as pyreach_core |
| 26 | +from pyreach import digital_output as pyreach_digital_output |
| 27 | +from pyreach import snapshot as lib_snapshot |
| 28 | +from pyreach.gyms import core as gyms_core |
| 29 | +from pyreach.gyms import io_element |
| 30 | +from pyreach.gyms.devices import reach_device |
| 31 | + |
| 32 | +# Common type hits: |
| 33 | +DigOutput = pyreach_digital_output.DigitalOutput |
| 34 | +DigOutState = pyreach_digital_output.DigitalOutputState |
| 35 | +DigOutPinState = pyreach_digital_output.DigitalOutputPinState |
| 36 | +ImmutableDict = pyreach_core.ImmutableDictionary |
| 37 | +ReachDigitalOutput = io_element.ReachIODigitalOutput |
| 38 | + |
| 39 | + |
| 40 | +class ReachDeviceIO(reach_device.ReachDevice): |
| 41 | + """Represents a Reach Io system.""" |
| 42 | + |
| 43 | + def __init__(self, io_config: io_element.ReachIO) -> None: |
| 44 | + """Initialize a Io actuator. |
| 45 | +
|
| 46 | + Args: |
| 47 | + io_config: The io configuration information. |
| 48 | + """ |
| 49 | + reach_name: str = io_config.reach_name |
| 50 | + is_synchronous: bool = io_config.is_synchronous |
| 51 | + digital_outputs_config: Dict[str, ReachDigitalOutput] = {} |
| 52 | + |
| 53 | + io_action_config: Dict[str, gym.spaces.Space] = {} |
| 54 | + io_observation_config: Dict[str, gym.spaces.Space] = {} |
| 55 | + # Eventually, there will be "digital_inputs", "analog_inputs" |
| 56 | + # and "analog_outputs" as well. |
| 57 | + if hasattr(io_config, "digital_outputs"): |
| 58 | + # Make private copy that will not be accidentally changed. |
| 59 | + digital_outputs_config = cast(Dict[str, ReachDigitalOutput], |
| 60 | + getattr(io_config, "digital_outputs")) |
| 61 | + if not isinstance(digital_outputs_config, dict): |
| 62 | + raise pyreach.PyReachError("digital_outputs config is not a dict") |
| 63 | + digital_outputs_config = copy.deepcopy(digital_outputs_config) |
| 64 | + self._digital_outputs_config = digital_outputs_config |
| 65 | + |
| 66 | + # Verify digital_outputs_config types since it easy to make a mistake. |
| 67 | + action_digital_outputs: Dict[str, gym.spaces.Discrete] = {} |
| 68 | + observation_digital_outputs: Dict[str, gym.spaces.Dict] = {} |
| 69 | + pin_name: str |
| 70 | + digital_output_config: ReachDigitalOutput |
| 71 | + for pin_name, digital_output_config in digital_outputs_config.items(): |
| 72 | + if pin_name in action_digital_outputs: |
| 73 | + raise pyreach.PyReachError(f"'{pin_name}' is a duplicate") |
| 74 | + if not isinstance(pin_name, str): |
| 75 | + raise pyreach.PyReachError(f"{pin_name} is not a str") |
| 76 | + if not isinstance(digital_output_config, ReachDigitalOutput): |
| 77 | + raise pyreach.PyReachError( |
| 78 | + f"{digital_output_config} is not a ReachDigitalOutput") |
| 79 | + |
| 80 | + # Each action can be 0=>set 0, 1=>set 1, or 2=> no_change. |
| 81 | + action_digital_outputs[pin_name] = gym.spaces.Discrete(3) |
| 82 | + # Each observation is a time stamp and a single boolean value: |
| 83 | + observation_digital_outputs[pin_name] = gym.spaces.Dict({ |
| 84 | + "state": gym.spaces.Discrete(2), |
| 85 | + "ts": gym.spaces.Box(low=0, high=sys.maxsize, shape=()), |
| 86 | + }) |
| 87 | + io_action_config["digital_outputs"] = ( |
| 88 | + gym.spaces.Dict(action_digital_outputs)) |
| 89 | + io_observation_config["digital_outputs"] = ( |
| 90 | + gym.spaces.Dict(observation_digital_outputs)) |
| 91 | + |
| 92 | + # Do the final configuration. |
| 93 | + action_space: gym.spaces.Dict = gym.spaces.Dict(io_action_config) |
| 94 | + observation_space: gym.spaces.Dict = ( |
| 95 | + gym.spaces.Dict(io_observation_config)) |
| 96 | + super().__init__(reach_name, action_space, observation_space, |
| 97 | + is_synchronous) |
| 98 | + self._digital_outputs_table: Optional[Dict[str, Tuple[DigOutput, |
| 99 | + str]]] = None |
| 100 | + self._all_digital_outputs: Tuple[DigOutput, ...] = () |
| 101 | + |
| 102 | + def _get_digital_outputs_table( |
| 103 | + self, host: pyreach.Host) -> Dict[str, Tuple[DigOutput, str]]: |
| 104 | + """Return "gym_pin_name" => (DigitalOutput, "pin_name") table.""" |
| 105 | + with self._timers_select({"!agent*", "gym.io"}): |
| 106 | + all_digital_outputs: Dict[int, DigOutput] = {} # Key is id(DigOutput) |
| 107 | + if not self._digital_outputs_table: |
| 108 | + # Be paranoid building this table since it is really easy to make a |
| 109 | + # configuration error. Give diagnostice error messages. |
| 110 | + gym_pin_name: str |
| 111 | + # ("arm", "capability", "pin") |
| 112 | + digital_outputs_table: Dict[str, Tuple[DigOutput, str]] = {} |
| 113 | + |
| 114 | + reach_digital_output: ReachDigitalOutput |
| 115 | + for gym_pin_name, reach_digital_output in ( |
| 116 | + self._digital_outputs_config.items()): |
| 117 | + if not isinstance(reach_digital_output, ReachDigitalOutput): |
| 118 | + raise pyreach.PyReachError( |
| 119 | + f"{reach_digital_output} is not a ReachIODigitalOutput") |
| 120 | + arm_name: str = reach_digital_output.reach_name |
| 121 | + capability_type: str = reach_digital_output.capability_type |
| 122 | + pin_name: str = reach_digital_output.pin_name |
| 123 | + |
| 124 | + if arm_name not in host.arms: |
| 125 | + raise pyreach.PyReachError( |
| 126 | + f"Arm '{arm_name}' is not one of {tuple(host.arms.keys())}") |
| 127 | + arm: pyreach_arm.Arm = host.arms[arm_name] |
| 128 | + |
| 129 | + digital_outputs: ImmutableDict[ |
| 130 | + ImmutableDict[DigOutput]] = arm.digital_outputs |
| 131 | + if capability_type not in digital_outputs: |
| 132 | + raise pyreach.PyReachError( |
| 133 | + f"Capability type: '{capability_type}' is not one of " |
| 134 | + f"{tuple(digital_outputs.keys())}") |
| 135 | + |
| 136 | + capabilities: ImmutableDict[DigOutput] = ( |
| 137 | + digital_outputs[capability_type]) |
| 138 | + if pin_name not in capabilities: |
| 139 | + raise pyreach.PyReachError( |
| 140 | + f"Pin name '{pin_name}' " |
| 141 | + f"is not one of {tuple(capabilities.keys())}") |
| 142 | + digital_output: DigOutput = capabilities[pin_name] |
| 143 | + all_digital_outputs[id(digital_output)] = digital_output |
| 144 | + # TODO(gramlich): Should this be conditioned on is_synchronous? |
| 145 | + digital_output.start_streaming() |
| 146 | + |
| 147 | + digital_outputs_table[gym_pin_name] = (digital_output, pin_name) |
| 148 | + |
| 149 | + self._digital_outputs_table = digital_outputs_table |
| 150 | + self._all_digital_outputs = tuple(all_digital_outputs.values()) |
| 151 | + return self._digital_outputs_table |
| 152 | + |
| 153 | + def validate(self, host: pyreach.Host) -> str: |
| 154 | + """Validate that io is operable.""" |
| 155 | + with self._timers_select({"!agent*", "gym.io"}): |
| 156 | + try: |
| 157 | + _ = self._get_digital_outputs_table(host) |
| 158 | + except pyreach.PyReachError as pyreach_error: |
| 159 | + return str(pyreach_error) |
| 160 | + return "" |
| 161 | + |
| 162 | + def get_observation(self, |
| 163 | + host: pyreach.Host) -> reach_device.ObservationSnapshot: |
| 164 | + """Return the Reach Io actuator Gym observation. |
| 165 | +
|
| 166 | + Args: |
| 167 | + host: The reach host to use. |
| 168 | +
|
| 169 | + Returns: |
| 170 | + Returns a Tuple containing the Gym Observation, a tuple of |
| 171 | + SnapshotReference objects and a tuple of SnapshotResponse objects. |
| 172 | + The next observation is a Gym Dict Space with "ts" and "state" values. |
| 173 | +
|
| 174 | + """ |
| 175 | + snapshots: List[lib_snapshot.SnapshotReference] = [] |
| 176 | + with self._timers_select({"!agent*", "gym.io"}): |
| 177 | + # Fetch the state for each digital output. |
| 178 | + observation: gyms_core.Observation = {} |
| 179 | + digital_outputs_table: Dict[str, Tuple[DigOutput, str]] = ( |
| 180 | + self._get_digital_outputs_table(host)) |
| 181 | + |
| 182 | + # Eventually, there will be "digital_inputs", "analog_inputs" |
| 183 | + # and "analog_outputs" as well. |
| 184 | + if digital_outputs_table: |
| 185 | + state: Optional[DigOutState] |
| 186 | + pin_state: DigOutPinState |
| 187 | + # Key: (id(digital_output), "pin_name") |
| 188 | + states: Dict[Tuple[int, str], Tuple[DigOutState, DigOutPinState]] = {} |
| 189 | + digital_output: DigOutput |
| 190 | + for digital_output in self._all_digital_outputs: |
| 191 | + if self.is_synchronous: |
| 192 | + state = digital_output.fetch_state() |
| 193 | + else: |
| 194 | + state = digital_output.state |
| 195 | + if state is None: |
| 196 | + raise pyreach.PyReachError( |
| 197 | + "No state available for " |
| 198 | + f"{digital_output.robot_name}.{digital_output.type}") |
| 199 | + assert isinstance(state, DigOutState), state |
| 200 | + for pin_state in state.pin_states: |
| 201 | + # The cast should not be needed, but mypy is complaining, so... |
| 202 | + states[(id(digital_output), |
| 203 | + pin_state.name)] = (cast(DigOutState, state), pin_state) |
| 204 | + snapshots.append( |
| 205 | + lib_snapshot.SnapshotReference( |
| 206 | + time=state.time, sequence=state.sequence)) |
| 207 | + |
| 208 | + # Construct the observation. |
| 209 | + digital_outputs_config: Dict[str, ReachDigitalOutput] = ( |
| 210 | + self._digital_outputs_config) |
| 211 | + gym_pin_name: str |
| 212 | + digital_outputs: Dict[str, Dict[str, Union[int, Any]]] = {} |
| 213 | + reach_digital_output: ReachDigitalOutput |
| 214 | + for gym_pin_name, reach_digital_output in ( |
| 215 | + digital_outputs_config.items()): |
| 216 | + pin_name: str = reach_digital_output.pin_name |
| 217 | + digital_output, pin_name = digital_outputs_table[gym_pin_name] |
| 218 | + state, pin_state = states[(id(digital_output), pin_name)] |
| 219 | + state_value: Optional[bool] = pin_state.state |
| 220 | + if not isinstance(state_value, bool): |
| 221 | + raise pyreach.PyReachError( |
| 222 | + f"Pin {gym_pin_name}.{pin_name} has no value.") |
| 223 | + digital_outputs[gym_pin_name] = { |
| 224 | + "state": int(state_value), |
| 225 | + "ts": gyms_core.Timestamp.new(state.time), |
| 226 | + } |
| 227 | + assert isinstance(observation, dict), observation |
| 228 | + observation["digital_outputs"] = digital_outputs |
| 229 | + return observation, tuple(snapshots), () |
| 230 | + |
| 231 | + def synchronize(self, host: pyreach.Host) -> None: |
| 232 | + """Synchronously update the io state.""" |
| 233 | + digital_output: DigOutput |
| 234 | + for digital_output in self._all_digital_outputs: |
| 235 | + digital_output.fetch_state() |
| 236 | + |
| 237 | + def do_action( |
| 238 | + self, action: gyms_core.Action, |
| 239 | + host: pyreach.Host) -> Tuple[lib_snapshot.SnapshotGymAction, ...]: |
| 240 | + """Set/Clear the io. |
| 241 | +
|
| 242 | + Args: |
| 243 | + action: The Gym Action Space to process to process. (See API document.) |
| 244 | + host: The reach host to use. |
| 245 | +
|
| 246 | + Returns: |
| 247 | + The list of gym action snapshots. |
| 248 | + """ |
| 249 | + |
| 250 | + with self._timers_select({"!agent*", "gym.io"}): |
| 251 | + try: |
| 252 | + action_dict: gyms_core.ActionDict = self._get_action_dict(action) |
| 253 | + except pyreach.PyReachError as runtime_error: |
| 254 | + raise pyreach.PyReachError from runtime_error |
| 255 | + acceptable_types: Set[str] = set(("digital_outputs",)) |
| 256 | + actual_types: Set[str] = set(action_dict.keys()) |
| 257 | + extra_types: Set[str] = actual_types - acceptable_types |
| 258 | + if extra_types: |
| 259 | + raise pyreach.PyReachError( |
| 260 | + f"{extra_types} do not match {acceptable_types}") |
| 261 | + |
| 262 | + snapshots: Tuple[lib_snapshot.SnapshotGymAction, ...] = () |
| 263 | + allowed_io_types: Tuple[str, ...] = ("digital_outputs",) |
| 264 | + if not isinstance(action, collections_abc.Mapping): |
| 265 | + raise pyreach.PyReachError(f"{action} is not a dict") |
| 266 | + io_type: Any |
| 267 | + io_dict: Any |
| 268 | + for io_type, io_dict in action.items(): |
| 269 | + if io_type not in allowed_io_types: |
| 270 | + raise pyreach.PyReachError( |
| 271 | + f"io type {io_type} device key is one of {allowed_io_types}") |
| 272 | + if not isinstance(io_dict, collections_abc.Mapping): |
| 273 | + raise pyreach.PyReachError(f"{io_type} value is not a dict") |
| 274 | + |
| 275 | + if "digital_outputs" in action: |
| 276 | + digital_outputs_dict: Any = action["digital_outputs"] |
| 277 | + if not isinstance(digital_outputs_dict, collections_abc.Mapping): |
| 278 | + raise pyreach.PyReachError("io.digital_outputs is not dict") |
| 279 | + snapshots += self._do_digital_outputs_action(digital_outputs_dict, |
| 280 | + host) |
| 281 | + return snapshots |
| 282 | + |
| 283 | + def _do_digital_outputs_action( |
| 284 | + self, digital_outputs_action: gyms_core.Action, |
| 285 | + host: pyreach.Host) -> Tuple[lib_snapshot.SnapshotGymAction, ...]: |
| 286 | + """Perform digital outputs action.""" |
| 287 | + digital_outputs_table: Dict[str, Tuple[DigOutput, str]] = ( |
| 288 | + self._get_digital_outputs_table(host)) |
| 289 | + |
| 290 | + # Collect all pin operations on the same DigOutput together: |
| 291 | + # Use id(DigOutput) as the key for this table: |
| 292 | + Operation = Tuple[DigOutput, str, bool] # (DigOutput, "pin", True/False) |
| 293 | + arm_operations: Dict[int, List[Operation]] = {} |
| 294 | + |
| 295 | + digital_output: DigOutput |
| 296 | + gym_pin_name: Any |
| 297 | + pin_value: Any |
| 298 | + pin_name: str |
| 299 | + if not isinstance(digital_outputs_action, collections_abc.Mapping): |
| 300 | + raise pyreach.PyReachError( |
| 301 | + f"action is not a dictionary: {digital_outputs_action}") |
| 302 | + for gym_pin_name, pin_value in digital_outputs_action.items(): |
| 303 | + if gym_pin_name not in digital_outputs_table: |
| 304 | + raise pyreach.PyReachError( |
| 305 | + f"io.digital_io: {gym_pin_name} is not a one of " |
| 306 | + f"{tuple(digital_outputs_table.keys())}") |
| 307 | + if not isinstance(pin_value, int) and 0 <= pin_value <= 2: |
| 308 | + raise pyreach.PyReachError(f"io.digital_io.{gym_pin_name}:" |
| 309 | + f"{pin_value} is not int in range 0-2") |
| 310 | + |
| 311 | + if pin_value < 2: |
| 312 | + digital_output, pin_name = digital_outputs_table[gym_pin_name] |
| 313 | + if id(digital_output) not in arm_operations: |
| 314 | + arm_operations[id(digital_output)] = [] |
| 315 | + arm_operations[id(digital_output)].append( |
| 316 | + (digital_output, pin_name, pin_value == 1)) |
| 317 | + |
| 318 | + # Perform pin operations on a per digital output basis: |
| 319 | + operations: List[Tuple[DigOutput, str, bool]] |
| 320 | + for operations in arm_operations.values(): |
| 321 | + if operations: |
| 322 | + value: bool |
| 323 | + states_list: List[Tuple[str, bool]] = [] |
| 324 | + for operation in operations: |
| 325 | + digital_output, pin_name, value = operation |
| 326 | + states_list.append((pin_name, value)) |
| 327 | + if self._is_synchronous: |
| 328 | + digital_output.set_pin_states(tuple(states_list)) |
| 329 | + else: |
| 330 | + digital_output.async_set_pin_states(tuple(states_list)) |
| 331 | + |
| 332 | + return () |
| 333 | + |
| 334 | + def start_observation(self, host: pyreach.Host) -> bool: |
| 335 | + """Start a synchronous observation.""" |
| 336 | + return True |
0 commit comments