please I used the function:
self.pool1 = TopKPooling(self.dimensions, ratio=0.8)
in the forward part of my class:
recon, edge_index, _, batch, _, _ = self.pool1(recon, edges, None, batch)
but after two epoches I get this error message:
--> 247 recon, edge_index, _, batch, _, _ = self.pool1(recon, edges, None, batch)
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.7/dist-packages/torch_geometric/nn/pool/topk_pool.py in forward(self, x, edge_index, edge_attr, batch, attn)
160 score = softmax(score, batch)
161
--> 162 perm = topk(score, self.ratio, batch, self.min_score)
163 x = x[perm] * score[perm].view(-1, 1)
164 x = self.multiplier * x if self.multiplier != 1 else x
/usr/local/lib/python3.7/dist-packages/torch_geometric/nn/pool/topk_pool.py in topk(x, ratio, batch, min_score, tol)
19 perm = (x > scores_min).nonzero(as_tuple=False).view(-1)
20 else:
---> 21 num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
22 batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
23
/usr/local/lib/python3.7/dist-packages/torch_scatter/scatter.py in scatter_add(src, index, dim, out, dim_size)
27 out: Optional[torch.Tensor] = None,
28 dim_size: Optional[int] = None) -> torch.Tensor:
---> 29 return scatter_sum(src, index, dim, out, dim_size)
30
31
/usr/local/lib/python3.7/dist-packages/torch_scatter/scatter.py in scatter_sum(src, index, dim, out, dim_size)
9 out: Optional[torch.Tensor] = None,
10 dim_size: Optional[int] = None) -> torch.Tensor:
---> 11 index = broadcast(index, src, dim)
12 if out is None:
13 size = list(src.size())
/usr/local/lib/python3.7/dist-packages/torch_scatter/utils.py in broadcast(src, other, dim)
10 for _ in range(src.dim(), other.dim()):
11 src = src.unsqueeze(-1)
---> 12 src = src.expand(other.size())
13 return src
RuntimeError: The expanded size of the tensor (16) must match the existing size (63) at non-singleton dimension 0. Target sizes: [16]. Tensor sizes: [63]
How can I solve this problem? and thank you