You could initialize the weight matrices in your desired shape, e.g. by using torch.triu
, and then zeroing out the gradients at the same locations.
Here is a small code snippet for this use case, which could be a good starter:
# Init
def triu_init(m):
if isinstance(m, nn.Linear):
with torch.no_grad():
m.weight.copy_(torch.triu(m.weight))
# Zero out gradients
def get_zero_grad_hook(mask):
def hook(grad):
return grad * mask
return hook
# Setup
model = nn.Sequential(
nn.Linear(10, 10),
nn.ReLU(),
nn.Linear(10, 10),
nn.ReLU(),
nn.Linear(10, 10)
)
model.apply(triu_init)
# Check params
for name, param in model.named_parameters():
print(name, param)
# Dummy intpus
x = torch.randn(10, 10)
target = torch.randn(10, 10)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
# Create masks for all layers (we can just use one)
mask = torch.triu(torch.ones_like(model[0].weight))
# Register with hook
model[0].weight.register_hook(get_zero_grad_hook(mask))
model[2].weight.register_hook(get_zero_grad_hook(mask))
model[4].weight.register_hook(get_zero_grad_hook(mask))
# Train for some epochs
for epoch in range(10):
optimizer.zero_grad()
output = model(x)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print('Epoch {}, loss {}'.format(epoch, loss.item()))
# Verify
for name, param in model.named_parameters():
print(name, param)