Custom binary decision tree module. How to calculate gradient and propagate loss?

I know there are a few torch modules out there which deal with tabular data and emulate a binary decision tree. However, for one reason or another they don’t quite fit for my purpose or require a conflicting version of other modules. And, really, I’d like to better my understanding of torch in general so I’ve been trying to create my own binary decision tree which attempts to learn the split value for the branches being applied to the feature columns.

I’m not even sure if what I’m trying to do is possible, but it “feels” to me like there should be a way to do this and I’ve been pounding on it so long that it’s time I ask for help.

There are a number of inefficiencies with the code I’m sure, however, what I’m looking to get here is just the basics of being able to adjust the split_value attribute of my DecisionNode by way of calculating a gradient and passing back through during the _distribute_loss() method.

Here is what I have…

for just some features and labels to be able to pass into the final module I’ve generated random values (5 different class labels from 0-1), and I’m creating 5 features for 20 different samples, all of similar values.

import torch
import numpy as np

classes = [ 0.0, 0.25, 0.5, 0.75, 1.0]
feature_data = torch.tensor( np.random.choice( classes, (20,5))).type( torch.float32)
target_data = torch.tensor( np.random.choice( classes, (20))).type( torch.float32)

The DecisionNode Class module, which takes all features as input but will split the samples into a true_branch (node._tb) and false_branch (node._fb) using the values from its assigned feature column (node.col) and its node.split_value attribute (which is what I am trying to optimize).

class DecisionNode( torch.nn.Module):
    def __init__(self, col:int, tb=None, fb=None):
        super().__init__()
        self.col = col

        #
        # Do these all need to be Prameters, do they all need to retain_grad()? I have no idea.
        self.split_value = torch.nn.Parameter( torch.rand(1, requires_grad=True).type( torch.float32))
        self.split_value.retain_grad()
        self.value = torch.nn.Parameter( torch.empty( 1))
        self.value.retain_grad()
        self.entropy = torch.nn.Parameter( torch.tensor( 0.0).type( torch.float32))
        self.entropy.retain_grad()
        self.variance = torch.nn.Parameter( torch.tensor( 0.0).type( torch.float32))
        self.variance.retain_grad()

        self._tb = tb
        self._fb = fb
    
    @property
    def device(self):
        return next(self.parameters()).device


    def forward( self, feature_value):
        if feature_value.dim() > 1:
            results = torch.nn.Parameter( torch.zeros( feature_value.size(0), requires_grad=True).to( feature_value.device))
            results.retain_grad()
        else:
            results = torch.nn.Parameter( torch.zeros( 1,1, requires_grad=True).to( feature_value.device))
            results.retain_grad()

        if (self._tb is None) and (self._fb is None):
            with torch.no_grad():
                results[ :] = torch.nn.Parameter( torch.tensor( [ self.value]))
        else:
            with torch.no_grad():
                if self._tb is not None:
                    results[ feature_value[:, self.col].ge( self.split_value)] = self._tb( feature_value[ feature_value[:, self.col].ge( self.split_value)])
                if self._fb is not None:
                    results[ feature_value[:, self.col].lt( self.split_value)] = self._fb( feature_value[ feature_value[:, self.col].lt( self.split_value)])
        results.retain_grad()
        return results

The TorchDecisionTree Class module, which builds a tree of DecisionNode objects. the .fit() method must be called before any training or prediction in order to assign the node.value attribute of the tree’s DecisionNodes. It would then have to be called again after each adjustment to the node.split_value attribute.

I am calculating entropy and variance error values for the nodes, because I assume I will want to use this information to determine the gradient and apply the loss back into each node. Here is where I am quite stuck, and my great lack of understanding for gradients and their application has me stabbing in the dark.

class TorchDecisionTree( torch.nn.Module):
    def __init__(self, max_depth:int, n_features:int, max_split_value=1.0, min_split_value=0.0):
        super().__init__()
        self.max_depth = max_depth
        self._root_node = self._build_tree( n_features, self.max_depth)
        self._max_split_value = None
        self._min_split_value = None
        
    """
        Function: device() - utility function so I can make sure the module is on the correct device
    """
    @property
    def device( self):
        return next( self.parameters()).device
    
    """
        Entropy function.
    """
    def _entropy( self, labels):
        values, counts = torch.unique( labels, return_counts=True, sorted=False)
        p = counts / labels.size( 0)
        ent = -(p * self._log2( p)).sum()
        return ent
    
    """
        Variance function.
    """
    def _variance( self, values):
        var = torch.var( values)
        if var.isnan():  # This really shouldn't happen but I'll try and worry about this later.
            var = torch.tensor( 0.0).to( self.device)
        return var
    
    """
        Function: log2()
    """
    def _log2( self, x:torch.Tensor):
        return torch.log( x) / torch.log( torch.tensor( [ 2.0]).to( x.device))
    
    """
        Private recursive function used to build the tree.
    """
    def _build_tree( self, n_features, depth):
        node = DecisionNode( col=np.random.randint( n_features))
        if depth > 0:
            node._tb = self._build_tree( n_features, depth - 1)
            node._fb = self._build_tree( n_features, depth - 1)
        return node
    
    """
        Private function used to set value of leaf node.
    """
    def _set_node_value( self, node, feature_value, labels):
        if labels.size(0) == 0:
            return
        
        if ( node._tb is None) and ( node._fb is None):
            values, counts = torch.unique( labels, return_counts=True, sorted=False)
            node.value = torch.nn.Parameter( torch.tensor( [(values * counts).sum() / counts.sum()]))
            node.value.retain_grad()
            node.entropy.data = self._entropy( labels)
            node.variance.data = self._variance( labels)

        else:
            if node._tb is not None:
                self._set_node_value( node._tb, feature_value[ feature_value[:, node.col].ge( node.split_value)], labels[ feature_value[:, node.col].ge( node.split_value)])
                node.entropy.data = node.entropy + node._tb.entropy
                node.variance.data = node.variance + node._tb.variance

            if node._fb is not None:
                self._set_node_value( node._fb, feature_value[ feature_value[:, node.col].lt( node.split_value)], labels[ feature_value[:, node.col].lt( node.split_value)])
                node.entropy.data = node.entropy + node._fb.entropy
                node.variance.data = node.variance + node._fb.variance
                
        return
    
    """    PLACEHOLDER!!
        TODO:
        functions to distibute loss back and adjust node.split_value based on error
    """
    def _distribute_node_loss( self, node, mean_grad):
        node.split_value.data -= mean_grad
        node.split_value.data = torch.maximum( torch.minimum( node.split_value.data, self._max_split_value), self._min_split_value)
        if node._tb is not None:
            self._distribute_node_loss( node._tb, mean_grad)
        if node._fb is not None:
            self._distribute_node_loss( node._fb, mean_grad)
            
        return

    """    PLACEHOLDER!!
        Function: distribute_loss()
            pass the loss through all branches of the tree  
    """
    def _distribute_loss( self, mean_grad):
        self._distribute_node_loss( self._root_node, mean_grad)
            
        return
    
    """
        Function: show_splits() 
            - utility function just so I can look at the split values for the nodes.  Just for debugging
    """
    def show_splits( self, node):
        print( f"Col: {node.col}\tSplit: {node.split_value.item()}")
        if node._tb is not None:
            self.show_splits( node._tb)
        if node._fb is not None:
            self.show_splits( node._fb)
        return
    
    """
        Function: show_values() 
            - utility function just so I can look at the entropy variance values for the nodes.  Just for debugging
    """
    def show_values( self, node, level=0, branch='root'):
        print( f"Level: {level}\tBranch: {branch}\tEntropy: {node.entropy.item():0.5f}\tVariance: {node.variance.item():0.5f}")
        if node._tb is not None:
            self.show_values( node._tb, level=level+1, branch='_tb')
        if node._fb is not None:
            self.show_values( node._fb, level=level+1, branch='_fb')
        return
    
    """
        Function: fit()
            - set node.value based on training features and labels
    """
    def fit( self, features, labels):
        if self._max_split_value is None:
            self._max_split_value = features.max()
        if self._min_split_value is None:
            self._min_split_value = features.min()

        self._set_node_value( self._root_node, features, labels)
        return
    
    """
        Function: predict()
            - only for compatibility
    """
    def predict( self, features):
        return self.forward( features)
    
    """
        Function: forward()
            - return values based on feature inputs
    """
    def forward( self, features):
        X = self._root_node( features)
        return X

then to test the code…

tree = TorchDecisionTree( n_features=feature_data.size(1), max_depth=2)

which creates a tree and random split_value for each node, which can be seen via…

tree.show_splits( tree._root_node)

You can see that the same column is often selected for multiple nodes, which may not be the best approach, but it should still allow the model to run and train.

Col: 1	Split: 0.8780428171157837
Col: 4	Split: 0.5944340229034424
Col: 1	Split: 0.8516061305999756
Col: 3	Split: 0.8720638751983643
Col: 3	Split: 0.7830216884613037
Col: 3	Split: 0.570353627204895
Col: 1	Split: 0.11497271060943604

To fit the model to the training data…

tree.fit( feature_data, target_data)

which will now generate predictions on features passed in…

preds = tree( feature_data)
print( preds)

showing…

Parameter containing:
tensor([0.5833, 0.6667, 0.3750, 0.3750, 0.5833, 0.5833, 0.2500, 0.5833, 0.5833,
        0.2500, 0.5833, 0.6667, 0.5833, 0.5833, 0.5833, 0.6667, 0.2500, 0.5833,
        0.5833, 0.5833], requires_grad=True)

and just to check that my entropy and variance attributes for the nodes are being set (not that I’m doing the right thing with regard to summation of branch nodes, I’m not yet sure what the correct approach is here but, again, the values are affected and I should be able to use them to adjust the split_value). Ideally, for the leaf nodes, I’m looking for low entropy and low variance (if I understand binary trees correctly, which I may not).

tree.show_values( tree._root_node)

which shows…

Level: 0	Branch: root	Entropy: 5.28742	Variance: 0.82860
Level: 1	Branch: _tb	    Entropy: 1.91830	Variance: 0.61458
Level: 2	Branch: _tb	    Entropy: 1.00000	Variance: 0.28125
Level: 2	Branch: _fb	    Entropy: 0.91830	Variance: 0.33333
Level: 1	Branch: _fb	    Entropy: 3.36912	Variance: 0.21402
Level: 2	Branch: _tb	    Entropy: 1.58496	Variance: 0.06250
Level: 2	Branch: _fb	    Entropy: 1.78416	Variance: 0.15152

The main questions…

  1. How do I calculate/accumulate the gradient for node.split_value w.r.t entropy or variance loss.
  2. How do I then apply the gradient in order to adjust the node.split_value?
  3. Can this be done with the help of torch.autograd, or am I completely on my own when needing to distribute the loss?

If you’ve read and followed this far then I want to thank you very much. Any help is greatly appreciated.