Say I have a noisy matrix-vector multiplication (MVM) operation enabled by some novel hardware. I want to simulate how this noisy MVM will impact a nn.Module inference. How can I trace and then replace all MVMs in the given nn.Module?
nn.Linear perhaps is the easiest to deal with. But attention and Conv layers are also based on MVMs. I wonder if there is a general way that can deal with any layers?
Additionally, sometimes I guess a nn.Module might directly use torch.matmul or nn.funcitonal.linear instead of defining a nn.Linear module. How can I trace these operations?
class CustomNN(nn.Module):
def __init__(self, input_dim, linear_dim, attention_dim, conv_in_channels, conv_out_channels, conv_kernel_size):
super(CustomNN, self).__init__()
self.linear = nn.Linear(input_dim, linear_dim)
self.attention = Attention(linear_dim, attention_dim)
self.conv = nn.Conv2d(conv_in_channels, conv_out_channels, conv_kernel_size)
self.fc = nn.Linear(conv_out_channels, 1)
def forward(self, x):
x = self.linear(x)
x, attention_weights = self.attention(x)
x = x.view(x.size(0), -1, 1, 1)
x = F.relu(self.conv(x))
x = F.adaptive_avg_pool2d(x, 1)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x, attention_weights
model = CustomNN(input_dim, linear_dim, attention_dim, conv_in_channels, conv_out_channels, conv_kernel_size)
traced_model = trace(model)
# Will run MVMs performed with error. For simplicity, say here the MVM error
# is just MVM(A,B) = A*B + Gauss(0,1). How to implement the trace()?
# I want it to be able to track the MVMs of the 'self.linear, self.attention, self.fc'
# in the CustomNN. More generally, trace() should ideally be able to
# deal with any layer types which use MVMs inside.
traced_model(x)
Thanks