Pure Pytorch MultiLabel Confusion Matrix?

Hi everyone -

Is there a pure-pytorch ( ie without converting tensors to numpy array ) implementation of a multilabel confusion matrix as with sklearn.metrics.confusion_matrix ?

2 Likes

Just remembered this question was dangling. My solution is below. It works on tensors or np.arrays without any changes. I’ve also added methods for precision/recall and all-of-the-above.

Note: Two divide by zero’s can occur when calculating precision and recall. (1) when you have 100% true-negatives, or (2) when there are some non-true-negative values. Note that (1) is good and (2) is bad - so I am returning 1.0 and False respectively, but you can change the config to handle this however you wish.

Here’s a gist: https://gist.github.com/brookisme/8f9f06286251af02bb9372fc35bb7fd8

Here is some code:

# CONFIG
#
BETA=2
RETURN_CMATRIX=True
INVALID_ZERO_DIVISON=False
VALID_ZERO_DIVISON=1.0



#
# METHODS
#     
def confusion_matrix(target,prediction,value,ignore_value=None):
    true=(target==prediction)
    false=(~true)
    pos=(target==value)
    neg=(~pos)
    keep=(target!=ignore_value)
    tp=(true*pos).sum()
    fp=(false*pos*keep).sum()
    fn=(false*neg*keep).sum()
    tn=(true*neg).sum()
    return _get_items(tp, fp, fn, tn)


def precision(tp,fp,fn):
    return _precision_recall(tp,fp,fn)


def recall(tp,fn,fp):
    return _precision_recall(tp,fn,fp)


def fbeta(p,r,beta=BETA):
    if p is None: p=precision(tp,fp)
    if r is None: r=recall(tp,fn)
    beta_sq=beta**2
    numerator=(beta_sq*p + r)
    if numerator:
        return (1+beta_sq)*(p*r)/numerator
    else:
        return 0

      
def stats(
        target,
        prediction,
        value,
        ignore_value=None,
        beta=BETA,
        return_cmatrix=RETURN_CMATRIX):
    tp, fp, fn, tn=confusion_matrix(
        target,
        prediction,
        value,
        ignore_value=ignore_value)
    p=precision(tp,fp,fn)
    r=recall(tp,fn,fp)
    stat_values=[p,r]
    if not _is_false(beta):
        stat_values.append(fbeta(p,r,beta=beta))
    if return_cmatrix:
        stat_values+=[tp, fp, fn, tn]
    return stat_values


#
# INTERNAL
#
def _precision_recall(a,b,c):
    if (a+b):
        return a/(a+b)
    else:
        if c:
            return INVALID_ZERO_DIVISON
        else:
            return VALID_ZERO_DIVISON


def _is_false(value):
    return value in [False,None]


def _get_items(*args):
    try:
        return list(map(lambda s: s.item(),args))
    except:
        return args
1 Like