Trying to Implement SLIDE (Sub-LInear Deep learning Engine) CPU

Hello all :slightly_smiling_face:
I’m trying to implement the algorithm described in SLIDE : In Defense of Smart Algorithms over Hardware Acceleration for Large-Scale Deep Learning Systems (https://github.com/keroro824/HashingDeepLearning) as a standalone torch.nn.module layer. For a one-line description of the algo, it is a modified version of Sampled Softmax, where the sampled nodes for each input in the batch are different and are identified using hash functions. Weights undergo hogwild style updates. The algo is designed to use multithreading to decrease training time on the CPU.

I’ve tried 2 implementations for the sparse multiplication so far.

import concurrent.futures

class SlideLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(SlideLayer, self).__init__()
        self.linear = nn.Linear(in_dim, out_dim)

    def forward_1(self, in_values, active_out_indices):
        # in_values (batch x in_dim)   // dense input
        # active_out_indices (batch x num_active_out_nodes)
        # output_values (batch x num_active_out_nodes)  // sparse output
        # assume fixed num_active_out_nodes for each input in batch

        active_weights = self.linear.weight[[active_out_indices]]
        active_biases = self.linear.bias[[active_out_indices]]
        output_values = torch.bmm(active_weights, in_values.unsqueeze(-1)).squeeze(-1) + active_biases
        return output_values

    def forward_2(self, in_values, active_out_indices):
        batch_size = active_out_indices.shape[0]
        num_active_out_indices = active_out_indices.shape[1]

        def thread_fun(sample_num):
            torch.set_num_threads = 1
            cur_output_values= torch.stack([torch.dot(in_values[sample_num], self.linear.weight[i]) + self.linear.bias[i] for i in active_out_indices[sample_num]])
            output_values[sample_num] = cur_output_values

        output_values = [0]*batch_size
        with concurrent.futures.ThreadPoolExecutor(max_workers=min(64,batch_size)) as executor:
            executor.map(thread_fun, range(batch_size))
        return output_values

Profiling indicates that the main bottleneck for forward_1 is the data copy because of advanced indexing. It requires 4x more time compared to the batch matrix multiplication (further the bmm by itself requires a small number of active nodes to prove meaningful).

I tried forward_2 to avoid the data copy, but it is 1000x slower than forward_1. I assume this is because of GIL issues?

What is the best way to go about implementing this algorithm? Can this be done effectively in Python. Would distributed dataparallel help? Should I look into C++ extensions ?

Thank You
Disclaimer: I’m a beginner in PyTorch and this is my first topic here. I apologise for the length of the post and any obvious mistakes.

Your concern about the GIL might be valid (I haven’t verified it with your code though).
Since your method seems to be similar to HOGWILD, you could have a look at the HOGWILD MNIST example, which uses multiprocessing and would thus avoid the GIL. :wink: