How to create a confusion matrix in PyTorch

Christian Bernecker
2 min readJan 27, 2021

This is a short tutorial on how to create a confusion matrix in PyTorch. I’ve often seen people have trouble creating a confusion matrix. But this is a helpful metric to see how well each class performs in your dataset. It can help you find problems between classes.

Confusion Matrix MNIST-FASHION dataset
Confusion Matrix MNIST-FASHION dataset

If you were only interested in coding the matrix. Jump directly to “Build confusion matrix” at the end of this article. You will also find the link to my code on GITHub at the end.

If you want to use Tensorboard instead go to:

For all others… first things first. Let’s start at the beginning. Download a MNIST dataset and train a simple Convolutional Neural Network.

Load the data:

Loading the FashionMNIST datatset.

Load FashionMNIST Dataset

The Conv-Net:

This is a simple architecture of a Conv-Net. Not fancy but it works!

Simple Convolution Neural Network — LeNet (by LeCun et al. in 1998)

Train the data:

Feed the Conv-Net with the data. Reduce the epochs if you have a slow CPU.

Train Convolution Neural Network

You will get something like this:

Epoch-1 lr: 0.001 
Training loss 1.7553951950371265 Steps: 999
Training loss 0.8032121055871249 Steps: 1999
Training loss 0.6876848935596644 Steps: 2999
Training loss 0.6286471617035567 Steps: 3999
Training loss 0.5757026640549302 Steps: 4999
Training loss 0.5538745389506221 Steps: 5999
Training loss 0.5094906052090228 Steps: 6999
Epoch-2 lr: 0.001
Training loss 0.72671784534771 Steps: 8499
....

Build confusion matrix:

Finally, this is why you are here.

Create Confusion Matirx of the Convolution Neural Network

You should see:

Want to Connect?

Thanks for reading. I hope you enjoyed it and you can get something out of it.

--

--

Christian Bernecker

AI enthusiast, speaker, and software developer passionate about leveraging technology to improve the world. Always happy to share knowledge and connect