I’ve built a neural network model and I’d like to incorporate custom functions,
encodeImage
andencodeText
, for pre-processing data. Ideally, I want these functions to be callable both during model definition and after training (post-build). However, including them directly within the model definition restricts their use to before Just-In-Time (JIT) compilation. Calls made after model building result in the functions being undefined
# The Custom Attributes I wan to add in the Model
def encode_image(self, image):
return self.visual(image.type(self.dtype))
def encode_text(self, text):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
# Image Classifier Neural Network
class ImageClassifier(nn.Module):
def __init__(self, n_qubits, n_layers, encode_image):
super().__init__()
self.model = nn.Sequential(
qlayer,
ClassicalLayer(2)
)
def forward(self, x):
result = self.model(x)
return result
batch_size = 28
channels = 1
height = 28
width = 28
example_input = torch.randn(height, width)
traced_model = torch.jit.trace(clf, example_input)
# Save JIT archive
traced_model.save('qunn.pt')
with open('model_state.pt', 'wb') as f:
save(clf.state_dict(), f)