|
17 | 17 | from typing import List, Optional, Sequence, Tuple |
18 | 18 |
|
19 | 19 | import numpy as np |
| 20 | +from scipy.optimize import linear_sum_assignment |
20 | 21 | from dm_control import mjcf |
21 | 22 | from dm_control.composer import variation as base_variation |
22 | 23 | from dm_control.composer.observation import observable |
@@ -134,6 +135,11 @@ def _set_rewards(self) -> None: |
134 | 135 | ) |
135 | 136 | if not self._disable_fingering_reward: |
136 | 137 | self._reward_fn.add("fingering_reward", self._compute_fingering_reward) |
| 138 | + else: |
| 139 | + # use OT based fingering |
| 140 | + print('Fingering is unavailable. OT fingering reward is used.') |
| 141 | + self._reward_fn.add("ot_fingering_reward", self._compute_ot_fingering_reward) |
| 142 | + |
137 | 143 | if not self._disable_forearm_reward: |
138 | 144 | self._reward_fn.add("forearm_reward", self._compute_forearm_reward) |
139 | 145 |
|
@@ -324,6 +330,44 @@ def _distance_finger_to_key( |
324 | 330 | ) |
325 | 331 | return float(np.mean(rews)) |
326 | 332 |
|
| 333 | + def _compute_ot_fingering_reward(self, physics: mjcf.Physics) -> float: |
| 334 | + """ OT reward calculation from RP1M https://arxiv.org/abs/2408.11048 """ |
| 335 | + # calcuate fingertip positions |
| 336 | + fingertip_pos = [physics.bind(finger).xpos.copy() for finger in self.left_hand.fingertip_sites] |
| 337 | + fingertip_pos += [physics.bind(finger).xpos.copy() for finger in self.right_hand.fingertip_sites] |
| 338 | + |
| 339 | + # calcuate the positions of piano keys to press. |
| 340 | + keys_to_press = np.flatnonzero(self._goal_current[:-1]) # keys to press |
| 341 | + # if no key is pressed |
| 342 | + if keys_to_press.shape[0] == 0: |
| 343 | + return 1. |
| 344 | + |
| 345 | + # calculate key pos |
| 346 | + key_pos = [] |
| 347 | + for key in keys_to_press: |
| 348 | + key_geom = self.piano.keys[key].geom[0] |
| 349 | + key_geom_pos = physics.bind(key_geom).xpos.copy() |
| 350 | + key_geom_pos[-1] += 0.5 * physics.bind(key_geom).size[2] |
| 351 | + key_geom_pos[0] += 0.35 * physics.bind(key_geom).size[0] |
| 352 | + key_pos.append(key_geom_pos.copy()) |
| 353 | + |
| 354 | + # calcualte the distance between keys and fingers |
| 355 | + dist = np.full((len(fingertip_pos), len(key_pos)), 100.) |
| 356 | + for i, finger in enumerate(fingertip_pos): |
| 357 | + for j, key in enumerate(key_pos): |
| 358 | + dist[i, j] = np.linalg.norm(key - finger) |
| 359 | + |
| 360 | + # calculate the shortest distance |
| 361 | + row_ind, col_ind = linear_sum_assignment(dist) |
| 362 | + dist = dist[row_ind, col_ind] |
| 363 | + rews = tolerance( |
| 364 | + dist, |
| 365 | + bounds=(0, _FINGER_CLOSE_ENOUGH_TO_KEY), |
| 366 | + margin=(_FINGER_CLOSE_ENOUGH_TO_KEY * 10), |
| 367 | + sigmoid="gaussian", |
| 368 | + ) |
| 369 | + return float(np.mean(rews)) |
| 370 | + |
327 | 371 | def _update_goal_state(self) -> None: |
328 | 372 | # Observable callables get called after `after_step` but before |
329 | 373 | # `should_terminate_episode`. Since we increment `self._t_idx` in `after_step`, |
|
0 commit comments