Skip to content

Commit 4b9ba16

Browse files
committed
fix: fix bug from serial_chain
1 parent eaac9a2 commit 4b9ba16

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

cgf/robotics.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ def __init__(
3535

3636
if os.path.exists(urdf_str_or_path):
3737
with open(urdf_str_or_path) as f:
38-
chain: pk.chain.Chain = pk.build_chain_from_urdf(f.read())
38+
self.chain: pk.chain.Chain = pk.build_chain_from_urdf(f.read()).to(dtype=dtype, device=device)
3939
else:
40-
chain: pk.chain.Chain = pk.build_chain_from_urdf(urdf_str_or_path)
40+
self.chain: pk.chain.Chain = pk.build_chain_from_urdf(urdf_str_or_path).to(dtype=dtype, device=device)
4141

4242
if return_geometry and (geometry_path is None and geometry_data is None):
4343
raise TypeError("geometry_path or geometry_data should be set if geometry need to be returned")
@@ -57,16 +57,15 @@ def __init__(
5757
self.geometries[name] = mesh
5858

5959
self.end_links = end_links
60-
self.serial_chains: List[pk.chain.SerialChain] = []
60+
# self.serial_chains: List[pk.chain.SerialChain] = []
6161
self.return_geometry = return_geometry
6262
self.global_transform = global_transform
6363

64-
for link_name in end_links:
65-
serial_chain = pk.SerialChain(chain, link_name)
66-
serial_chain = serial_chain.to(dtype=dtype, device=device)
67-
self.serial_chains.append(serial_chain)
68-
69-
self.dof: int = len(chain.get_joint_parameter_names())
64+
# for link_name in end_links:
65+
# serial_chain = pk.SerialChain(self.chain, link_name)
66+
# serial_chain = serial_chain.to(dtype=dtype, device=device)
67+
# self.serial_chains.append(serial_chain)
68+
self.dof: int = len(self.chain.get_joint_parameter_names())
7069

7170
def forward(self, qpos: torch.Tensor):
7271
tf3ds: Dict[str, pk.Transform3d] = {}
@@ -80,11 +79,12 @@ def forward(self, qpos: torch.Tensor):
8079
tf3ds["palm"] = pk.Transform3d(batch_size, matrix=identiy)
8180

8281
start = 0 if not self.global_transform else 6
83-
for _, serial_chain in enumerate(self.serial_chains):
84-
# hard code for now
85-
joint_num = 4
86-
tf3ds.update(serial_chain.forward_kinematics(qpos[:, start : start + joint_num], end_only=False))
87-
start += joint_num
82+
# for _, serial_chain in enumerate(self.serial_chains):
83+
# # hard code for now
84+
# joint_num = 4
85+
# tf3ds.update(serial_chain.forward_kinematics(qpos[:, start : start + joint_num], end_only=False))
86+
# start += joint_num
87+
tf3ds.update(self.chain.forward_kinematics(qpos[:, start:]))
8888

8989
if self.global_transform:
9090
rot_mat = axis_angle_to_matrix(qpos[:, 3:6])
@@ -98,7 +98,6 @@ def forward(self, qpos: torch.Tensor):
9898
for name in self.geometries:
9999
geo = self.geometries[name].extend(batch_size).clone()
100100
geo = geo.update_padded(tf3ds[name].transform_points(geo.verts_padded()))
101-
# geo = geo.update_padded(cvee.rotate(tf3ds[name].transform_points(geo.verts_padded()), debug_mat[None]))
102101
geos[name] = geo
103102
return tf3ds, geos
104103

0 commit comments

Comments
 (0)