my code looks like this
def get_loss(maxpool_output, inputs, dim_bottle_neck, regualization=True, regualization_ratio=0.1):
‘’‘maxpool (30, 256*4, 1)
inputs (30, 3, 1024)
dim_bottle_neck 256
outputs (30, 1024)’’’
# get the batch size because for each model has different loss to compute
batch_size = maxpool_output.size(0)
# create the loss and convert it to variable
loss = torch.zeros(batch_size).cuda()
loss = Variable(loss)
# create values that store the value of each point
values = torch.zeros(1024, ).cuda()
values = Variable(values)
# create regualization loss if needed
if regualization:
regualization_loss = torch.zeros(1).cuda()
regualization_loss = Variable(regualization_loss)
# each batch compute a loss and sum them up
for i in range(batch_size):
# extract the x y z coordinate from inputs
x = inputs[i, 0, :].clone()
y = inputs[i, 1, :].clone()
z = inputs[i, 2, :].clone()
# for each RBF there could be a value and sum them up
for j in range(dim_bottle_neck):
# get the d c1 c2 and c3 from maxpool_output
index = j*4
d_ind = maxpool_output[i, index, 0]
c1 = maxpool_output[i, index+1, 0]
c2 = maxpool_output[i, index+2, 0]
c3 = maxpool_output[i, index+3, 0]
# compute the distance from the points to the control point
distance = basis_function(c1, c2, c3, x, y, z)
# convert the distance to the value at thest points
values += d_ind * torch.pow(distance, 3)
# add the values up and add it to loss
loss[i] = torch.sum(torch.pow(values, 2))
# if want regualization add the regulization from d
if regualization:
regualization_loss += torch.pow(d_ind, 2)
# add the regualization into loss if needed
if regualization:
loss += regualization_ratio * (torch.sqrt(regualization_loss) - 1)
return loss