An if statement causes a catastrophic GPU memory leak

In my model I apply self attention and have 3 different pooling strategies of using the cls token, max pooling, and a combination of max pooling and the cls token.

What mode the module is in is defined as a string kwarg and these are the effects during init

    if pool_strat == "both":
        post_dim = pre_pool_layers[-1] + attention_dim_per_head * attention_heads

    if pool_strat == "max":
        post_dim = pre_pool_layers[-1]

    if pool_strat == "cls":
        post_dim = attention_dim_per_head * attention_heads

All this does is tell the first layer after pooling what the dim will be. In forward this is handled with

    if self.pool_strat == "both":
        # Concat CLS to max pool
        global_embedding = torch.cat([CLS, crystal_x], dim=1)
    if self.pool_strat == "max":
        # Set x to just crystal x
        global_embedding = crystal_x
    if self.pool_strat == "cls":
        # Set x to just cls
        global_embedding = CLS

While my model works fine for “both” and “max” the use of “cls” causes a memory leak. crystal x and CLS are defined as

crystal_x = torch.max(x[1:],0,keepdim=False)[0]
CLS = x[0]

I am extremely confused as to why these nearly identical options could cause a memory leak like this, more so confusing that concating CLS to global embedding works but directly pushing it forward does not. Is there something wrong with indexing x in the way I have? The memory leak grows on the end of every epoch, and it increases by memory usage at epoch 0 each time additively.