Can someone explain to me what do we mean by buffers in pytorch.
what is its characteristics and when we will use that and when we should not use that?
If a buffer is basically the same as tensor, Why would I even need a buffer when I can simply create my tensor inside the module?
an example would be appreciated!!
thanks
Buffers are tensors, which are registered in the module and will thus be inside the state_dict
.
These tensors do not require gradients and are thus not registered as parameters.
This is useful e.g. to track the mean and std in batchnorm layers etc. which should be stored and loaded using the state_dict
of the module.
I understand, thanks.
- is there a way to know all the buffers in a model?
- in case of bn that you mentioned how I can get the mean and std values? I reach to this point, but not sure how to get mean and std
import torchvision
import torch.nn as nn
import torch.nn.functional as F
resnet = torchvision.models.resnet50()
resnet.layer1[0].bn1.state_dict
You can get all buffers via model.buffers()
or model.named_buffers()
(same as with .parameters()
and .named_parameters()
).
To access the buffers in a specific layer, you can access them directly:
model.bn.running_mean
model.bn.running_var
Q1: Do I understand correctly that named_buffers() and buffers() is the same parameters?
Q2: Is Pytotrch automatically determinates the name of the parameter?
Because ctor does not obtain this as argument (Parameter — PyTorch 2.1 documentation)
-
named_buffers()
andbuffers()
returns the same buffers where the first operation returns the corresponding name for each buffer. I’m explicitly using “buffer” to avoid conflicting it with parameters, which are different. Both are registered to thenn.Module
where parameters are trainable while buffers are not. -
Yes, the name of the buffer or parameter is determined by its variable name (e.g.
self.my_param = nn.Parameter(...)
would return “my_param” as the name) or explicitly viaself.register_buffer(name, buffer)
orself.register_parameter(name, param)
.
It’s an interesting example of using Python reflection like behaviour in such way that there is no need to register named_pararamters with specific names:
import torch
import torch.nn as nn
class NN_Network(nn.Module):
def __init__(self,in_dim,hid,out_dim):
super(NN_Network, self).__init__()
self.linear1 = nn.Linear(in_dim,hid)
self.linear2 = nn.Linear(hid,out_dim)
def forward(self, input_array):
h = self.linear1(input_array)
y_pred = self.linear2(h)
return y_pred
in_d = 5
hidn = 2
out_d = 3
model = NN_Network(in_d, hidn, out_d)
for name, param in model.named_parameters():
print(name, "[", type(name), "]", type(param), param.size())
I have updated a code snippet to play with parameters and buffers from public Python API side:
#!/usr/bin/env python3
import torch
import torch.nn as nn
class NN_Network(nn.Module):
def __init__(self,in_dim,hid,out_dim):
super(NN_Network, self).__init__()
# https://pytorch.org/docs/stable/generated/torch.nn.Linear.html?highlight=torch%20nn%20linear#torch.nn.Linear
self.linear1 = nn.Linear(in_dim,hid)
self.linear2 = nn.Linear(hid,out_dim)
# https://pytorch.org/docs/stable/generated/torch.nn.parameter.Parameter.html?highlight=torch%20nn%20parameter#torch.nn.parameter.Parameter
self.my_param_a = nn.Parameter(torch.zeros(3,3))
# https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=register_buffer#torch.nn.Module.register_buffer
self.my_buffer_1 = self.register_buffer("A", torch.zeros(3,3))
self.my_buffer_2 = self.register_buffer("B", torch.zeros(4,4))
def forward(self, input_array):
h = self.linear1(input_array)
y_pred = self.linear2(h)
return y_pred
in_d = 5
hidn = 2
out_d = 3
model = NN_Network(in_d, hidn, out_d)
print("\nNAMED PARAMS")
for name, param in model.named_parameters():
print(" ", name, "[", type(name), "]", type(param), param.size())
print("\nNAMED BUFFERS")
for name, param in model.named_buffers():
print(" ", name, "[", type(name), "]", type(param), param.size())
print("\nSTATE DICT (KEYS ONLY)")
for k, v in model.state_dict().items():
print(" ", k)
print("\nSTATE DICT (KEYS VALUE PAIRS)")
for k, v in model.state_dict().items():
print(" ", "(", k, "=>", v, ")")