Skip to content

Commit

Permalink
Update parent array function in ABA
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Jan 18, 2024
1 parent 874c16c commit 1114f14
Showing 1 changed file with 24 additions and 35 deletions.
59 changes: 24 additions & 35 deletions src/adam/core/rbd_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,14 +481,7 @@ def aba(
joint_accelerations (T): The joints acceleration
"""
model = self.model.reduce(self.model.actuated_joints)

joints = list(
filter(
lambda joint: joint.name in self.model.actuated_joints,
self.model.joints.values(),
)
)

joints = list(model.joints.values())
joints.sort(key=lambda joint: joint.idx)

NB = model.N
Expand All @@ -506,10 +499,8 @@ def aba(
sdd = self.math.factory.zeros(NB, 1, 1)
B_X_W = self.math.adjoint_mixed(base_transform)

if self.model.floating_base:
IA[0] = self.model.tree.get_node_from_name(
self.root_link
).link.spatial_inertia()
if model.floating_base:
IA[0] = model.tree.get_node_from_name(self.root_link).link.spatial_inertia()
v[0] = B_X_W @ base_velocity
pA[0] = (
self.math.spatial_skew_star(v[0]) @ IA[0] @ v[0]
Expand All @@ -522,7 +513,7 @@ def get_tree_transform(self, joints) -> "Array":
Array: the tree transform
"""
relative_transform = lambda j: self.math.inv(
self.model.tree.graph[j.parent].parent_arc.spatial_transform(0)
model.tree.graph[j.child].parent_arc.spatial_transform(0)
) @ j.spatial_transform(0)

return self.math.vertcat(
Expand All @@ -539,19 +530,6 @@ def get_tree_transform(self, joints) -> "Array":

tree_transform = get_tree_transform(self, joints)

find_parent = (
lambda j: find_parent(model.tree.get_node_from_name(j.parent).parent_arc)
if model.tree.get_node_from_name(j.parent).parent_arc.idx is None
else model.tree.get_node_from_name(j.parent).parent_arc.idx
)

p = [-1] + [
model.tree.get_idx_from_name(i.parent)
if model.tree.get_idx_from_name(i.parent) < NB
else find_parent(i)
for i in joints
]

# Pass 1
for i, joint in enumerate(joints[1:], start=1):
q = joint_positions[i]
Expand All @@ -561,8 +539,11 @@ def get_tree_transform(self, joints) -> "Array":
i_X_pi[i] = joint.spatial_transform(q) @ tree_transform[i]
v_J = joint.motion_subspace() * q_dot

v[i] = i_X_pi[i] @ v[p[i]] + v_J
c[i] = i_X_pi[i] @ c[p[i]] + self.math.spatial_skew(v[i]) @ v_J
# TODO: reassign idx after reducing the model
pi = model.tree.get_idx_from_name(joint.child)

v[i] = i_X_pi[i] @ v[pi] + v_J
c[i] = i_X_pi[i] @ c[pi] + self.math.spatial_skew(v[i]) @ v_J

IA[i] = model.tree.get_node_from_name(joint.parent).link.spatial_inertia()

Expand All @@ -579,26 +560,34 @@ def get_tree_transform(self, joints) -> "Array":
):
U[i] = IA[i] @ joint.motion_subspace()
D[i] = joint.motion_subspace().T @ U[i]
u[i] = self.math.vertcat(tau[joint.idx]) - joint.motion_subspace().T @ pA[i]
u[i] = (
self.math.vertcat(tau[joint.idx]) - joint.motion_subspace().T @ pA[i]
if joint.idx is not None
else 0.0
)

Ia = IA[i] - U[i] / D[i] @ U[i].T
pa = pA[i] + Ia @ c[i] + U[i] * u[i] / D[i]

if joint.parent != self.root_link or not self.model.floating_base:
IA[p[i]] += i_X_pi[i].T @ Ia @ i_X_pi[i]
pA[p[i]] += i_X_pi[i].T @ pa
pi = model.tree.get_idx_from_name(joint.child)

if joint.parent != self.root_link or not model.floating_base:
IA[pi] += i_X_pi[i].T @ Ia @ i_X_pi[i]
pA[pi] += i_X_pi[i].T @ pa
continue

a[0] = B_X_W @ g if self.model.floating_base else self.math.solve(-IA[0], pA[0])
a[0] = B_X_W @ g if model.floating_base else self.math.solve(-IA[0], pA[0])

# Pass 3
for i, joint in enumerate(joints[1:], start=1):
if joint.parent == self.root_link:
continue

pi = model.tree.get_idx_from_name(joint.child)

sdd[i - 1] = (u[i] - U[i].T @ a[i]) / D[i]

a[i] += i_X_pi[i].T @ a[p[i]] + joint.motion_subspace() * sdd[i - 1] + c[i]
a[i] += i_X_pi[i].T @ a[pi] + joint.motion_subspace() * sdd[i - 1] + c[i]

# Squeeze sdd
s_ddot = self.math.vertcat(*[sdd[i] for i in range(sdd.shape[0])])
Expand All @@ -613,7 +602,7 @@ def get_tree_transform(self, joints) -> "Array":
return self.math.horzcat(
self.math.vertcat(
self.math.solve(B_X_W, a[0]) + g
if self.model.floating_base
if model.floating_base
else self.math.zeros(6, 1),
),
s_ddot,
Expand Down

0 comments on commit 1114f14

Please sign in to comment.