diff --git a/src/adam/model/tree.py b/src/adam/model/tree.py index e8447871..0b72dd4e 100644 --- a/src/adam/model/tree.py +++ b/src/adam/model/tree.py @@ -1,4 +1,6 @@ import dataclasses +import logging + from typing import Dict, Iterable, List, Tuple, Union from adam.model.abc_factories import Joint, Link @@ -89,32 +91,65 @@ def reduce(self, considered_joint_names: List[str]) -> "Tree": if joint.name not in considered_joint_names } ) - # TODO: neck_2, l_wrist_1, r_wrist_1 don't get lumped - while nodes_to_lump != []: - node = self.graph[nodes_to_lump.pop()] - parent_node = self.graph[node.parent.name] - - # lump the inertial properties - parent_node.link = node.parent.lump( # r_hip_1 - other=node.link, # r_hip_2 - relative_transform=node.parent_arc.spatial_transform(0), - ) - - # update the parent - node.parent = parent_node.link - # update the children - if node.name in parent_node.children: - parent_node.children.remove(node.name) - parent_node.children.append(node.children) - - # update the arcs - if node.parent_arc.name not in considered_joint_names: - parent_node.arcs.remove(node.parent_arc) - parent_node.arcs.append(node.arcs) + relative_transform = ( + lambda node: node.link.math.inv( + self.graph[node.parent.name].parent_arc.spatial_transform(0) + ) + @ node.parent_arc.spatial_transform(0) + if node.parent.name != self.root + else node.parent_arc.spatial_transform(0) + ) - # remove the node - self.graph.pop(node.name) + last = [] + leaves = [node for node in self.graph.values() if node.children == last] + + while all(leaf.name != self.root for leaf in leaves): + for leaf in leaves: + if leaf is self.graph[self.root]: + continue + + if leaf.parent_arc.name not in considered_joint_names: + # create the new node + new_node = Node( + name=leaf.parent.name, + link=None, + arcs=[], + children=None, + parent=None, + parent_arc=None, + ) + + # update the link + new_node.link = leaf.parent.lump( + other=leaf.link, + relative_transform=relative_transform(leaf), + ) + + # update the parents + new_node.parent = self.graph[leaf.parent.name].parent + new_node.parent_arc = self.graph[new_node.name].parent_arc + new_node.parent_arc.parent = ( + leaf.children[0].parent_arc.name if leaf.children != [] else [] + ) + + # update the children + new_node.children = leaf.children + + # update the arcs + if leaf.arcs != []: + for arc in leaf.arcs: + if arc.name in considered_joint_names: + new_node.arcs.append(arc) + + logging.debug(f"Removing {leaf.name}") + self.graph.pop(leaf.name) + self.graph[new_node.name] = new_node + leaves = [ + self.get_node_from_name((leaf.parent.name)) + for leaf in leaves + if leaf.name != self.root + ] return Tree(self.graph, self.root)