I have the following issue where I need to pass multiple args in the module.apply
callable but for some reason it keeps throwing an error.
ValueError: unsupported format character 'b' (0x62) at index 1
Any ideas?
Here’s a MWE:
dummy_params = []
def init_buffers(module, params, flag=True):
for name in list(module._parameters.keys()):
if module._parameters[name] is None:
continue
data = module._parameters[name].data
module._parameters.pop(name)
module.register_buffer("%alpha" % name, torch.zeros(data.size()))
module.register_buffer("%beta" % name, torch.zeros(data.size()))
if flag is False:
module.register_buffer("%delta" % name, torch.zeros(0, data.numel()))
params.append((module, name))
class Mynet(torch.nn.Module):
def __init__(self):
super(Mynet, self).__init__()
self.fc1 = torch.nn.Linear(2,2)
self.fc2 = torch.nn.Linear(2,2)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
net = Mynet()
net.apply(lambda module: init_buffers(module=module, params=params))