How to create a confusion matrix in PyTorch

This is a short tutorial on how to create a confusion matrix in PyTorch. I’ve often seen people have trouble creating a confusion matrix. But this is a helpful metric to see how well each class performs in your dataset. It can help you find problems between classes.

Confusion Matrix MNIST-FASHION dataset
Confusion Matrix MNIST-FASHION dataset

If you were only interested in coding the matrix. Jump directly to “Build confusion matrix” at the end of this article. You will also find the link to my code on GITHub at the end.

If you want to use Tensorboard instead go to:

For all others… first things first. Let’s start at the beginning. Download a MNIST dataset and train a simple Convolutional Neural Network.

Load the data:

Loading the FashionMNIST datatset.

The Conv-Net:

This is a simple architecture of a Conv-Net. Not fancy but it works!

Train the data:

Feed the Conv-Net with the data. Reduce the epochs if you have a slow CPU.

You will get something like this:

Epoch-1 lr: 0.001 
Training loss 1.7553951950371265 Steps: 999
Training loss 0.8032121055871249 Steps: 1999
Training loss 0.6876848935596644 Steps: 2999
Training loss 0.6286471617035567 Steps: 3999
Training loss 0.5757026640549302 Steps: 4999
Training loss 0.5538745389506221 Steps: 5999
Training loss 0.5094906052090228 Steps: 6999
Epoch-2 lr: 0.001
Training loss 0.72671784534771 Steps: 8499
....

Build confusion matrix:

Finally, this is why you are here.

You should see:

Congrats you got it!!!

You can donwload the full notebook here:

https://github.com/cbernecker/medium/blob/main/confusion_matrix.ipynb

Software Developer and Data Scientist