Question on good practice in defining trainable parameters

I have a simple model such as the one shown below, where I define a set of parameters of shape (K, 2 * N * d) and I initialize them so as dim=1 holds pairs of “antipodal” vectors; that is, self.params[:, j+1, :] = -self.params[:, j, :]. This of course results in a number of K * 2 * N * d trainable parameters, while at the same time, the pairs in dim=1 are not “tied” together during training (they are only initialized to be opposite).

I would like to modify my model so as I define only K * N * d parameters, but in a way that I can form self.params as a (K, 2 * N * d) matrix holding the property described above. What would be the optimal way to do so? I could create a list of parameters maybe and then concatenate and repeat in order to get the final matrix, but would that be efficient?

What would be your take on this? Thanks!

import torch
import torch.nn as nn


class Model(nn.Module):
    def __init__(self, K, N, d):
        """...
        
        The model defines a set of parameters of shape (K, 2 * N * d) and initialises them so as dim=1 
        holds pairs of "antipodal" vectors; that is, self.params[:, j+1, :] = -self.params[:, j, :].
        
        Args:
            K (int) : 
            N (int) : 
            d (int) : 
            
        """
        super(Model, self).__init__()
        self.K = K
        self.N = N
        self.d = d
        
        self.params = nn.Parameter(data=torch.ones(K, 2 * N * d))
        params_init = torch.zeros(K, 2 * N, d)
        for k in range(self.K):
            vectors_k = []
            for i in range(self.N):
                v = torch.randn(1, d)
                vectors_k.extend([v, -v])
            vectors_k = torch.cat(vectors_k)
            params_init[k, :] = vectors_k
        
        self.params.data = params_init.reshape(self.K, 2 * N * d).clone()
        
        
model = Model(K=32, N=16, d=2)
trainable_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(trainable_parameters)
# 2048