How to retrain a saved model.pth with a new dataset

i trained a model on a dataset and saved the weight pth file.how do i load it and use the weights to train on a new dataset

You can load the parameters using model.load_state_dict():

# Initialize model
model = MyModel()
# Load state_dict
model.load_state_dict(torch.load('my_weights.pth'))

Have a look at the Transfer Learning Tutorial to see how you can fine-tune your model.

1 Like

Thank you for the help…after training and saving the weights to a pth file how to i predict(probabilities) on my test set form test dataloader

I assume you have already created your DataLoader. If so, you could write a simple loop iterating your test_loader:

model.eval()
predictions = []
with torch.no_grad():
    for data in test_loader:
        output = model(data)
        pred = F.softmax(output, 1)  # assuming your model outputs logits
        predictions.append(pred)
predictions = torch.cat(predictions)

predictions will now contain the probabilities for each sample and each class.

i need full code for this retraining model pt with new dataset and save the model with new pt file

In case you would like to retrain a classification model, the ImageNet example might be a good starter.

i solved by loaded pretrained model and loaded paramters iwth model and use train function to retrain the model it is working

i Have input image has tiff tfile uint8 format(0-255) and segmentaed image(target image0 in nrrd file uint16 (0-65535) how can i use this dataset corectly inot my dataloader for segmentaion using unet
please help me for this

You could create a custom Dataset as described here and load the input as well as the target with any library that works for you (e.g. pynrrd for the nrrd file).

Bro I have nrrd file in 16 bit uint which is not supported in pytorch


my segmentation predict image like this first image i want to segment imAGE LIKE THIS

Why would it not be supported? float32 can exactly represent all integers in [0, 16777216] which also includes uint16.

How to Get the Data Type of a Pytorch Tensor? - GeeksforGeekssuppooirteddatatype_img

I’m not sure where you are stuck, but why wouldn’t casting to float32 work (which is the default dtype in PyTorch)?

how can i correctly convert uint16 to flaot32 image where to use

Assuming your library uses numpy for the uint16 images, use astype:

arr = np.random.randint(0, 65000, (100,), dtype=np.uint16)
arr = arr.astype(np.float32)

then create the tensor via torch.from_numpy.

1 Like

give me solution for this @ptrblck bro

@ptrblck bro give a solution for this

I’ve already posted the solution: transform the uint16 array to float32 and then to a tensor.

1 Like

not for that bro

predicated image is wrong

correct image will be like this @ptrblck