Hi,
I have a pre-trained CNN model on MNIST and each time load the trained weights and biases to run inference. Is there any way to skip zero operations in the conv and fc layers?
As the MNIST images are sparse so I should be expecting to have much less execution time when skipping zero operations.
Thank you.
@rzjhdi ,
PyTorch provides support for sparse tensors, which can be more efficient when working with sparse data. Convert your input tensors to sparse tensors and ensure your operations are compatible with sparse tensors.
Modify the forward pass of your convolutional and fully connected layers to skip zero-valued elements.
Implement a custom forward function that checks for zero values and skips the computation for those elements.
Some libraries are designed to handle sparse computations more efficiently. Look into libraries such as NVIDIA’s cuSPARSE or other hardware-specific optimizations if you’re using GPU.
Prune your model to remove zero-valued weights. This reduces the number of computations during inference.
PyTorch provides pruning methods in torch.nn.utils.prune
.
Use libraries like torch-sparse
which provide efficient implementations for sparse convolutional operations.
By employing these strategies, you can optimize your CNN model to skip zero operations, thereby improving inference time on sparse data like MNIST. Be sure to benchmark the performance improvements to validate the effectiveness of these methods in your specific use case.
I want to run on CPU. I used “to_sparse()” of Pytorch for input tensors, weights and biases however it returns error, seems the conv operation of Pytorch operates on dense tensors. I load the trained weights as:
model = LeNet5_sparse().to(device)
model.load_state_dict(torch.load(os.path.join(float_model, 'trained_model.pkl')))
The model and my custom convolution operation are:
class SparseConv2d(nn.Conv2d):
def forward(self, input):
output = torch.nn.functional.conv2d(input.to_sparse(), self.weight.to_sparse(), self.bias.to_sparse(), self.stride, self.padding, self.dilation, self.groups)
return output
class SparseLinear(nn.Linear):
def forward(self, input):
output = torch.nn.functional.linear(input.to_sparse(), self.weight.to_sparse(), self.bias.to_sparse())
return output
class LeNet5_sparse(nn.Module):
def init(self):
super(LeNet5_sparse, self).init()
self.conv1 = SparseConv2d(1, 6, kernel_size=5, stride=1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2)
self.conv2 = SparseConv2d(6, 16, kernel_size=5, stride=1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2)
self.fc1 = SparseLinear(256, 120)
self.relu3 = nn.ReLU()
self.fc2 = SparseLinear(120, 84)
self.relu4 = nn.ReLU()
self.fc3 = SparseLinear(84, 10)
self.relu5 = nn.ReLU()
def forward(self, x):
y = self.conv1(x)
y = self.relu1(y)
y = self.pool1(y)
y = self.conv2(y)
y = self.relu2(y)
y = self.pool2(y)
y = y.view(y.shape[0], -1)
y = self.fc1(y)
y = self.relu3(y)
y = self.fc2(y)
y = self.relu4(y)
y = self.fc3(y)
y = self.relu5(y)
return y
Before, I was familiar with different sparse tensor techniques and I assumed that if only I pass the sparse tensor to conv2d of Pytorch, it automatically implements conv operation on non-zero values (for instance by using their indices). Could you please check where is my mistake? Thanks