Skip to content
This repository was archived by the owner on Oct 4, 2023. It is now read-only.

Commit 3e77bba

Browse files
Pyreach sync 20220217
1 parent 9e40faf commit 3e77bba

38 files changed

+3261
-31
lines changed

pyreach/gyms/devices/annotation_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,6 @@ def validate(self, host: pyreach.Host) -> str:
185185
return str(pyreach_error)
186186
return ""
187187

188-
def synchronize(self) -> None:
188+
def synchronize(self, host: pyreach.Host) -> None:
189189
"""Force the annotation device synchronize its observations."""
190190
pass

pyreach/gyms/devices/arm_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def get_observation(self,
361361
ts, arm_state.sequence),)
362362
return observation, snapshot_reference, tuple(responses)
363363

364-
def synchronize(self) -> None:
364+
def synchronize(self, host: pyreach.Host) -> None:
365365
"""Synchronously update the arm state."""
366366
if self._arm:
367367
_ = self._arm.fetch_state()

pyreach/gyms/devices/color_camera_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def get_observation(self,
239239
ts, color_frame.sequence),)
240240
return observation, snapshot_reference, ()
241241

242-
def synchronize(self) -> None:
242+
def synchronize(self, host: pyreach.Host) -> None:
243243
"""Synchronously fetch an image."""
244244
if self._color_camera:
245245
_ = self._color_camera.fetch_image()

pyreach/gyms/devices/constraints_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def get_observation(self,
125125
self._counter, "constraints", self.config_name,
126126
lib_snapshot.SnapshotReference(0.0, counter)),)
127127

128-
def synchronize(self) -> None:
128+
def synchronize(self, host: pyreach.Host) -> None:
129129
"""Synchronously fetch constraints."""
130130
pass # Currently constraints don't change.
131131

pyreach/gyms/devices/depth_camera_device.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def __init__(
6363
depth_shape: Tuple[int, int] = shape
6464
color_shape: Tuple[int, int, int] = shape + (3,)
6565

66-
observation_space_dict: Dict[str, gym.spaces.Space]
6766
observation_dict: Dict[str, Any] = {
6867
"ts":
6968
gym.spaces.Box(low=0, high=sys.maxsize, shape=()),
@@ -279,7 +278,7 @@ def get_observation(self,
279278

280279
return observation, snapshot_reference, ()
281280

282-
def synchronize(self) -> None:
281+
def synchronize(self, host: pyreach.Host) -> None:
283282
"""Synchronously fetch an image."""
284283
if self._depth_camera:
285284
_ = self._depth_camera.fetch_image()

pyreach/gyms/devices/force_torque_sensor_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def get_observation(self,
147147

148148
return observation, snapshot_reference, ()
149149

150-
def synchronize(self) -> None:
150+
def synchronize(self, host: pyreach.Host) -> None:
151151
"""Synchronously update the arm state."""
152152
if self._force_torque_sensor:
153153
self._force_torque_sensor.fetch_state()

pyreach/gyms/devices/io_device.py

Lines changed: 336 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,336 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Implementation of PyReach Gym I/O Device."""
15+
16+
import collections.abc as collections_abc
17+
import copy
18+
import sys
19+
from typing import Any, cast, Dict, List, Set, Tuple, Optional, Union
20+
21+
import gym # type: ignore
22+
23+
import pyreach
24+
from pyreach import arm as pyreach_arm
25+
from pyreach import core as pyreach_core
26+
from pyreach import digital_output as pyreach_digital_output
27+
from pyreach import snapshot as lib_snapshot
28+
from pyreach.gyms import core as gyms_core
29+
from pyreach.gyms import io_element
30+
from pyreach.gyms.devices import reach_device
31+
32+
# Common type hits:
33+
DigOutput = pyreach_digital_output.DigitalOutput
34+
DigOutState = pyreach_digital_output.DigitalOutputState
35+
DigOutPinState = pyreach_digital_output.DigitalOutputPinState
36+
ImmutableDict = pyreach_core.ImmutableDictionary
37+
ReachDigitalOutput = io_element.ReachIODigitalOutput
38+
39+
40+
class ReachDeviceIO(reach_device.ReachDevice):
41+
"""Represents a Reach Io system."""
42+
43+
def __init__(self, io_config: io_element.ReachIO) -> None:
44+
"""Initialize a Io actuator.
45+
46+
Args:
47+
io_config: The io configuration information.
48+
"""
49+
reach_name: str = io_config.reach_name
50+
is_synchronous: bool = io_config.is_synchronous
51+
digital_outputs_config: Dict[str, ReachDigitalOutput] = {}
52+
53+
io_action_config: Dict[str, gym.spaces.Space] = {}
54+
io_observation_config: Dict[str, gym.spaces.Space] = {}
55+
# Eventually, there will be "digital_inputs", "analog_inputs"
56+
# and "analog_outputs" as well.
57+
if hasattr(io_config, "digital_outputs"):
58+
# Make private copy that will not be accidentally changed.
59+
digital_outputs_config = cast(Dict[str, ReachDigitalOutput],
60+
getattr(io_config, "digital_outputs"))
61+
if not isinstance(digital_outputs_config, dict):
62+
raise pyreach.PyReachError("digital_outputs config is not a dict")
63+
digital_outputs_config = copy.deepcopy(digital_outputs_config)
64+
self._digital_outputs_config = digital_outputs_config
65+
66+
# Verify digital_outputs_config types since it easy to make a mistake.
67+
action_digital_outputs: Dict[str, gym.spaces.Discrete] = {}
68+
observation_digital_outputs: Dict[str, gym.spaces.Dict] = {}
69+
pin_name: str
70+
digital_output_config: ReachDigitalOutput
71+
for pin_name, digital_output_config in digital_outputs_config.items():
72+
if pin_name in action_digital_outputs:
73+
raise pyreach.PyReachError(f"'{pin_name}' is a duplicate")
74+
if not isinstance(pin_name, str):
75+
raise pyreach.PyReachError(f"{pin_name} is not a str")
76+
if not isinstance(digital_output_config, ReachDigitalOutput):
77+
raise pyreach.PyReachError(
78+
f"{digital_output_config} is not a ReachDigitalOutput")
79+
80+
# Each action can be 0=>set 0, 1=>set 1, or 2=> no_change.
81+
action_digital_outputs[pin_name] = gym.spaces.Discrete(3)
82+
# Each observation is a time stamp and a single boolean value:
83+
observation_digital_outputs[pin_name] = gym.spaces.Dict({
84+
"state": gym.spaces.Discrete(2),
85+
"ts": gym.spaces.Box(low=0, high=sys.maxsize, shape=()),
86+
})
87+
io_action_config["digital_outputs"] = (
88+
gym.spaces.Dict(action_digital_outputs))
89+
io_observation_config["digital_outputs"] = (
90+
gym.spaces.Dict(observation_digital_outputs))
91+
92+
# Do the final configuration.
93+
action_space: gym.spaces.Dict = gym.spaces.Dict(io_action_config)
94+
observation_space: gym.spaces.Dict = (
95+
gym.spaces.Dict(io_observation_config))
96+
super().__init__(reach_name, action_space, observation_space,
97+
is_synchronous)
98+
self._digital_outputs_table: Optional[Dict[str, Tuple[DigOutput,
99+
str]]] = None
100+
self._all_digital_outputs: Tuple[DigOutput, ...] = ()
101+
102+
def _get_digital_outputs_table(
103+
self, host: pyreach.Host) -> Dict[str, Tuple[DigOutput, str]]:
104+
"""Return "gym_pin_name" => (DigitalOutput, "pin_name") table."""
105+
with self._timers_select({"!agent*", "gym.io"}):
106+
all_digital_outputs: Dict[int, DigOutput] = {} # Key is id(DigOutput)
107+
if not self._digital_outputs_table:
108+
# Be paranoid building this table since it is really easy to make a
109+
# configuration error. Give diagnostice error messages.
110+
gym_pin_name: str
111+
# ("arm", "capability", "pin")
112+
digital_outputs_table: Dict[str, Tuple[DigOutput, str]] = {}
113+
114+
reach_digital_output: ReachDigitalOutput
115+
for gym_pin_name, reach_digital_output in (
116+
self._digital_outputs_config.items()):
117+
if not isinstance(reach_digital_output, ReachDigitalOutput):
118+
raise pyreach.PyReachError(
119+
f"{reach_digital_output} is not a ReachIODigitalOutput")
120+
arm_name: str = reach_digital_output.reach_name
121+
capability_type: str = reach_digital_output.capability_type
122+
pin_name: str = reach_digital_output.pin_name
123+
124+
if arm_name not in host.arms:
125+
raise pyreach.PyReachError(
126+
f"Arm '{arm_name}' is not one of {tuple(host.arms.keys())}")
127+
arm: pyreach_arm.Arm = host.arms[arm_name]
128+
129+
digital_outputs: ImmutableDict[
130+
ImmutableDict[DigOutput]] = arm.digital_outputs
131+
if capability_type not in digital_outputs:
132+
raise pyreach.PyReachError(
133+
f"Capability type: '{capability_type}' is not one of "
134+
f"{tuple(digital_outputs.keys())}")
135+
136+
capabilities: ImmutableDict[DigOutput] = (
137+
digital_outputs[capability_type])
138+
if pin_name not in capabilities:
139+
raise pyreach.PyReachError(
140+
f"Pin name '{pin_name}' "
141+
f"is not one of {tuple(capabilities.keys())}")
142+
digital_output: DigOutput = capabilities[pin_name]
143+
all_digital_outputs[id(digital_output)] = digital_output
144+
# TODO(gramlich): Should this be conditioned on is_synchronous?
145+
digital_output.start_streaming()
146+
147+
digital_outputs_table[gym_pin_name] = (digital_output, pin_name)
148+
149+
self._digital_outputs_table = digital_outputs_table
150+
self._all_digital_outputs = tuple(all_digital_outputs.values())
151+
return self._digital_outputs_table
152+
153+
def validate(self, host: pyreach.Host) -> str:
154+
"""Validate that io is operable."""
155+
with self._timers_select({"!agent*", "gym.io"}):
156+
try:
157+
_ = self._get_digital_outputs_table(host)
158+
except pyreach.PyReachError as pyreach_error:
159+
return str(pyreach_error)
160+
return ""
161+
162+
def get_observation(self,
163+
host: pyreach.Host) -> reach_device.ObservationSnapshot:
164+
"""Return the Reach Io actuator Gym observation.
165+
166+
Args:
167+
host: The reach host to use.
168+
169+
Returns:
170+
Returns a Tuple containing the Gym Observation, a tuple of
171+
SnapshotReference objects and a tuple of SnapshotResponse objects.
172+
The next observation is a Gym Dict Space with "ts" and "state" values.
173+
174+
"""
175+
snapshots: List[lib_snapshot.SnapshotReference] = []
176+
with self._timers_select({"!agent*", "gym.io"}):
177+
# Fetch the state for each digital output.
178+
observation: gyms_core.Observation = {}
179+
digital_outputs_table: Dict[str, Tuple[DigOutput, str]] = (
180+
self._get_digital_outputs_table(host))
181+
182+
# Eventually, there will be "digital_inputs", "analog_inputs"
183+
# and "analog_outputs" as well.
184+
if digital_outputs_table:
185+
state: Optional[DigOutState]
186+
pin_state: DigOutPinState
187+
# Key: (id(digital_output), "pin_name")
188+
states: Dict[Tuple[int, str], Tuple[DigOutState, DigOutPinState]] = {}
189+
digital_output: DigOutput
190+
for digital_output in self._all_digital_outputs:
191+
if self.is_synchronous:
192+
state = digital_output.fetch_state()
193+
else:
194+
state = digital_output.state
195+
if state is None:
196+
raise pyreach.PyReachError(
197+
"No state available for "
198+
f"{digital_output.robot_name}.{digital_output.type}")
199+
assert isinstance(state, DigOutState), state
200+
for pin_state in state.pin_states:
201+
# The cast should not be needed, but mypy is complaining, so...
202+
states[(id(digital_output),
203+
pin_state.name)] = (cast(DigOutState, state), pin_state)
204+
snapshots.append(
205+
lib_snapshot.SnapshotReference(
206+
time=state.time, sequence=state.sequence))
207+
208+
# Construct the observation.
209+
digital_outputs_config: Dict[str, ReachDigitalOutput] = (
210+
self._digital_outputs_config)
211+
gym_pin_name: str
212+
digital_outputs: Dict[str, Dict[str, Union[int, Any]]] = {}
213+
reach_digital_output: ReachDigitalOutput
214+
for gym_pin_name, reach_digital_output in (
215+
digital_outputs_config.items()):
216+
pin_name: str = reach_digital_output.pin_name
217+
digital_output, pin_name = digital_outputs_table[gym_pin_name]
218+
state, pin_state = states[(id(digital_output), pin_name)]
219+
state_value: Optional[bool] = pin_state.state
220+
if not isinstance(state_value, bool):
221+
raise pyreach.PyReachError(
222+
f"Pin {gym_pin_name}.{pin_name} has no value.")
223+
digital_outputs[gym_pin_name] = {
224+
"state": int(state_value),
225+
"ts": gyms_core.Timestamp.new(state.time),
226+
}
227+
assert isinstance(observation, dict), observation
228+
observation["digital_outputs"] = digital_outputs
229+
return observation, tuple(snapshots), ()
230+
231+
def synchronize(self, host: pyreach.Host) -> None:
232+
"""Synchronously update the io state."""
233+
digital_output: DigOutput
234+
for digital_output in self._all_digital_outputs:
235+
digital_output.fetch_state()
236+
237+
def do_action(
238+
self, action: gyms_core.Action,
239+
host: pyreach.Host) -> Tuple[lib_snapshot.SnapshotGymAction, ...]:
240+
"""Set/Clear the io.
241+
242+
Args:
243+
action: The Gym Action Space to process to process. (See API document.)
244+
host: The reach host to use.
245+
246+
Returns:
247+
The list of gym action snapshots.
248+
"""
249+
250+
with self._timers_select({"!agent*", "gym.io"}):
251+
try:
252+
action_dict: gyms_core.ActionDict = self._get_action_dict(action)
253+
except pyreach.PyReachError as runtime_error:
254+
raise pyreach.PyReachError from runtime_error
255+
acceptable_types: Set[str] = set(("digital_outputs",))
256+
actual_types: Set[str] = set(action_dict.keys())
257+
extra_types: Set[str] = actual_types - acceptable_types
258+
if extra_types:
259+
raise pyreach.PyReachError(
260+
f"{extra_types} do not match {acceptable_types}")
261+
262+
snapshots: Tuple[lib_snapshot.SnapshotGymAction, ...] = ()
263+
allowed_io_types: Tuple[str, ...] = ("digital_outputs",)
264+
if not isinstance(action, collections_abc.Mapping):
265+
raise pyreach.PyReachError(f"{action} is not a dict")
266+
io_type: Any
267+
io_dict: Any
268+
for io_type, io_dict in action.items():
269+
if io_type not in allowed_io_types:
270+
raise pyreach.PyReachError(
271+
f"io type {io_type} device key is one of {allowed_io_types}")
272+
if not isinstance(io_dict, collections_abc.Mapping):
273+
raise pyreach.PyReachError(f"{io_type} value is not a dict")
274+
275+
if "digital_outputs" in action:
276+
digital_outputs_dict: Any = action["digital_outputs"]
277+
if not isinstance(digital_outputs_dict, collections_abc.Mapping):
278+
raise pyreach.PyReachError("io.digital_outputs is not dict")
279+
snapshots += self._do_digital_outputs_action(digital_outputs_dict,
280+
host)
281+
return snapshots
282+
283+
def _do_digital_outputs_action(
284+
self, digital_outputs_action: gyms_core.Action,
285+
host: pyreach.Host) -> Tuple[lib_snapshot.SnapshotGymAction, ...]:
286+
"""Perform digital outputs action."""
287+
digital_outputs_table: Dict[str, Tuple[DigOutput, str]] = (
288+
self._get_digital_outputs_table(host))
289+
290+
# Collect all pin operations on the same DigOutput together:
291+
# Use id(DigOutput) as the key for this table:
292+
Operation = Tuple[DigOutput, str, bool] # (DigOutput, "pin", True/False)
293+
arm_operations: Dict[int, List[Operation]] = {}
294+
295+
digital_output: DigOutput
296+
gym_pin_name: Any
297+
pin_value: Any
298+
pin_name: str
299+
if not isinstance(digital_outputs_action, collections_abc.Mapping):
300+
raise pyreach.PyReachError(
301+
f"action is not a dictionary: {digital_outputs_action}")
302+
for gym_pin_name, pin_value in digital_outputs_action.items():
303+
if gym_pin_name not in digital_outputs_table:
304+
raise pyreach.PyReachError(
305+
f"io.digital_io: {gym_pin_name} is not a one of "
306+
f"{tuple(digital_outputs_table.keys())}")
307+
if not isinstance(pin_value, int) and 0 <= pin_value <= 2:
308+
raise pyreach.PyReachError(f"io.digital_io.{gym_pin_name}:"
309+
f"{pin_value} is not int in range 0-2")
310+
311+
if pin_value < 2:
312+
digital_output, pin_name = digital_outputs_table[gym_pin_name]
313+
if id(digital_output) not in arm_operations:
314+
arm_operations[id(digital_output)] = []
315+
arm_operations[id(digital_output)].append(
316+
(digital_output, pin_name, pin_value == 1))
317+
318+
# Perform pin operations on a per digital output basis:
319+
operations: List[Tuple[DigOutput, str, bool]]
320+
for operations in arm_operations.values():
321+
if operations:
322+
value: bool
323+
states_list: List[Tuple[str, bool]] = []
324+
for operation in operations:
325+
digital_output, pin_name, value = operation
326+
states_list.append((pin_name, value))
327+
if self._is_synchronous:
328+
digital_output.set_pin_states(tuple(states_list))
329+
else:
330+
digital_output.async_set_pin_states(tuple(states_list))
331+
332+
return ()
333+
334+
def start_observation(self, host: pyreach.Host) -> bool:
335+
"""Start a synchronous observation."""
336+
return True

0 commit comments

Comments
 (0)