Hi, Thank you for your response.
code snippet
class BatchNorm(nn.Module):
def __init__(self, input_dim: int, use_batch_normalization: bool = True, momentum: float = 0.1,
track_running_stats: bool = True) -> None:
super().__init__()
self.input_dim = input_dim
self.momentum = momentum
self.use_batch_normalization = use_batch_normalization
if self.use_batch_normalization:
self.batch_norm = nn.BatchNorm1d(input_dim, momentum=momentum, track_running_stats=track_running_stats)
else:
self.bias = Parameter(torch.zeros(input_dim, dtype=torch.float32), requires_grad=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of batch normalization block.
:param x: Input of shape `(N, D)` or `(N, K, D)` where `N = number of points`, `K = number of neighbors`, and
`D = number of feature channels`.
:type x: torch.Tensor
:return: Normalized output of the same shape as the input.
:rtype: torch.Tensor
"""
if self.use_batch_normalization:
if x.dim() == 2:
return self.batch_norm(x)
if x.dim() == 3:
# (N, K, D) -> (N, D, K)
x = x.transpose(1, 2).contiguous()
# (N, D, K)
output = self.batch_norm(x)
# (N, D, K) -> (N, K, D)
return output.transpose(1, 2).contiguous()
raise ValueError(f"Input dimension of batch normalization block should be 2 or 3, got {x.dim()}.")
else:
return x + self.bias
def __repr__(self) -> str:
return 'BatchNormBlock(in_feat: {:d},' \
' momentum: {:.3f}, only_bias: {:s})'.format(self.input_dim,
self.momentum,
str(not self.use_batch_normalization))
which basically originate from
self.linear_pos_bias = nn.Sequential(OrderedDict([
('linear_1', nn.Linear(3, self.feature_dim, bias=False)),
('bn', BatchNorm(self.feature_dim)),
('relu', nn.ReLU(inplace=True)),
('linear_2', nn.Linear(self.feature_dim, self.feature_dim))
]))
Error:
File "x:\Transformer.py", line 206, in forward
peb = self.linear_pos_bias(pos)
File "x:\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "x:\lib\site-packages\torch\nn\modules\container.py", line 204, in forward
input = module(input)
File "x:\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "x:\blocks\batch_norm.py", line 65, in forward
output = self.batch_norm(x)
File "x:\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
return forward_call(*input, **kwargs)
File "x:\lib\site-packages\torch\nn\modules\batchnorm.py", line 171, in forward
return F.batch_norm(
File "x:\lib\site-packages\torch\nn\functional.py", line 2450, in batch_norm
return torch.batch_norm(
RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR