Performance(accuracy) difference between desktop and android

Environment

  • Desktop (Colab)
    • pytorch version: 1.7.0
  • Android
    • torch version: 1.7.0
    • torchvision version: 1.7.0

I am working on a image classification task with mobilenet v2 on android machine.
Firstly, I train & save model with colab by below code snippet. (I have five labels)

# model definition with five labels

model = torchvision.models.mobilenet_v2(pretrained=True)
model.classifier[1] = torch.nn.Linear(in_features=model.classifier[1].in_features, out_features=5)
model.eval()
model.to(device)
# save model with quantization and optimization

from torch.utils.mobile_optimizer import optimize_for_mobile
def model_save(model, name):
    q_model = torch.quantization.convert(model)
    traced_script_module = torch.jit.trace(q_model, torch.rand(1,3,224,224).to(device))
    opt_model = optimize_for_mobile(traced_script_module)
    torch.jit.save(opt_model, name)
# training step

for epoch in range(num_epochs):
  epoch_loss = 0
  test_loss = 0
  best_test_loss = 1

  for i, samples in enumerate(train_loader):
      model.train()
      imgs, annotations = samples
      imgs, annotations = imgs.to(device), annotations.to(device)
      
      output = model(imgs)
      loss = loss_func(output,annotations)
      
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      print(f'Iteration: {i+1}/{len(train_loader)}, Loss: {loss.item()}')
      epoch_loss += loss.item()

  avg_train_loss = epoch_loss/len(train_loader)

  # Save result for plotting
  train_loss_list.append(avg_train_loss)

  # Print epoch's test loss
  print(f'Epoch {epoch} train loss: {avg_train_loss}')

  # validation
  for i,test_samples in enumerate(test_loader):
    with torch.no_grad():
      model.eval()
      test_imgs, test_annotations = test_samples
      test_imgs, test_annotations = test_imgs.to(device), test_annotations.to(device)

      test_output = model(test_imgs)
      loss = loss_func(test_output, test_annotations)
      test_loss += loss.item()

  avg_test_loss = test_loss/len(test_loader)

  # save best
  if best_test_loss > avg_test_loss:
    best_test_loss = avg_test_loss
    model_save(model, 'quantized_best.pt')

And I transform this model manually in my local machine(mac).
Because If I don’t transform model, I get some errors in android machine.

import torch
quantized_torchmodel = torch.jit.load("quantized_best.pt", map_location='cpu')
torch.jit.save(quantized_torchmodel, "best_cpu_quantized.pt")
# ERROR that appears when I don't transform model
2020-11-26 13:33:34.877 416-750/org.pytorch.demo E/PyTorchDemo: Error during image analysis
    com.facebook.jni.CppException: Could not run 'aten::empty_strided' with arguments from the 'CUDA' backend. 'aten::empty_strided' is only available for these backends: [CPU, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, Tracer, Autocast, Batched, VmapMode].

and load this model to module in android code and inference some images.

mModule = Module.load(moduleFileAbsoluteFilePath);

...

final Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor();

However, the output(score) of inference with same image is quite different between mobile and desktop(colab).

I used same image and same transformation(resize & normalization …), but output is different.
And strangely, emulator’s output is same with colab’s output.
How can it possible?

Could you use a static input (e.g. all ones) and compare the output on both platforms? If these results are equal, I would guess that the image loading (decoding) and processing might differ between the approaches.

1 Like

@ptrblck
Thank you for your reply.
However, unfortunately, outputs are different. I tried with below three environments.

  • Colab
  • Android Emulator (Pixel XL API 30)
  • Actual Android Device (Samsung Galaxy S8)

  1. Colab
// colab code snippet for testing
model = torch.jit.load(MYPATH, map_location='cpu')
model.eval()
one = torch.ones([1,3,224,224])
model(one)
// output of colab
tensor([[ 2.7785,  1.4376, -0.8041, -1.5155, -0.8308]])

  1. Android Emulator & Actual Device
// android code snippet for testing
mModule = Module.load(MYPATH);

...

// Initialize input to one
for(int i=0;i<3*224*224;i++){
    mInputTensorBuffer.put(i,1);
}

// Get output
final Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor();
final float[] scores = outputTensor.getDataAsFloatArray();
Log.d("DEBUG", Arrays.toString(scores));

// Output of Emulator
2020-11-29 12:56:51.328 3030-3030/org.pytorch.demo D/DEBUG: [2.778521, 1.4375764, -0.80405545, -1.5155371, -0.83080786]
// Output of Actual Device
2020-11-29 12:58:08.625 725-725/org.pytorch.demo D/DEBUG: [1.224158, -0.65208256, -0.3742584, -0.123134516, -0.085230656]

Emulator’s output is same with colab’s one, but android’s output is not.
I could not figure out why this happened…

I tried with very latest release of pytorch & pytorch-android.

I assume the models are all in eval mode? If so, this might be a new bug so could you create a GitHub issue so that we could track and reproduce it, please?

Yes, every models are all in eval mode.
I will create an issue about this in pytorch repo.
Thank you.

Link to Github Issue