Creating a Triangular weight matrix with a hidden Layer instead of a fully connected one

Hi Everyone,

I’m quite new to PyTorch and I am excited about using it for some projects.

I have a question regarding a the use of a linear hidden layer. I want to create a hidden layer and put it in nn.Sequential( ) where the weight matrix associated with the hidden layer is desired to be Lower Triangular or Upper Triangular in nature, and then apply some activations and then optimize using Autograd. The problem is nn.Linear( ) creates a fully connected layer with a complete weight matrix which is not what I want.

Any help in this regard would be appreciated.

1 Like

You could initialize the weight matrices in your desired shape, e.g. by using torch.triu, and then zeroing out the gradients at the same locations.
Here is a small code snippet for this use case, which could be a good starter:

# Init
def triu_init(m):
    if isinstance(m, nn.Linear):
        with torch.no_grad():
            m.weight.copy_(torch.triu(m.weight))

# Zero out gradients
def get_zero_grad_hook(mask):
    def hook(grad):
        return grad * mask
    return hook

# Setup
model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 10)
)
model.apply(triu_init)

# Check params
for name, param in model.named_parameters():
    print(name, param)

# Dummy intpus
x = torch.randn(10, 10)
target = torch.randn(10, 10)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# Create masks for all layers (we can just use one)
mask = torch.triu(torch.ones_like(model[0].weight))
# Register with hook
model[0].weight.register_hook(get_zero_grad_hook(mask))
model[2].weight.register_hook(get_zero_grad_hook(mask))
model[4].weight.register_hook(get_zero_grad_hook(mask))

# Train for some epochs
for epoch in range(10):
    optimizer.zero_grad()
    output = model(x)
    loss = criterion(output, target)
    loss.backward()        
    optimizer.step()
    print('Epoch {}, loss {}'.format(epoch, loss.item()))

# Verify
for name, param in model.named_parameters():
    print(name, param)

This solution really helped. Thanks so much @ptrblck.