Today we write the “Hello World” of machine learning in Flux, training a simple neural net to classify hand-written digits from the MNIST database. As you’re here on Machine Learning Geek, I don’t believe you need any introduction to this, so let’s get right into the action.
First of all we load all the data using
Flux‘s built-in functions:
images = Flux.Data.MNIST.images();
labels = Flux.Data.MNIST.labels();
We can check out random images and labels:
We’ll preprocess as usual. Lets pick an image:
Now we’ll reshape it into a vector:
But currently, it’s a vector of fixed-point numbers, which won’t play well with our modeling:
So we’ll convert it into a floating-point vector:
Now we’ll apply these steps to all the images in one fell swoop, and then convert our image data into columns.
Now we’ll use another function from
Flux to one-hot encode our labels:
y = onehotbatch(labels, 0:9)
We’ll now define our model to consist of two fully connected layers, with a softmax at the end for easing inference:
m = Chain(
As you can see, our model takes in a 28\times 2828×28 array, runs it through the first layer to get a 40-dimensional vector, which the second layer then converts into a 10-dimenssional vector to be fed into the softmax activation function for inference. Now that we have the ingridents, lets put them together:
loss(X, y) = crossentropy(m(X), y)
opt = ADAM()
As is easy to figure out, we’re going to use log-loss as our loss function, and ADAM for gradient-descent. Finally, we’re going to define function to show us the loss after every epoch. This isn’t mandatory, but definitely enhances our quality of life.
progress = () -> @show(loss(X, y)) # callback to show loss
And now we train our model, and request it to report the loss every 10 seconds.
@epochs 100 Flux.train!(loss, params(m),[(X,y)], opt, cb = throttle(progress, 10))
As you probably remember from our last article on Flux,
train! only trains for one epoch at a time, so we import and use the
@epochs macro to train for 100 epochs. If you’re working in Pluto notebooks like me, do note that the
progress the callback will print to your terminal:
Now that we’re done training, we ought to verify that our model is working well, so we check it on a random test image:
# Prediction: $(onecold(m(hcat(float.(reshape.(Flux.Data.MNIST.images(:test), :))...)[:,foobar])) - 1)
Here’s how it looks in Pluto:
Bonus, if you want to check the accuracy of our training and test sets, the following terse code shall display it for you:
accuracy(X, y) = mean(onecold(m(X)) .== onecold(y))
# Training accuracy: $(accuracy(X, y))
# Test accuracy: $(accuracy(hcat(float.(reshape.(Flux.Data.MNIST.images(:test), :))...), onehotbatch(Flux.Data.MNIST.labels(:test), 0:9)))"""
The above two code-blocks are simply applying our pre-processing steps to test data, and then displaying them in a pretty way.
Don’t worry if you don’t yet grok macros or Flux, this was just an appetizer, we’ll be hand-rolling a regression and an SVM model soon. As well as discussing Macros and Flux in more detail.
As usual, you can access the whole notebook at your leisure as well: