Skip to content

Commit 0d9736c

Browse files
authored
Merge pull request #24 from zhaoyi11/rp1m
Lift the requirement of human fingering with RP1M
2 parents d9cde23 + ad4ca3d commit 0d9736c

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

robopianist/suite/tasks/piano_with_shadow_hands.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import List, Optional, Sequence, Tuple
1818

1919
import numpy as np
20+
from scipy.optimize import linear_sum_assignment
2021
from dm_control import mjcf
2122
from dm_control.composer import variation as base_variation
2223
from dm_control.composer.observation import observable
@@ -134,6 +135,11 @@ def _set_rewards(self) -> None:
134135
)
135136
if not self._disable_fingering_reward:
136137
self._reward_fn.add("fingering_reward", self._compute_fingering_reward)
138+
else:
139+
# use OT based fingering
140+
print('Fingering is unavailable. OT fingering reward is used.')
141+
self._reward_fn.add("ot_fingering_reward", self._compute_ot_fingering_reward)
142+
137143
if not self._disable_forearm_reward:
138144
self._reward_fn.add("forearm_reward", self._compute_forearm_reward)
139145

@@ -324,6 +330,44 @@ def _distance_finger_to_key(
324330
)
325331
return float(np.mean(rews))
326332

333+
def _compute_ot_fingering_reward(self, physics: mjcf.Physics) -> float:
334+
""" OT reward calculation from RP1M https://arxiv.org/abs/2408.11048 """
335+
# calcuate fingertip positions
336+
fingertip_pos = [physics.bind(finger).xpos.copy() for finger in self.left_hand.fingertip_sites]
337+
fingertip_pos += [physics.bind(finger).xpos.copy() for finger in self.right_hand.fingertip_sites]
338+
339+
# calcuate the positions of piano keys to press.
340+
keys_to_press = np.flatnonzero(self._goal_current[:-1]) # keys to press
341+
# if no key is pressed
342+
if keys_to_press.shape[0] == 0:
343+
return 1.
344+
345+
# calculate key pos
346+
key_pos = []
347+
for key in keys_to_press:
348+
key_geom = self.piano.keys[key].geom[0]
349+
key_geom_pos = physics.bind(key_geom).xpos.copy()
350+
key_geom_pos[-1] += 0.5 * physics.bind(key_geom).size[2]
351+
key_geom_pos[0] += 0.35 * physics.bind(key_geom).size[0]
352+
key_pos.append(key_geom_pos.copy())
353+
354+
# calcualte the distance between keys and fingers
355+
dist = np.full((len(fingertip_pos), len(key_pos)), 100.)
356+
for i, finger in enumerate(fingertip_pos):
357+
for j, key in enumerate(key_pos):
358+
dist[i, j] = np.linalg.norm(key - finger)
359+
360+
# calculate the shortest distance
361+
row_ind, col_ind = linear_sum_assignment(dist)
362+
dist = dist[row_ind, col_ind]
363+
rews = tolerance(
364+
dist,
365+
bounds=(0, _FINGER_CLOSE_ENOUGH_TO_KEY),
366+
margin=(_FINGER_CLOSE_ENOUGH_TO_KEY * 10),
367+
sigmoid="gaussian",
368+
)
369+
return float(np.mean(rews))
370+
327371
def _update_goal_state(self) -> None:
328372
# Observable callables get called after `after_step` but before
329373
# `should_terminate_episode`. Since we increment `self._t_idx` in `after_step`,

0 commit comments

Comments
 (0)