In my model I apply self attention and have 3 different pooling strategies of using the cls token, max pooling, and a combination of max pooling and the cls token.
What mode the module is in is defined as a string kwarg and these are the effects during init
if pool_strat == "both": post_dim = pre_pool_layers[-1] + attention_dim_per_head * attention_heads if pool_strat == "max": post_dim = pre_pool_layers[-1] if pool_strat == "cls": post_dim = attention_dim_per_head * attention_heads
All this does is tell the first layer after pooling what the dim will be. In forward this is handled with
if self.pool_strat == "both": # Concat CLS to max pool global_embedding = torch.cat([CLS, crystal_x], dim=1) if self.pool_strat == "max": # Set x to just crystal x global_embedding = crystal_x if self.pool_strat == "cls": # Set x to just cls global_embedding = CLS
While my model works fine for “both” and “max” the use of “cls” causes a memory leak. crystal x and CLS are defined as
crystal_x = torch.max(x[1:],0,keepdim=False)[0]
CLS = x[0]
I am extremely confused as to why these nearly identical options could cause a memory leak like this, more so confusing that concating CLS to global embedding works but directly pushing it forward does not. Is there something wrong with indexing x in the way I have? The memory leak grows on the end of every epoch, and it increases by memory usage at epoch 0 each time additively.