Skip to content

Commit a80caf5

Browse files
committed
[fix] fix warm start for position retargeting with euler angle initialized dummy joint angles
1 parent b33035a commit a80caf5

File tree

5 files changed

+93
-49
lines changed

5 files changed

+93
-49
lines changed

dex_retargeting/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.4.3"
1+
__version__ = "0.4.4"

dex_retargeting/constants.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,24 @@
22
from pathlib import Path
33
from typing import Optional
44

5+
import numpy as np
6+
7+
OPERATOR2MANO_RIGHT = np.array(
8+
[
9+
[0, 0, -1],
10+
[-1, 0, 0],
11+
[0, 1, 0],
12+
]
13+
)
14+
15+
OPERATOR2MANO_LEFT = np.array(
16+
[
17+
[0, 0, -1],
18+
[1, 0, 0],
19+
[0, -1, 0],
20+
]
21+
)
22+
523

624
class RobotName(enum.Enum):
725
allegro = enum.auto()
@@ -59,3 +77,9 @@ def get_default_config_path(
5977
else:
6078
config_name = f"{robot_name_str}_{hand_type_str}.yml"
6179
return config_path / config_name
80+
81+
82+
OPERATOR2MANO = {
83+
HandType.right: OPERATOR2MANO_RIGHT,
84+
HandType.left: OPERATOR2MANO_LEFT,
85+
}

dex_retargeting/robot_wrapper.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,19 @@ def get_joint_index(self, name: str):
6060
def get_link_index(self, name: str):
6161
if name not in self.link_names:
6262
raise ValueError(f"{name} is not a link name. Valid link names: \n{self.link_names}")
63-
return self.model.getFrameId(name)
63+
return self.model.getFrameId(name, pin.BODY)
64+
65+
def get_joint_parent_child_frames(self, joint_name: str):
66+
joint_id = self.model.getFrameId(joint_name)
67+
parent_id = self.model.frames[joint_id].parent
68+
child_id = -1
69+
for idx, frame in enumerate(self.model.frames):
70+
if frame.previousFrame == joint_id:
71+
child_id = idx
72+
if child_id == -1:
73+
raise ValueError(f"Can not find child link of {joint_name}")
74+
75+
return parent_id, child_id
6476

6577
# -------------------------------------------------------------------------- #
6678
# Kinematics function

dex_retargeting/seq_retarget.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
from pytransform3d import rotations
66

7+
from dex_retargeting.constants import OPERATOR2MANO, HandType
78
from dex_retargeting.optimizer import Optimizer
89
from dex_retargeting.optimizer_utils import LPFilter
910

@@ -39,40 +40,37 @@ def __init__(
3940
# Warm started
4041
self.is_warm_started = False
4142

42-
# TODO: hack here
43-
self.scene = None
44-
45-
def warm_start(self, wrist_pos: np.ndarray, wrist_orientation: np.ndarray, global_rot: np.array):
43+
def warm_start(
44+
self,
45+
wrist_pos: np.ndarray,
46+
wrist_quat: np.ndarray,
47+
hand_type: HandType = HandType.right,
48+
is_mano_convention: bool = False,
49+
):
4650
"""
4751
Initialize the wrist joint pose using analytical computation instead of retargeting optimization.
4852
This function is specifically for position retargeting with the flying robot hand, i.e. has 6D free joint
4953
You are not expected to use this function for vector retargeting, e.g. when you are working on teleoperation
54+
5055
Args:
5156
wrist_pos: position of the hand wrist, typically from human hand pose
52-
wrist_orientation: orientation of the hand orientation, typically from human hand pose in MANO convention
53-
global_rot:
54-
57+
wrist_quat: quaternion of the hand wrist, the same convention as the operator frame definition if not is_mano_convention
58+
hand_type: hand type, used to determine the operator2mano matrix
59+
is_mano_convention: whether the wrist_quat is in mano convention
5560
"""
5661
# This function can only be used when the first joints of robot are free joints
57-
if len(wrist_pos) != 3:
58-
raise ValueError(f"Wrist pos:{wrist_pos} is not a 3-dim vector.")
59-
if len(wrist_orientation) != 3:
60-
raise ValueError(f"Wrist orientation:{wrist_orientation} is not a 3-dim vector.")
6162

62-
if np.linalg.norm(wrist_orientation) < 1e-3:
63-
mat = np.eye(3)
64-
else:
65-
mat = rotations.matrix_from_compact_axis_angle(wrist_orientation)
63+
if len(wrist_pos) != 3:
64+
raise ValueError(f"Wrist pos: {wrist_pos} is not a 3-dim vector.")
65+
if len(wrist_quat) != 4:
66+
raise ValueError(f"Wrist quat: {wrist_quat} is not a 4-dim vector.")
6667

68+
operator2mano = OPERATOR2MANO[hand_type] if is_mano_convention else np.eye(3)
6769
robot = self.optimizer.robot
68-
operator2mano = np.array([[0, 0, -1], [-1, 0, 0], [0, 1, 0]])
69-
mat = global_rot.T @ mat @ operator2mano
7070
target_wrist_pose = np.eye(4)
71-
target_wrist_pose[:3, :3] = mat
71+
target_wrist_pose[:3, :3] = rotations.matrix_from_quaternion(wrist_quat) @ operator2mano.T
7272
target_wrist_pose[:3, 3] = wrist_pos
7373

74-
wrist_link_name = self.optimizer.wrist_link_name
75-
wrist_link_id = self.optimizer.robot.get_link_index(wrist_link_name)
7674
name_list = [
7775
"dummy_x_translation_joint",
7876
"dummy_y_translation_joint",
@@ -81,6 +79,9 @@ def warm_start(self, wrist_pos: np.ndarray, wrist_orientation: np.ndarray, globa
8179
"dummy_y_rotation_joint",
8280
"dummy_z_rotation_joint",
8381
]
82+
wrist_link_id = robot.get_joint_parent_child_frames(name_list[5])[1]
83+
84+
# Set the dummy joints angles to zero
8485
old_qpos = robot.q0
8586
new_qpos = old_qpos.copy()
8687
for num, joint_name in enumerate(self.optimizer.target_joint_names):
@@ -128,6 +129,13 @@ def set_qpos(self, robot_qpos: np.ndarray):
128129
target_qpos = robot_qpos[self.optimizer.idx_pin2target]
129130
self.last_qpos = target_qpos
130131

132+
def get_qpos(self, fixed_qpos: np.ndarray | None = None):
133+
robot_qpos = np.zeros(self.optimizer.robot.dof)
134+
robot_qpos[self.optimizer.idx_pin2target] = self.last_qpos
135+
if fixed_qpos is not None:
136+
robot_qpos[self.optimizer.idx_pin2fixed] = fixed_qpos
137+
return robot_qpos
138+
131139
def verbose(self):
132140
min_value = self.optimizer.opt.last_optimum_value()
133141
print(f"Retargeting {self.num_retargeting} times takes: {self.accumulated_time}s")

example/position_retargeting/hand_robot_viewer.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,38 +2,22 @@
22
from pathlib import Path
33
from typing import Dict, List
44

5-
import numpy as np
65
import cv2
7-
from tqdm import trange
6+
import numpy as np
87
import sapien
9-
import transforms3d.quaternions
8+
from hand_viewer import HandDatasetSAPIENViewer
9+
from pytransform3d import rotations
10+
from tqdm import trange
1011

1112
from dex_retargeting import yourdfpy as urdf
12-
from dex_retargeting.constants import RobotName, HandType, get_default_config_path, RetargetingType
13+
from dex_retargeting.constants import (
14+
HandType,
15+
RetargetingType,
16+
RobotName,
17+
get_default_config_path,
18+
)
1319
from dex_retargeting.retargeting_config import RetargetingConfig
1420
from dex_retargeting.seq_retarget import SeqRetargeting
15-
from hand_viewer import HandDatasetSAPIENViewer
16-
17-
ROBOT2MANO = np.array(
18-
[
19-
[0, 0, -1],
20-
[-1, 0, 0],
21-
[0, 1, 0],
22-
]
23-
)
24-
ROBOT2MANO_POSE = sapien.Pose(q=transforms3d.quaternions.mat2quat(ROBOT2MANO))
25-
26-
27-
def prepare_position_retargeting(joint_pos: np.array, link_hand_indices: np.ndarray):
28-
link_pos = joint_pos[link_hand_indices]
29-
return link_pos
30-
31-
32-
def prepare_vector_retargeting(joint_pos: np.array, link_hand_indices_pairs: np.ndarray):
33-
joint_pos = joint_pos @ ROBOT2MANO
34-
origin_link_pos = joint_pos[link_hand_indices_pairs[0]]
35-
task_link_pos = joint_pos[link_hand_indices_pairs[1]]
36-
return task_link_pos - origin_link_pos
3721

3822

3923
class RobotHandDatasetSAPIENViewer(HandDatasetSAPIENViewer):
@@ -45,6 +29,7 @@ def __init__(self, robot_names: List[RobotName], hand_type: HandType, headless=F
4529
self.robot_file_names: List[str] = []
4630
self.retargetings: List[SeqRetargeting] = []
4731
self.retarget2sapien: List[np.ndarray] = []
32+
self.hand_type = hand_type
4833

4934
# Load optimizer and filter
5035
loader = self.scene.create_urdf_loader()
@@ -126,7 +111,22 @@ def render_dexycb_data(self, data: Dict, fps=5, y_offset=0.8):
126111
robot_names = "_".join(robot_names)
127112
video_path = Path(__file__).parent.resolve() / f"data/{robot_names}_video.mp4"
128113
writer = cv2.VideoWriter(
129-
str(video_path), cv2.VideoWriter_fourcc(*"mp4v"), 30.0, (self.camera.get_width(), self.camera.get_height())
114+
str(video_path),
115+
cv2.VideoWriter_fourcc(*"mp4v"),
116+
30.0,
117+
(self.camera.get_width(), self.camera.get_height()),
118+
)
119+
120+
# Warm start
121+
hand_pose_start = hand_pose[start_frame]
122+
wrist_quat = rotations.quaternion_from_compact_axis_angle(hand_pose_start[0, 0:3])
123+
vertex, joint = self._compute_hand_geometry(hand_pose_start)
124+
for robot, retargeting, retarget2sapien in zip(self.robots, self.retargetings, self.retarget2sapien):
125+
retargeting.warm_start(
126+
joint[0, :],
127+
wrist_quat,
128+
hand_type=self.hand_type,
129+
is_mano_convention=True,
130130
)
131131

132132
# Loop rendering

0 commit comments

Comments
 (0)