hello, I am new to pytorch.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
summary(model, input_size = input_size)
input_size = [ (14,1,1024), (14,1,1024), (14,1,1024), (14,1,1024),(14,1,384),(14,1,768), (14,1,768), (14,1,768), (14,1,768), (14,1,768), (14,1,9), (1,14)]
However, model forward need r1, r2, r3, r4, x1, x2, x3, o1, o2, qmask, umask, att2, return_hidden parameter. it will exceed two. How to fix this problem? Thanks, best wishes