Hi All,
TL;DR How can I de-stack a large tensor into a dict of Tensors in PyTorch?
Let’s say I compute the gradients of a model for a set of samples and concatenate them into a single tensor, is there a way to invert this process?
I have a minimal reproducible example below, is there an efficient way to convert grad_params_stack
back to grads_params
?
import torch
from torch import nn
from torch.func import jacrev, vmap, functional_call
class network(nn.Module):
def __init__(self, num_input, num_hidden, num_output):
super(network, self).__init__()
self.num_input = num_input
self.num_hidden = num_hidden
self.num_output = num_output
self.func = nn.Tanh()
self.fc1=nn.Linear(self.num_input, self.num_hidden)
self.fc2=nn.Linear(self.num_hidden, self.num_hidden)
self.fc3=nn.Linear(self.num_hidden, self.num_hidden)
self.fc4=nn.Linear(self.num_hidden, self.num_output)
def forward(self, x):
x=self.fc1(x)
x=self.func(x)
x=self.fc2(x)
x=self.func(x)
x=self.fc3(x)
x=self.func(x)
x=self.fc4(x)
return x.squeeze(0)
batch_size=100
num_input=4
num_hidden=32
num_output=1
net = network(num_input=num_input, num_hidden=num_hidden, num_output=num_output)
num_params = sum([p.numel() for p in net.parameters()])
#inputs
params = dict(net.named_parameters())
x = torch.randn(batch_size, num_input)
def fcall(params, x):
return functional_call(net, params, x)
#per-sample gradients of the network w.r.t model params
grads_params = vmap(jacrev(fcall, argnums=(0)), in_dims=(None,0))(params, x) #compute gradients via torch.func
grad_params_stack = torch.cat([v.flatten(start_dim=1) for v in grads_params.values()], dim=-1)
print(grad_params_stack.shape) # batch_size by num_params
#How to invert grad_params_stack back to grads_params?