@@ -314,17 +314,21 @@ def __init__(
314314 eta2 = 3e-2 ,
315315 scaling = 1.0 ,
316316 ):
317- if len (finger_tip_link_names ) < 4 or len (finger_tip_link_names ) > 5 :
318- raise ValueError (f"DexPilot optimizer can only be applied to hands with four or five fingers" )
319- is_four_finger = len (finger_tip_link_names ) == 4
320- if is_four_finger :
317+ # if len(finger_tip_link_names) < 4 or len(finger_tip_link_names) > 5:
318+ # raise ValueError(f"DexPilot optimizer can only be applied to hands with four or five fingers")
319+ self .num_fingers = len (finger_tip_link_names )
320+
321+ if self .num_fingers == 2 : # For gripper
322+ origin_link_index = [2 , 0 , 0 ]
323+ task_link_index = [1 , 1 , 2 ]
324+ elif self .num_fingers == 4 :
321325 origin_link_index = [2 , 3 , 4 , 3 , 4 , 4 , 0 , 0 , 0 , 0 ]
322326 task_link_index = [1 , 1 , 1 , 2 , 2 , 3 , 1 , 2 , 3 , 4 ]
323- self .num_fingers = 4
324- else :
327+ elif self .num_fingers == 5 :
325328 origin_link_index = [2 , 3 , 4 , 5 , 3 , 4 , 5 , 4 , 5 , 5 , 0 , 0 , 0 , 0 , 0 ]
326329 task_link_index = [1 , 1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 , 4 , 1 , 2 , 3 , 4 , 5 ]
327- self .num_fingers = 5
330+ else :
331+ raise NotImplementedError (f"Unsupported number of fingers: { self .num_fingers } " )
328332
329333 if target_link_human_indices is None :
330334 target_link_human_indices = (np .stack ([origin_link_index , task_link_index ], axis = 0 ) * 4 ).astype (int )
@@ -359,16 +363,23 @@ def __init__(
359363 self .opt .set_ftol_abs (1e-6 )
360364
361365 # DexPilot cache
362- if is_four_finger :
366+ if self .num_fingers == 2 :
367+ self .projected = np .zeros (1 , dtype = bool )
368+ self .s2_project_index_origin = np .array ([], dtype = int )
369+ self .s2_project_index_task = np .array ([], dtype = int )
370+ self .projected_dist = np .array ([eta1 ] * 1 )
371+ elif self .num_fingers == 4 :
363372 self .projected = np .zeros (6 , dtype = bool )
364373 self .s2_project_index_origin = np .array ([1 , 2 , 2 ], dtype = int )
365374 self .s2_project_index_task = np .array ([0 , 0 , 1 ], dtype = int )
366375 self .projected_dist = np .array ([eta1 ] * 3 + [eta2 ] * 3 )
367- else :
376+ elif self . num_fingers == 5 :
368377 self .projected = np .zeros (10 , dtype = bool )
369378 self .s2_project_index_origin = np .array ([1 , 2 , 3 , 2 , 3 , 3 ], dtype = int )
370379 self .s2_project_index_task = np .array ([0 , 0 , 0 , 1 , 1 , 2 ], dtype = int )
371380 self .projected_dist = np .array ([eta1 ] * 4 + [eta2 ] * 6 )
381+ else :
382+ raise NotImplementedError (f"Unsupported number of fingers: { self .num_fingers } " )
372383
373384 def get_objective_function (self , target_vector : np .ndarray , fixed_qpos : np .ndarray , last_qpos : np .ndarray ):
374385 qpos = np .zeros (self .num_joints )
0 commit comments