Hi, I would like to ask some questions about how graphs are computed.
Imagine a simple multimodal model where we have 3 sub-networks, A,B,C.
INa and INb are torch variables, with whatever dimensionality but the dimension corresponding to the batch is dim 0, and both have the same amount of samples.
If the proper workflow for one sample were:
A process a sample of INa and B process a sample of INb, then C takes as input OUTa and OUTb to compute OUT.
and so on through the whole batch. When the batch is already processed compute loss and backprop
Is the previously mentioned graph equal to a graph generated by the following process?:
A process the whole batch INa, B process the whole batch INb, then C process the the whole batch OUTa, OUTb and loss-backward?
in pseudo-code:
for i in range(batch_size):
outa_i = model_A(INa_i)
outb_i = model_B(INb_i)
out_i = model_C(outa_i,outb_i)s
stack(out)
loss
backward
for i in range(batch_size):
outa_i = model_A(INa_i)
outa=stack(outa_i)
for i in range(batch_size):
outb_i = model_B(INb_i)
outb=stack(outb_i)
for i in range(batch_size):
out_i = model_C(outa,outb)
out=stack(out_i)
loss
backward
Is there a good explanation about pytorch graphs?