Hi everyone,
I have a problem with my transformer architecture (for 2D input), since it runs out of memory on the GPU already for system sizes ~6x6=36 pixels. I am using a batch size of 100, and when I decrease it the computation time is incredibly long. Does someone have experience with similar problems?
I am using the pytorch TransformerEncoder layer + caching of the layer outputs:
minibatch = self.minibatch
repeat = int(np.ceil(batch / minibatch))
amp = []
phase = []
new_cache = []
for i in range(repeat):
src_i = src[:, i * minibatch:(i + 1) * minibatch]
src_i = self.encoder(src_i) * math.sqrt(self.embedding_size) # (seq, batch, embedding)
src_i = self.pos_encoder(src_i, self.system_size) # (seq, batch, embedding)
if compute_cache:
if cache != None: c = cache[i]
else: c = None
output_i, c = self.next_with_cache(src_i, self.src_mask, c)
if c != None: new_cache.append(c)
psi_output = output_i
else:
output_i = self.transformer_encoder(src_i, self.src_mask)
psi_output = output_i[self.seq_prefix_len - 1:] # only use the physical degrees of freedom
amp_i = F.log_softmax(self.amp_head(psi_output), dim=-1) # (seq, batch, phys_dim)
amp.append(amp_i)
with the function
def next_with_cache(self,tgt,mask,cache=None,idx=-1):
output = tgt
new_token_cache = []
#go through each layer and apply self attention only to the last input
for i,layer in enumerate(self.transformer_encoder.layers):
tgt=output
#have to merge the functions into one
src = tgt[idx:, :, :]
# self attention part
src2 = layer.self_attn(
src, #only do attention with the last elem of the sequence
tgt,
tgt,
attn_mask=mask[idx:],
key_padding_mask=None,
)[0]
#straight from torch transformer encoder code
src = src + layer.dropout1(src2)
src = layer.norm1(src)
src2 = layer.linear2(layer.dropout(layer.activation(layer.linear1(src))))
src = src + layer.dropout2(src2)
src = layer.norm2(src)
output = src
new_token_cache.append(output)
if cache is not None:
#layers after layer 1 need to use a cache of the previous layer's output on each input
output = torch.cat([cache[i], output], dim=0)
#update cache with new output
if cache is not None:
new_cache = torch.cat([cache, torch.stack(new_token_cache, dim=0)], dim=1)
else:
new_cache = torch.stack(new_token_cache, dim=0)
return output, new_cache