@@ -314,21 +314,14 @@ def __init__(
314314 eta2 = 3e-2 ,
315315 scaling = 1.0 ,
316316 ):
317- # if len(finger_tip_link_names) < 4 or len(finger_tip_link_names) > 5:
318- # raise ValueError(f"DexPilot optimizer can only be applied to hands with four or five fingers")
317+ if len (finger_tip_link_names ) < 2 or len (finger_tip_link_names ) > 5 :
318+ raise ValueError (
319+ f"DexPilot optimizer can only be applied to hands with 2 to 5 fingers, but got "
320+ f"{ len (finger_tip_link_names )} fingers."
321+ )
319322 self .num_fingers = len (finger_tip_link_names )
320323
321- if self .num_fingers == 2 : # For gripper
322- origin_link_index = [2 , 0 , 0 ]
323- task_link_index = [1 , 1 , 2 ]
324- elif self .num_fingers == 4 :
325- origin_link_index = [2 , 3 , 4 , 3 , 4 , 4 , 0 , 0 , 0 , 0 ]
326- task_link_index = [1 , 1 , 1 , 2 , 2 , 3 , 1 , 2 , 3 , 4 ]
327- elif self .num_fingers == 5 :
328- origin_link_index = [2 , 3 , 4 , 5 , 3 , 4 , 5 , 4 , 5 , 5 , 0 , 0 , 0 , 0 , 0 ]
329- task_link_index = [1 , 1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 , 4 , 1 , 2 , 3 , 4 , 5 ]
330- else :
331- raise NotImplementedError (f"Unsupported number of fingers: { self .num_fingers } " )
324+ origin_link_index , task_link_index = self .generate_link_indices (self .num_fingers )
332325
333326 if target_link_human_indices is None :
334327 target_link_human_indices = (np .stack ([origin_link_index , task_link_index ], axis = 0 ) * 4 ).astype (int )
@@ -363,23 +356,55 @@ def __init__(
363356 self .opt .set_ftol_abs (1e-6 )
364357
365358 # DexPilot cache
366- if self .num_fingers == 2 :
367- self .projected = np .zeros (1 , dtype = bool )
368- self .s2_project_index_origin = np .array ([], dtype = int )
369- self .s2_project_index_task = np .array ([], dtype = int )
370- self .projected_dist = np .array ([eta1 ] * 1 )
371- elif self .num_fingers == 4 :
372- self .projected = np .zeros (6 , dtype = bool )
373- self .s2_project_index_origin = np .array ([1 , 2 , 2 ], dtype = int )
374- self .s2_project_index_task = np .array ([0 , 0 , 1 ], dtype = int )
375- self .projected_dist = np .array ([eta1 ] * 3 + [eta2 ] * 3 )
376- elif self .num_fingers == 5 :
377- self .projected = np .zeros (10 , dtype = bool )
378- self .s2_project_index_origin = np .array ([1 , 2 , 3 , 2 , 3 , 3 ], dtype = int )
379- self .s2_project_index_task = np .array ([0 , 0 , 0 , 1 , 1 , 2 ], dtype = int )
380- self .projected_dist = np .array ([eta1 ] * 4 + [eta2 ] * 6 )
381- else :
382- raise NotImplementedError (f"Unsupported number of fingers: { self .num_fingers } " )
359+ self .projected , self .s2_project_index_origin , self .s2_project_index_task , self .projected_dist = (
360+ self .set_dexpilot_cache (self .num_fingers , eta1 , eta2 )
361+ )
362+
363+ @staticmethod
364+ def generate_link_indices (num_fingers ):
365+ """
366+ Example:
367+ >>> generate_link_indices(4)
368+ ([2, 3, 4, 3, 4, 4, 0, 0, 0, 0], [1, 1, 1, 2, 2, 3, 1, 2, 3, 4])
369+ """
370+ origin_link_index = []
371+ task_link_index = []
372+
373+ # Add indices for connections between fingers
374+ for i in range (1 , num_fingers ):
375+ for j in range (i + 1 , num_fingers + 1 ):
376+ origin_link_index .append (j )
377+ task_link_index .append (i )
378+
379+ # Add indices for connections to the base (0)
380+ for i in range (1 , num_fingers + 1 ):
381+ origin_link_index .append (0 )
382+ task_link_index .append (i )
383+
384+ return origin_link_index , task_link_index
385+
386+ @staticmethod
387+ def set_dexpilot_cache (num_fingers , eta1 , eta2 ):
388+ """
389+ Example:
390+ >>> set_dexpilot_cache(4, 0.1, 0.2)
391+ (array([False, False, False, False, False, False]),
392+ [1, 2, 2],
393+ [0, 0, 1],
394+ array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]))
395+ """
396+ projected = np .zeros (num_fingers * (num_fingers - 1 ) // 2 , dtype = bool )
397+
398+ s2_project_index_origin = []
399+ s2_project_index_task = []
400+ for i in range (0 , num_fingers - 2 ):
401+ for j in range (i + 1 , num_fingers - 1 ):
402+ s2_project_index_origin .append (j )
403+ s2_project_index_task .append (i )
404+
405+ projected_dist = np .array ([eta1 ] * (num_fingers - 1 ) + [eta2 ] * ((num_fingers - 1 ) * (num_fingers - 2 ) // 2 ))
406+
407+ return projected , s2_project_index_origin , s2_project_index_task , projected_dist
383408
384409 def get_objective_function (self , target_vector : np .ndarray , fixed_qpos : np .ndarray , last_qpos : np .ndarray ):
385410 qpos = np .zeros (self .num_joints )
0 commit comments