@@ -35,9 +35,9 @@ def __init__(
3535
3636 if os .path .exists (urdf_str_or_path ):
3737 with open (urdf_str_or_path ) as f :
38- chain : pk .chain .Chain = pk .build_chain_from_urdf (f .read ())
38+ self . chain : pk .chain .Chain = pk .build_chain_from_urdf (f .read ()). to ( dtype = dtype , device = device )
3939 else :
40- chain : pk .chain .Chain = pk .build_chain_from_urdf (urdf_str_or_path )
40+ self . chain : pk .chain .Chain = pk .build_chain_from_urdf (urdf_str_or_path ). to ( dtype = dtype , device = device )
4141
4242 if return_geometry and (geometry_path is None and geometry_data is None ):
4343 raise TypeError ("geometry_path or geometry_data should be set if geometry need to be returned" )
@@ -57,16 +57,15 @@ def __init__(
5757 self .geometries [name ] = mesh
5858
5959 self .end_links = end_links
60- self .serial_chains : List [pk .chain .SerialChain ] = []
60+ # self.serial_chains: List[pk.chain.SerialChain] = []
6161 self .return_geometry = return_geometry
6262 self .global_transform = global_transform
6363
64- for link_name in end_links :
65- serial_chain = pk .SerialChain (chain , link_name )
66- serial_chain = serial_chain .to (dtype = dtype , device = device )
67- self .serial_chains .append (serial_chain )
68-
69- self .dof : int = len (chain .get_joint_parameter_names ())
64+ # for link_name in end_links:
65+ # serial_chain = pk.SerialChain(self.chain, link_name)
66+ # serial_chain = serial_chain.to(dtype=dtype, device=device)
67+ # self.serial_chains.append(serial_chain)
68+ self .dof : int = len (self .chain .get_joint_parameter_names ())
7069
7170 def forward (self , qpos : torch .Tensor ):
7271 tf3ds : Dict [str , pk .Transform3d ] = {}
@@ -80,11 +79,12 @@ def forward(self, qpos: torch.Tensor):
8079 tf3ds ["palm" ] = pk .Transform3d (batch_size , matrix = identiy )
8180
8281 start = 0 if not self .global_transform else 6
83- for _ , serial_chain in enumerate (self .serial_chains ):
84- # hard code for now
85- joint_num = 4
86- tf3ds .update (serial_chain .forward_kinematics (qpos [:, start : start + joint_num ], end_only = False ))
87- start += joint_num
82+ # for _, serial_chain in enumerate(self.serial_chains):
83+ # # hard code for now
84+ # joint_num = 4
85+ # tf3ds.update(serial_chain.forward_kinematics(qpos[:, start : start + joint_num], end_only=False))
86+ # start += joint_num
87+ tf3ds .update (self .chain .forward_kinematics (qpos [:, start :]))
8888
8989 if self .global_transform :
9090 rot_mat = axis_angle_to_matrix (qpos [:, 3 :6 ])
@@ -98,7 +98,6 @@ def forward(self, qpos: torch.Tensor):
9898 for name in self .geometries :
9999 geo = self .geometries [name ].extend (batch_size ).clone ()
100100 geo = geo .update_padded (tf3ds [name ].transform_points (geo .verts_padded ()))
101- # geo = geo.update_padded(cvee.rotate(tf3ds[name].transform_points(geo.verts_padded()), debug_mat[None]))
102101 geos [name ] = geo
103102 return tf3ds , geos
104103
0 commit comments