Deep dive into Flux.jl

Today we’ll be taking a look at some source code behind Flux.jl. If you’re here, you’ve likely tried to work with Flux, but been turned off by the lack of. documentation. So, today we’ll be showing you how easy it is to read the source code itself, and that its actually quite simple once you cross the initial hurdles.

Most of the complicated code is related to auto-differentiation and abstracted out in another package, Zygote.jl. You can read about the in-depth details of Zygote.jl here, but we’ll not be touching it as its quite advanced in both, the novelty of its concept as well as the implementation. So let’s start with some simple loss functions’ implementations.

Mean Square Error

Mathematically, given nn-dimensional vectors xx and yy, we want \frac{1}{n}\sum_{i=1}^{n} (x-y)^2n1​∑i=1n​(xy)2. So we accomplish by using the vectorised version of the subtraction operator, .- along with that of the exponentiation operator, .^ while also providing a keyword argument agg. The keyword argument, by default returns the mean of the processed vector, but also allows us to supply a custom aggregation function if need be.

mse(ŷ, y; agg=mean) = agg((ŷ .- y).^2)

Mean absolute error

Given equidimensional vectors xx and yy, we simply want \frac{1}{n}\sum_{i=1}^{n} |x_i-y_i|n1​∑i=1n​∣xi​−yi​∣. So, we proceed like in the previous case, but also apply the abs function element-wise to obtain the absolute value of each element before aggregating

mae(ŷ, y; agg=mean) = agg(abs.(ŷ .- y))

Cross-entropy

Here we come across the final trick used in the file defining in-built loss functions. First of all, xlogy element-wise converts the pair (y_i,\hat{y}_i)(yi​,y^​i​) to y_i\times log(\hat{y}_i+\epsilon)yi​×log(y^​i​+ϵ) where \epsilonϵ is essentially a stand-in for a very small value of the same type as elements of ŷ. Then they’re summed and negated, finally being aggregated.

function crossentropy(ŷ, y; dims=1, agg=mean, ϵ=epseltype(ŷ))
    agg(.-sum(xlogy.(y, ŷ .+ ϵ); dims=dims))
end

Now that we’re comfortable with loss functions, we’ll take a closer look at the layer we used a couple articles ago.

Dense

This defines the most basic layer we have, a fully connected layer, or, to put it simply, a weights matrix along with a bias. The implementation, like the rest of Flux, is quite simple, consisting of a structure and a constructor.

The structure is defined to have an array, W, a vector, possibly full of zeroes b and an activation σ:

struct Dense{F,S<:AbstractArray,T<:Union{Zeros, AbstractVector}}
  W::S
  b::T
  σ::F
end

Now, the default constructor is defined to accept the input and output size as an integer. And also an activation function that defaults to the identity map. Optionally, you can change the way weights and biases are initialised too, by default, the latter becomes a null-vector of the appropriate size while the former uses the Glorot initialisation algorithm.

Dense(W, b) = Dense(W, b, identity)
function Dense(in::Integer, out::Integer, σ = identity;
               initW = glorot_uniform, initb = zeros, bias=true)
  return Dense(initW(out, in), create_bias(bias, initb, out), σ)
end

The way we combine multiple layers is through the Chain function, so we’ll talk about that next.

Chain

This, function, even though forming the bedrock of our model, is just three simple lines of code.

struct Chain{T<:Tuple}
  layers::T
  Chain(xs...) = new{typeof(xs)}(xs)
end

functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)

This speaks to the simplicity of Flux and the power of Julia’s design. All that Chain does is take in a tuple of layers, and then use splatting(...) along with a clever use of lambda-function () to use functor to add layers to a struct via tuple destructuring!

What is happening here is that when you call, say Chain(foo,bar), the third line of above code-block gets called, which then, due to multiple-dispatch, calls the last line, which then puts the layers into the Chain object that it got called from. This process repeats recursively due to the lambda at the end of the last line, until the list of layers runs out, which stops the last lambda from running, and the functor parses back the whole chain of layers into a Chain.

Train

The heart of our machine learning efforts, the train! function provides a simple training-loop suitable for most. scenarios. It accepts a loss function loss, model parameters ps, data data and an optimiser opt. Optionally, it also takes a callback function cb, which gives an easy hook into the training loop, and can be used for both displaying progress and injecting custom code.

There’s not much going on inside, line 3 starts a loop over each example in the given data, and then uses a try-catch loop to do gradient descent while gracefully handling the possible arising exceptions. If the descent is successful, it updates the optimiser, parameters and the gradient before proceeding to call the callback-function and moving to the next example in the data.

function train!(loss, ps, data, opt; cb = () -> ())
  ps = Params(ps)
  cb = runall(cb)
  @progress for d in data
    try
      gs = gradient(ps) do
        loss(batchmemaybe(d)...)
      end
      update!(opt, ps, gs)
      cb()
    catch ex
      if ex isa StopException
        break
      elseif ex isa SkipException
        continue
      else
        rethrow(ex)
      end
    end
  end
end

You’ll have noted that this only trains once, i.e. for only one epoch. If we want to train for multiple epochs, we use the @epochs macro.

Epochs macro

Following is the source code of @epochs macro taken from here. Again, in the spirit of Flux, its as simple as possible. It accepts two parameters, the number of epochs we want n, and the training-loop ex. Then, it simply repeats ex as many times as required, printing the progress to the terminal every epoch using the @info macro.

macro epochs(n, ex)
  :(@progress for i = 1:$(esc(n))
      @info "Epoch $i"
      $(esc(ex))
    end)
end

Now you’re sufficiently proficient in reading Flux’s source code that a lack of documentation won’t hinder you. And you’ll be able to confidently debug your code without having to rely on StackOverflow.

Leave a Reply

Your email address will not be published.