As long as you have a way to map (input, weights) to the results, it will be fine. Depending on your needs, the following may or may not be the best approach.
class LeNet(nn.Module):
def __init__(self, num_classes=10, use_dropout=True):
super(LeNet, self).__init__()
self.use_dropout = use_dropout
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, num_classes)
def get_weights(self):
return (self.conv1.weight, self.conv1.bias,
self.conv2.weight, self.conv2.bias,
self.fc1.weight, self.fc1.bias,
self.fc2.weight, self.fc2.bias)
def forward_with_weights(self, x,
conv1_w, conv1_b,
conv2_w, conv2_b,
fc1_w, fc1_b,
fc2_w, fc2_b):
x = F.conv2d(x, conv1_w, conv1_b)
x = F.relu(F.max_pool2d(x, 2))
x = F.conv2d(x, conv2_w, conv2_b)
if self.use_dropout:
x = F.dropout2d(x, self.training)
x = F.relu(F.max_pool2d(x, 2))
x = x.view(-1, 320)
x = F.linear(x, fc1_w, fc1_b)
x = F.relu(x)
if self.use_dropout:
x = F.dropout(x, training=self.training)
return F.linear(x, fc2_w, fc2_b)
def forward(self, x):
return self.forward_with_weights(x, *self.get_weights())