Yep.
the code snipest:
the model:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
import torchvision.models as models
from torchvision.models import resnet
from torchvision.models._utils import IntermediateLayerGetter
class deeplab_v3_separation_dev(nn.Module):
def __init__(self):
super(deeplab_v3_separation_dev, self).__init__()
deepLabV3ResNet101 = models.segmentation.deeplabv3_resnet101(pretrained=True,aux_loss=None)
deepLabV3ResNet101Layers = list(deepLabV3ResNet101.children())
ClassifierLayers = deepLabV3ResNet101.classifier
ASSP_features = ClassifierLayers[0]
ResNet101_features = list(deepLabV3ResNet101.backbone.children())
semantic_layers = list(ClassifierLayers)
self.ResNet101_features = nn.Sequential(*ResNet101_features)
self.ASSP_features = ASSP_features
self.Semantic_head = nn.Sequential(*semantic_layers[1:])
self.SFTMX = nn.Sigmoid()
# self.Background_convs = nn.Sequential(nn.Conv2d(280, 3, kernel_size=(1, 1), stride=(1, 1), bias=False))#,
# nn.BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
# nn.ReLU(),
# nn.Dropout(p=0.5, inplace=False))
# self.Separation_convs = nn.Sequential(nn.Conv2d(283, 3, kernel_size=(1, 1), stride=(1, 1), bias=False))#,
# nn.BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
# nn.ReLU(),
# nn.Dropout(p=0.5, inplace=False))
self.Background_convs = nn.Sequential(nn.Conv2d(24, 24, 1),
nn.BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(),
nn.Conv2d(24, 3, 1))
self.Separation_convs = nn.Sequential(nn.Conv2d(27, 27, 1),
nn.BatchNorm2d(27, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.ReLU(),
nn.Conv2d(27, 3, 1))
def forward(self, x):
input_shape = x.shape[-2:]
x1 =self.ResNet101_features(x)
x2 = self.ASSP_features(x1)
y1 = self.Semantic_head(x2)
y1 = self.SFTMX(y1)
y1 = F.interpolate(y1, size=input_shape, mode='bilinear', align_corners=False)
x2 = F.interpolate(x2, size=input_shape, mode='bilinear', align_corners=False)
y2 = self.Background_convs(torch.cat((y1,x),1)) #x2,
y3 = self.Separation_convs(torch.cat((y1,y2,x),1)) #x2,
return y1, y2, y3
The freezing layers code:
i=0
for param in net.parameters():
if i<340 :
param.requires_grad = False
i+=1
and the training script:
def train_net(net,
device,
epochs=5,
batch_size=2,
lr=0.05,
val_percent=0,
save_cp=True,
img_scale=0.25):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)
deepLabV3ResNet101 = models.segmentation.deeplabv3_resnet101(pretrained=True)
writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
global_step = 0
optimizer = optim.Adam(net.parameters(), lr=lr, betas=(0.7, 0.999), eps=1e-08, weight_decay=1e-3, amsgrad=False)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=(0.3))
L1_criterion = nn.L1Loss(reduction='none')
semantic_criterion = nn.BCELoss(reduction='none')#BCEWithLogitsLoss
for epoch in range(epochs):
net.train()
with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
for batch in train_loader:
mixedImgs = batch['mixedImage']
if mixedImgs.size()[0]==1:
continue
BackgroundImage = batch['BackgroundImage']
ReflectionImage = batch['ReflectionImage']
BackgtoundSemantics = batch['BackgtoundSemantics']
mixedImgs = mixedImgs.to(device=device) #, dtype=torch.float32)
BackgroundImage = BackgroundImage.to(device=device) #, dtype=torch.float32)*255
ReflectionImage = ReflectionImage.to(device=device) #, dtype=torch.float32)
BackgtoundSemantics = BackgtoundSemantics.to(device=device , dtype=torch.float32) # , dtype=torch.long
semantic_pred, BG_img_pred, R_img_pred = net.forward(mixedImgs) # , BG_img_pred, R_img_pred
loss1 = semantic_criterion(semantic_pred,BackgtoundSemantics ) #BackgtoundSemantics[:,:,:,:]
var1 = torch.var(loss1)
loss2 = L1_criterion(BG_img_pred, BackgroundImage)
# # # # loss += 0.0003*LA.norm(cv2.Canny(BG_img_pred,100,200) - cv2.Canny(BackgroundImage,100,200))
loss3 = 0.6*(1 - ssim.ssim(BG_img_pred, BackgroundImage))
loss4 = 0.8*L1_criterion(R_img_pred,ReflectionImage)# torch.sum(torch.abs(R_img_pred-ReflectionImage),1)
var2 = torch.var(loss4) #dim=0, keepdim=True, unbiased=False
Loss1 = loss1/(2*var1.square())
Loss2 = (loss2+loss3+loss4) /(2*var2.square())
# # epoch_loss += loss.item()
loss = 0*torch.mean(Loss2)+0*torch.mean(Loss1)
writer.add_scalar('Loss/train', loss.item(), global_step)
pbar.set_postfix(**{'loss (batch)': loss.item()})
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_value_(net.parameters(), 0.1)
optimizer.step()
pbar.update(mixedImgs.shape[0])
global_step += 1
if global_step % (len(dataset) // (10 * batch_size)) == 0:
scheduler.step()
writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)