I am struggling with quantile regression with NN.
The input data shape(n_data, 308) and last dim(qunatile information) should not be normalized.
May I ask how to apply batch norm partially, except some of features?
-------EDIT
Thanks to fs5ss1, I solved it, please refer to the code below:
class Net(nn.Module):
def __init__(self):
super().__init__()
self.build_model()
def build_model(self):
self.short_cut = nn.Identity()
self.linear1 = nn.Linear(308, 1101)
self.bn1 = nn.BatchNorm1d(1100)
self.leakyrelu = nn.LeakyReLU()
self.linear2 = nn.Linear(1101, 308)
self.bn2 = nn.BatchNorm1d(307)
self.linear_mean_output = nn.Linear(308, 1)
def forward(self, x):
x_out = self.linear1(x)
x_q = x_out[:, -1].contiguous().view([-1, 1])
x_out = torch.cat([self.bn1(x_out[:,:-1]), x_q], axis=1)
x_out = self.leakyrelu(x_out)
x_out = self.linear2(x_out)
x_out += self.short_cut(x)
x_q = x_out[:, -1].contiguous().view([-1, 1])
x_out = torch.cat([self.bn2(x_out[:,:-1]), x_q], axis=1)
x_out = self.leakyrelu(x_out)
x_out = self.linear_mean_output(x_out)
return x_out