Skip to content

Commit 619030d

Browse files
committed
[update] update DexPilot teleop for hands with 2-5 fingers
1 parent b615da9 commit 619030d

File tree

3 files changed

+57
-32
lines changed

3 files changed

+57
-32
lines changed

dex_retargeting/optimizer.py

Lines changed: 55 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -314,21 +314,14 @@ def __init__(
314314
eta2=3e-2,
315315
scaling=1.0,
316316
):
317-
# if len(finger_tip_link_names) < 4 or len(finger_tip_link_names) > 5:
318-
# raise ValueError(f"DexPilot optimizer can only be applied to hands with four or five fingers")
317+
if len(finger_tip_link_names) < 2 or len(finger_tip_link_names) > 5:
318+
raise ValueError(
319+
f"DexPilot optimizer can only be applied to hands with 2 to 5 fingers, but got "
320+
f"{len(finger_tip_link_names)} fingers."
321+
)
319322
self.num_fingers = len(finger_tip_link_names)
320323

321-
if self.num_fingers == 2: # For gripper
322-
origin_link_index = [2, 0, 0]
323-
task_link_index = [1, 1, 2]
324-
elif self.num_fingers == 4:
325-
origin_link_index = [2, 3, 4, 3, 4, 4, 0, 0, 0, 0]
326-
task_link_index = [1, 1, 1, 2, 2, 3, 1, 2, 3, 4]
327-
elif self.num_fingers == 5:
328-
origin_link_index = [2, 3, 4, 5, 3, 4, 5, 4, 5, 5, 0, 0, 0, 0, 0]
329-
task_link_index = [1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 1, 2, 3, 4, 5]
330-
else:
331-
raise NotImplementedError(f"Unsupported number of fingers: {self.num_fingers}")
324+
origin_link_index, task_link_index = self.generate_link_indices(self.num_fingers)
332325

333326
if target_link_human_indices is None:
334327
target_link_human_indices = (np.stack([origin_link_index, task_link_index], axis=0) * 4).astype(int)
@@ -363,23 +356,55 @@ def __init__(
363356
self.opt.set_ftol_abs(1e-6)
364357

365358
# DexPilot cache
366-
if self.num_fingers == 2:
367-
self.projected = np.zeros(1, dtype=bool)
368-
self.s2_project_index_origin = np.array([], dtype=int)
369-
self.s2_project_index_task = np.array([], dtype=int)
370-
self.projected_dist = np.array([eta1] * 1)
371-
elif self.num_fingers == 4:
372-
self.projected = np.zeros(6, dtype=bool)
373-
self.s2_project_index_origin = np.array([1, 2, 2], dtype=int)
374-
self.s2_project_index_task = np.array([0, 0, 1], dtype=int)
375-
self.projected_dist = np.array([eta1] * 3 + [eta2] * 3)
376-
elif self.num_fingers == 5:
377-
self.projected = np.zeros(10, dtype=bool)
378-
self.s2_project_index_origin = np.array([1, 2, 3, 2, 3, 3], dtype=int)
379-
self.s2_project_index_task = np.array([0, 0, 0, 1, 1, 2], dtype=int)
380-
self.projected_dist = np.array([eta1] * 4 + [eta2] * 6)
381-
else:
382-
raise NotImplementedError(f"Unsupported number of fingers: {self.num_fingers}")
359+
self.projected, self.s2_project_index_origin, self.s2_project_index_task, self.projected_dist = (
360+
self.set_dexpilot_cache(self.num_fingers, eta1, eta2)
361+
)
362+
363+
@staticmethod
364+
def generate_link_indices(num_fingers):
365+
"""
366+
Example:
367+
>>> generate_link_indices(4)
368+
([2, 3, 4, 3, 4, 4, 0, 0, 0, 0], [1, 1, 1, 2, 2, 3, 1, 2, 3, 4])
369+
"""
370+
origin_link_index = []
371+
task_link_index = []
372+
373+
# Add indices for connections between fingers
374+
for i in range(1, num_fingers):
375+
for j in range(i + 1, num_fingers + 1):
376+
origin_link_index.append(j)
377+
task_link_index.append(i)
378+
379+
# Add indices for connections to the base (0)
380+
for i in range(1, num_fingers + 1):
381+
origin_link_index.append(0)
382+
task_link_index.append(i)
383+
384+
return origin_link_index, task_link_index
385+
386+
@staticmethod
387+
def set_dexpilot_cache(num_fingers, eta1, eta2):
388+
"""
389+
Example:
390+
>>> set_dexpilot_cache(4, 0.1, 0.2)
391+
(array([False, False, False, False, False, False]),
392+
[1, 2, 2],
393+
[0, 0, 1],
394+
array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]))
395+
"""
396+
projected = np.zeros(num_fingers * (num_fingers - 1) // 2, dtype=bool)
397+
398+
s2_project_index_origin = []
399+
s2_project_index_task = []
400+
for i in range(0, num_fingers - 2):
401+
for j in range(i + 1, num_fingers - 1):
402+
s2_project_index_origin.append(j)
403+
s2_project_index_task.append(i)
404+
405+
projected_dist = np.array([eta1] * (num_fingers - 1) + [eta2] * ((num_fingers - 1) * (num_fingers - 2) // 2))
406+
407+
return projected, s2_project_index_origin, s2_project_index_task, projected_dist
383408

384409
def get_objective_function(self, target_vector: np.ndarray, fixed_qpos: np.ndarray, last_qpos: np.ndarray):
385410
qpos = np.zeros(self.num_joints)

example/vector_retargeting/capture_webcam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def main(video_path: str, video_capture_device: Union[str, int] = 0):
3030
if cv2.waitKey(1) & 0xFF == 27:
3131
break
3232

33-
print('Recording finished')
33+
print("Recording finished")
3434
cap.release()
3535
writer.release()
3636
cv2.destroyAllWindows()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def setup_package():
107107
python_requires=">=3.7,<3.11",
108108
zip_safe=True,
109109
include_package_data=True,
110-
package_data={'dex_retargeting': ['configs/**']},
110+
package_data={"dex_retargeting": ["configs/**"]},
111111
install_requires=core_requirements,
112112
extras_require={
113113
"dev": dev_requirements,

0 commit comments

Comments
 (0)