I have a problem with an object classification pipeline. The task is instance segmentation on 2D grayscale images. If I try to do the image instance segmentation directly with something like fast-rCNN, the algorithm gets all over the place. Thus, we implement a two-step classification, first with a U-net that gives us foreground/background classification with a high degree of speed and accuracy, and then a random forest that iterates over the roi (extracted via thrs on the mask and then from scipy.ndimage.label. This works fine, but we need to create derive features for the random forest, which ends up also consuming quite a bit of time. Thus, I was wondering what the most efficient approach to classifying objects using Torch is. Currently, I have the following pipeline:
def extract_object_features(original_image, binary_mask, categorical_image):
"""Extract object features for classification"""
# Check if inputs are 3D, and if so, add a dummy dimension
if original_image.ndim == 3:
original_image = np.expand_dims(original_image, axis=0)
binary_mask = np.expand_dims(binary_mask, axis=0)
categorical_image = np.expand_dims(categorical_image, axis=0)
#assert original_image.ndim == 3, "Input must be a 3D stack (ZXY) or 2D (XY)"
assert original_image.shape == binary_mask.shape == categorical_image.shape, "Input shapes must match"
all_features = []
all_object_labels = []
binary_mask=np.where(binary_mask > 0.5, 1, 0)
for z in range(original_image.shape[0]):
# Get slices
binary_slice = binary_mask[z, 0]
img_slice = original_image[z, 0]
catimage_slice = categorical_image[z, 0]
labeled_image = label(binary_slice)
regions = regionprops(labeled_image)
features = []
object_labels = []
for region in regions:
minr, minc, maxr, maxc = region.bbox
object_mask = np.zeros((maxr - minr, maxc - minc), dtype=bool)
object_mask[region.coords[:, 0] - minr, region.coords[:, 1] - minc] = True
object_image = np.zeros((maxr - minr, maxc - minc), dtype=original_image.dtype)
object_image[object_mask] = img_slice[region.coords[:, 0], region.coords[:, 1]]
object_label = np.bincount(catimage_slice[region.coords[:, 0], region.coords[:, 1]]).argmax()
features.append(object_image)
object_labels.append(object_label)
all_features.append(features)
all_object_labels.append(object_labels)
# Flatten the list of lists
all_features = [item for sublist in all_features for item in sublist]
all_object_labels = [item for sublist in all_object_labels for item in sublist]
return all_features, all_object_labels
def run_epoch(model, dataloader, optimizer, loss_function, device, num_classes=5, is_train=False):
"""Run a epoch of training or validation."""
total_loss = 0.0
model.train(is_train)
for idx, batch in enumerate(dataloader):
print(f'running batch {idx + 1}')
images = np.stack(batch['image'])
masks = np.stack(batch['binary'])
labels = np.stack(batch['label'])
feats, labs = extract_object_features(images, masks, labels)
all_outputs = []
for f, in feats:
feats_tensor = torch.from_numpy(f).to(device)
feats_tensor = feats_tensor.unsqueeze(0).unsqueeze(0)
if is_train:
optimizer.zero_grad()
all_outputs.append(model(feats_tensor))
labels_one_hot = torch.nn.functional.one_hot(labs_tensor, num_classes=num_classes).float()
labels_one_hot = labels_one_hot.unsqueeze(0) # Add batch dimension
loss = loss_function(outputs, labels_one_hot)
if is_train:
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
return avg_loss
I would say that something like this could theoretically work, but it seems to have lots of nested loops and is very inefficient. The ROI sizes have a wide range (from 400.000 to 40 px), so I don’t know how I could build a batch. Its number is also quite considerable (~100 objects per 512x512 patch), so I cannot assemble a Z-stack of image copies where I turn off all pixes except the desired labelled object in slice Z. Therefore, I came here for some help and inspiration. In the meantime, I will keep working with the random forest pipeline. Also, happy to upload a sample image in it helps.