Deep dive into Flux.jl
Today we’ll be taking a look at some source code behind Flux.jl. If you’re here, you’ve likely tried to work with Flux, but been turned off by the lack of. documentation. So, today we’ll be showing you how easy it is to read the source code itself, and that its actually quite simple once you cross the initial hurdles.
Most of the complicated code is related to auto-differentiation and abstracted out in another package, Zygote.jl. You can read about the in-depth details of Zygote.jl here, but we’ll not be touching it as its quite advanced in both, the novelty of its concept as well as the implementation. So let’s start with some simple loss functions’ implementations.
Mean Square Error
Mathematically, given nn-dimensional vectors xx and yy, we want \frac{1}{n}\sum_{i=1}^{n} (x-y)^2n1∑i=1n(x−y)2. So we accomplish by using the vectorised version of the subtraction operator, .-
along with that of the exponentiation operator, .^
while also providing a keyword argument agg
. The keyword argument, by default returns the mean of the processed vector, but also allows us to supply a custom aggregation function if need be.
mse(ŷ, y; agg=mean) = agg((ŷ .- y).^2)
Mean absolute error
Given equidimensional vectors xx and yy, we simply want \frac{1}{n}\sum_{i=1}^{n} |x_i-y_i|n1∑i=1n∣xi−yi∣. So, we proceed like in the previous case, but also apply the abs
function element-wise to obtain the absolute value of each element before aggregating
mae(ŷ, y; agg=mean) = agg(abs.(ŷ .- y))
Cross-entropy
Here we come across the final trick used in the file defining in-built loss functions. First of all, xlogy
element-wise converts the pair (y_i,\hat{y}_i)(yi,y^i) to y_i\times log(\hat{y}_i+\epsilon)yi×log(y^i+ϵ) where \epsilonϵ is essentially a stand-in for a very small value of the same type as elements of ŷ
. Then they’re summed and negated, finally being aggregated.
function crossentropy(ŷ, y; dims=1, agg=mean, ϵ=epseltype(ŷ))
agg(.-sum(xlogy.(y, ŷ .+ ϵ); dims=dims))
end
Now that we’re comfortable with loss functions, we’ll take a closer look at the layer we used a couple articles ago.
Dense
This defines the most basic layer we have, a fully connected layer, or, to put it simply, a weights matrix along with a bias. The implementation, like the rest of Flux, is quite simple, consisting of a structure and a constructor.
The structure is defined to have an array, W
, a vector, possibly full of zeroes b
and an activation σ
:
struct Dense{F,S<:AbstractArray,T<:Union{Zeros, AbstractVector}}
W::S
b::T
σ::F
end
Now, the default constructor is defined to accept the input and output size as an integer. And also an activation function that defaults to the identity map. Optionally, you can change the way weights and biases are initialised too, by default, the latter becomes a null-vector of the appropriate size while the former uses the Glorot initialisation algorithm.
Dense(W, b) = Dense(W, b, identity)
function Dense(in::Integer, out::Integer, σ = identity;
initW = glorot_uniform, initb = zeros, bias=true)
return Dense(initW(out, in), create_bias(bias, initb, out), σ)
end
The way we combine multiple layers is through the Chain
function, so we’ll talk about that next.
Chain
This, function, even though forming the bedrock of our model, is just three simple lines of code.
struct Chain{T<:Tuple}
layers::T
Chain(xs...) = new{typeof(xs)}(xs)
end
functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)
This speaks to the simplicity of Flux and the power of Julia’s design. All that Chain
does is take in a tuple of layers, and then use splatting(...
) along with a clever use of lambda-function (→
) to use functor
to add layers to a struct via tuple destructuring!
What is happening here is that when you call, say Chain(foo,bar)
, the third line of above code-block gets called, which then, due to multiple-dispatch, calls the last line, which then puts the layers into the Chain
object that it got called from. This process repeats recursively due to the lambda at the end of the last line, until the list of layers runs out, which stops the last lambda from running, and the functor parses back the whole chain of layers into a Chain
.
Train
The heart of our machine learning efforts, the train!
function provides a simple training-loop suitable for most. scenarios. It accepts a loss function loss
, model parameters ps
, data data
and an optimiser opt
. Optionally, it also takes a callback function cb
, which gives an easy hook into the training loop, and can be used for both displaying progress and injecting custom code.
There’s not much going on inside, line 3 starts a loop over each example in the given data, and then uses a try-catch loop to do gradient descent while gracefully handling the possible arising exceptions. If the descent is successful, it updates the optimiser, parameters and the gradient before proceeding to call the callback-function and moving to the next example in the data.
function train!(loss, ps, data, opt; cb = () -> ())
ps = Params(ps)
cb = runall(cb)
@progress for d in data
try
gs = gradient(ps) do
loss(batchmemaybe(d)...)
end
update!(opt, ps, gs)
cb()
catch ex
if ex isa StopException
break
elseif ex isa SkipException
continue
else
rethrow(ex)
end
end
end
end
You’ll have noted that this only trains once, i.e. for only one epoch. If we want to train for multiple epochs, we use the @epochs
macro.
Epochs macro
Following is the source code of @epochs
macro taken from here. Again, in the spirit of Flux, its as simple as possible. It accepts two parameters, the number of epochs we want n
, and the training-loop ex
. Then, it simply repeats ex
as many times as required, printing the progress to the terminal every epoch using the @info
macro.
macro epochs(n, ex)
:(@progress for i = 1:$(esc(n))
@info "Epoch $i"
$(esc(ex))
end)
end
Now you’re sufficiently proficient in reading Flux’s source code that a lack of documentation won’t hinder you. And you’ll be able to confidently debug your code without having to rely on StackOverflow.