I’m working on a text classification problem with BERT. When training on the local machine everything works just fine, but when switching to the server, I get the following error:
<ipython-input-28-508d35ac5f5f> in flat_accuracy(preds, labels)
5 pred_flat = np.argmax(preds, axis=1).flatten()
6 labels_flat = labels.flatten()
----> 7 return np.sum(pred_flat == labels_flat) / len(labels_flat)
8
9 # Function to calculate the f1_score of our predictions vs labels
TypeError: eq() received an invalid combination of arguments - got (numpy.ndarray), but expected one of:
* (Tensor other)
didn't match because some of the arguments have invalid types: (numpy.ndarray)
* (Number other)
didn't match because some of the arguments have invalid types: (numpy.ndarray)
Code:
def flat_accuracy(preds, labels):
pred_flat = np.argmax(preds, axis=1).flatten()
labels_flat = labels.flatten()
return np.sum(pred_flat == labels_flat) / len(labels_flat)
Torch version on local machine: 1.4.0
Torch version on the server: 1.3.1
Any help would be greatly appreciated!