Sure! Thanks for your reply! Here is the source code for the model which returns the error:
class MultiHeadLinear(nn.Module):
def __init__(self, in_feats, out_feats, n_heads, bias=True):
super().__init__()
self.weight = nn.Parameter(torch.FloatTensor(size=(n_heads, in_feats, out_feats)))
if bias:
self.bias = nn.Parameter(torch.FloatTensor(size=(n_heads, 1, out_feats)))
else:
self.bias = None
def reset_parameters(self) -> None:
for weight, bias in zip(self.weight, self.bias):
nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
if bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(bias, -bound, bound)
# def reset_parameters(self):
# gain = nn.init.calculate_gain("relu")
# for weight in self.weight:
# nn.init.xavier_uniform_(weight, gain=gain)
# if self.bias is not None:
# nn.init.zeros_(self.bias)
def forward(self, x):
# input size: [N, d_in] or [H, N, d_in]
# output size: [H, N, d_out]
if len(x.shape) == 3:
x = x.transpose(0, 1)
x = torch.matmul(x, self.weight)
if self.bias is not None:
x += self.bias
return x.transpose(0, 1)
class MultiHeadBatchNorm(nn.Module):
def __init__(self, n_heads, in_feats, momentum=0.1, affine=True, device=None,
dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
assert in_feats % n_heads == 0
self._in_feats = in_feats
self._n_heads = n_heads
self._momentum = momentum
self._affine = affine
if affine:
self.weight = nn.Parameter(torch.empty(size=(n_heads, in_feats // n_heads)))
self.bias = nn.Parameter(torch.empty(size=(n_heads, in_feats // n_heads)))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
self.register_buffer("running_mean", torch.zeros(size=(n_heads, in_feats // n_heads)))
self.register_buffer("running_var", torch.ones(size=(n_heads, in_feats // n_heads)))
self.running_mean: Optional[Tensor]
self.running_var: Optional[Tensor]
self.reset_parameters()
self.eps = 1e-5
def reset_parameters(self):
self.running_mean.zero_() # type: ignore[union-attr]
self.running_var.fill_(1) # type: ignore[union-attr]
if self._affine:
nn.init.zeros_(self.bias)
for weight in self.weight:
nn.init.ones_(weight)
def forward(self, x):
assert x.shape[1] == self._in_feats
x = x.view(-1, self._n_heads, self._in_feats // self._n_heads)
self.running_mean = self.running_mean.to(x.device)
self.running_var = self.running_var.to(x.device)
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)
if bn_training:
mean = x.mean(dim=0, keepdim=True)
var = x.var(dim=0, unbiased=False, keepdim=True)
out = (x-mean) * torch.rsqrt(var + self.eps)
self.running_mean = (1 - self._momentum) * self.running_mean + self._momentum * mean.detach()
self.running_var = (1 - self._momentum) * self.running_var + self._momentum * var.detach()
else:
out = (x - self.running_mean) * torch.rsqrt(self.running_var + self.eps)
if self._affine:
out = out * self.weight + self.bias
return out
class GroupMLP(nn.Module):
def __init__(self, in_feats, hidden, out_feats, n_heads, n_layers, dropout, input_drop=0., residual=False, normalization="batch"):
super(GroupMLP, self).__init__()
self._residual = residual
self.layers = nn.ModuleList()
self.norms = nn.ModuleList()
self._n_heads = n_heads
self._n_layers = n_layers
self.input_drop = nn.Dropout(input_drop)
if self._n_layers == 1:
self.layers.append(MultiHeadLinear(in_feats, out_feats, n_heads))
else:
self.layers.append(MultiHeadLinear(in_feats, hidden, n_heads))
if normalization == "batch":
self.norms.append(MultiHeadBatchNorm(n_heads, hidden * n_heads))
# self.norms.append(nn.BatchNorm1d(hidden * n_heads))
if normalization == "layer":
self.norms.append(nn.GroupNorm(n_heads, hidden * n_heads))
if normalization == "none":
self.norms.append(nn.Identity())
for i in range(self._n_layers - 2):
self.layers.append(MultiHeadLinear(hidden, hidden, n_heads))
if normalization == "batch":
self.norms.append(MultiHeadBatchNorm(n_heads, hidden * n_heads))
# self.norms.append(nn.BatchNorm1d(hidden * n_heads))
if normalization == "layer":
self.norms.append(nn.GroupNorm(n_heads, hidden * n_heads))
if normalization == "none":
self.norms.append(nn.Identity())
self.layers.append(MultiHeadLinear(hidden, out_feats, n_heads))
if self._n_layers > 1:
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
for head in range(self._n_heads):
for layer in self.layers:
nn.init.kaiming_uniform_(layer.weight[head], a=math.sqrt(5))
if layer.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(layer.weight[head])
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(layer.bias[head], -bound, bound)
self.reset_parameters()
def reset_parameters(self):
gain = nn.init.calculate_gain("relu")
for head in range(self._n_heads):
for layer in self.layers:
nn.init.xavier_uniform_(layer.weight[head], gain=gain)
if layer.bias is not None:
nn.init.zeros_(layer.bias[head])
for norm in self.norms:
norm.reset_parameters()
# for norm in self.norms:
# norm.moving_mean[head].zero_()
# norm.moving_var[head].fill_(1)
# if norm._affine:
# nn.init.ones_(norm.scale[head])
# nn.init.zeros_(norm.offset[head])
# print(self.layers[0].weight[0])
def forward(self, x):
x = self.input_drop(x)
if len(x.shape) == 2:
x = x.view(-1, 1, x.shape[1])
if self._residual:
prev_x = x
for layer_id, layer in enumerate(self.layers):
x = layer(x)
if layer_id < self._n_layers - 1:
shape = x.shape
x = x.flatten(1, -1)
x = self.dropout(self.relu(self.norms[layer_id](x)))
x = x.reshape(shape=shape)
if self._residual:
if x.shape[2] == prev_x.shape[2]:
x += prev_x
prev_x = x
return x
And the code for my self-defined network:
class Net(nn.Module):
def __init__(self, num_features, hidden_channels, num_classes, hidden_channels_label, num_layers, num_layers_label, num_nodes, **kwargs):
super(Net, self).__init__()
self.lins = torch.nn.ModuleList()
self.lins.append(torch.nn.Linear(num_features, hidden_channels))
self.bns = torch.nn.ModuleList()
self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
for _ in range(num_layers - 2):
self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
self.lins.append(torch.nn.Linear(hidden_channels, num_classes))
self.num_classes = num_classes
self.num_nodes = num_nodes
self.label_model = GroupMLP(num_classes,
hidden_channels_label,
num_classes,
args.num_heads,
num_layers_label,
args.dropout_label,
residual=args.label_residual)
It seems the error happens when I initialize the label_model, which is the MultiHeadBatchNorm
of Group_MLP
.
,