Hello all
I’m trying to implement the algorithm described in SLIDE : In Defense of Smart Algorithms over Hardware Acceleration for Large-Scale Deep Learning Systems (https://github.com/keroro824/HashingDeepLearning) as a standalone torch.nn.module layer. For a one-line description of the algo, it is a modified version of Sampled Softmax, where the sampled nodes for each input in the batch are different and are identified using hash functions. Weights undergo hogwild style updates. The algo is designed to use multithreading to decrease training time on the CPU.
I’ve tried 2 implementations for the sparse multiplication so far.
import concurrent.futures
class SlideLayer(nn.Module):
def __init__(self, in_dim, out_dim):
super(SlideLayer, self).__init__()
self.linear = nn.Linear(in_dim, out_dim)
def forward_1(self, in_values, active_out_indices):
# in_values (batch x in_dim) // dense input
# active_out_indices (batch x num_active_out_nodes)
# output_values (batch x num_active_out_nodes) // sparse output
# assume fixed num_active_out_nodes for each input in batch
active_weights = self.linear.weight[[active_out_indices]]
active_biases = self.linear.bias[[active_out_indices]]
output_values = torch.bmm(active_weights, in_values.unsqueeze(-1)).squeeze(-1) + active_biases
return output_values
def forward_2(self, in_values, active_out_indices):
batch_size = active_out_indices.shape[0]
num_active_out_indices = active_out_indices.shape[1]
def thread_fun(sample_num):
torch.set_num_threads = 1
cur_output_values= torch.stack([torch.dot(in_values[sample_num], self.linear.weight[i]) + self.linear.bias[i] for i in active_out_indices[sample_num]])
output_values[sample_num] = cur_output_values
output_values = [0]*batch_size
with concurrent.futures.ThreadPoolExecutor(max_workers=min(64,batch_size)) as executor:
executor.map(thread_fun, range(batch_size))
return output_values
Profiling indicates that the main bottleneck for forward_1 is the data copy because of advanced indexing. It requires 4x more time compared to the batch matrix multiplication (further the bmm by itself requires a small number of active nodes to prove meaningful).
I tried forward_2 to avoid the data copy, but it is 1000x slower than forward_1. I assume this is because of GIL issues?
What is the best way to go about implementing this algorithm? Can this be done effectively in Python. Would distributed dataparallel help? Should I look into C++ extensions ?
Thank You
Disclaimer: I’m a beginner in PyTorch and this is my first topic here. I apologise for the length of the post and any obvious mistakes.