I combined conventional unet module and a custom function in a forward model.
This custom function cannot take input as a batch, so I use forloop to feed the data one by one.
However, it gives me an error when batchsize >=2 (works when =1).
(It also works for >=2 when I only use custom_func, without UNET.)
Therefore I think I’m doing something wrong with evaluating the module multiple times in a single forward process.
Can anyone help please? Thanks!!
______________________________
model_UNET = UNET() # conventional unet
optvars = [{'params': custom_vars, 'lr':lr_custom}] # custom_vars used in custom_func
optimizer1 = optim.Adam(optvars)
optimizer2 = optim.Adam(UNET.parameters(), lr=lr_Unet)
# My forward model
for aa in range(batch_size):
input=batch[aa,:,:]
temp=custom_func(input, custom_vars)
output=model_UNET(temp)
loss_total=loss_total + loss(output)
# backward & update
custom_vars.retain_grad() #<-- Maybe I need something like this for parameters in UNET?
loss_total.backward() #<-- This line gives the error
optimizer1.step()
optimizer2.step()
___________________________