# Deep dive into Flux.jl

Today we’ll be taking a look at some source code behind Flux.jl. If you’re here, you’ve likely tried to work with Flux, but been turned off by the lack of. documentation. So, today we’ll be showing you how easy it is to read the source code itself, and that its actually quite simple once you cross the initial hurdles.

Most of the complicated code is related to auto-differentiation and abstracted out in another package, Zygote.jl. You can read about the in-depth details of Zygote.jl here, but we’ll not be touching it as its quite advanced in both, the novelty of its concept as well as the implementation. So let’s start with some simple loss functions’ implementations.

### Mean Square Error

Mathematically, given n*n*-dimensional vectors x*x* and y*y*, we want \frac{1}{n}\sum_{i=1}^{n} (x-y)^2*n*1∑*i*=1*n*(*x*−*y*)2. So we accomplish by using the vectorised version of the subtraction operator, `.-`

along with that of the exponentiation operator, `.^`

while also providing a keyword argument `agg`

. The keyword argument, by default returns the mean of the processed vector, but also allows us to supply a custom aggregation function if need be.

`mse(ŷ, y; agg=mean) = agg((ŷ .- y).^2)`

### Mean absolute error

Given equidimensional vectors x*x* and y*y*, we simply want \frac{1}{n}\sum_{i=1}^{n} |x_i-y_i|*n*1∑*i*=1*n*∣*x**i*−*y**i*∣. So, we proceed like in the previous case, but also apply the `abs`

function *element-wise* to obtain the absolute value of *each* element before aggregating

`mae(ŷ, y; agg=mean) = agg(abs.(ŷ .- y))`

### Cross-entropy

Here we come across the final trick used in the file defining in-built loss functions. First of all, `xlogy`

element-wise converts the pair (y_i,\hat{y}_i)(*y**i*,*y*^*i*) to y_i\times log(\hat{y}_i+\epsilon)*y**i*×*l**o**g*(*y*^*i*+*ϵ*) where \epsilon*ϵ* is essentially a stand-in for a very small value of the same type as elements of `ŷ`

. Then they’re summed and negated, finally being aggregated.

```
function crossentropy(ŷ, y; dims=1, agg=mean, ϵ=epseltype(ŷ))
agg(.-sum(xlogy.(y, ŷ .+ ϵ); dims=dims))
end
```

Now that we’re comfortable with loss functions, we’ll take a closer look at the layer we used a couple articles ago.

### Dense

This defines the most basic layer we have, a fully connected layer, or, to put it simply, a weights matrix along with a bias. The implementation, like the rest of Flux, is quite simple, consisting of a structure and a constructor.

The structure is defined to have an array, `W`

, a *vector*, possibly full of zeroes `b`

and an activation `σ`

:

```
struct Dense{F,S<:AbstractArray,T<:Union{Zeros, AbstractVector}}
W::S
b::T
σ::F
end
```

Now, the default constructor is defined to accept the input and output size as an integer. And also an activation function that defaults to the identity map. Optionally, you can change the way weights and biases are initialised too, by default, the latter becomes a null-vector of the appropriate size while the former uses the Glorot initialisation algorithm.

```
Dense(W, b) = Dense(W, b, identity)
function Dense(in::Integer, out::Integer, σ = identity;
initW = glorot_uniform, initb = zeros, bias=true)
return Dense(initW(out, in), create_bias(bias, initb, out), σ)
end
```

The way we combine multiple layers is through the `Chain`

function, so we’ll talk about that next.

### Chain

This, function, even though forming the bedrock of our model, is just three simple lines of code.

```
struct Chain{T<:Tuple}
layers::T
Chain(xs...) = new{typeof(xs)}(xs)
end
functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)
```

This speaks to the simplicity of Flux and the power of Julia’s design. All that `Chain`

does is take in a tuple of layers, and then use splatting(`...`

) along with a clever use of lambda-function (`→`

) to use `functor`

to add layers to a struct via tuple destructuring!

What is happening here is that when you call, say `Chain(foo,bar)`

, the third line of above code-block gets called, which then, due to multiple-dispatch, calls the last line, which then puts the layers into the `Chain`

object that it got called from. This process repeats recursively due to the lambda at the end of the last line, until the list of layers runs out, which stops the last lambda from running, and the functor parses back the whole chain of layers into a `Chain`

.

### Train

The heart of our machine learning efforts, the `train!`

function provides a simple training-loop suitable for most. scenarios. It accepts a loss function `loss`

, model parameters `ps`

, data `data`

and an optimiser `opt`

. Optionally, it also takes a callback function `cb`

, which gives an easy hook into the training loop, and can be used for both displaying progress and injecting custom code.

There’s not much going on inside, line 3 starts a loop over each example in the given data, and then uses a try-catch loop to do gradient descent while gracefully handling the possible arising exceptions. If the descent is successful, it updates the optimiser, parameters and the gradient before proceeding to call the callback-function and moving to the next example in the data.

```
function train!(loss, ps, data, opt; cb = () -> ())
ps = Params(ps)
cb = runall(cb)
@progress for d in data
try
gs = gradient(ps) do
loss(batchmemaybe(d)...)
end
update!(opt, ps, gs)
cb()
catch ex
if ex isa StopException
break
elseif ex isa SkipException
continue
else
rethrow(ex)
end
end
end
end
```

You’ll have noted that this only trains once, i.e. for only one epoch. If we want to train for multiple epochs, we use the `@epochs`

macro.

### Epochs macro

Following is the source code of `@epochs`

macro taken from here. Again, in the spirit of Flux, its as simple as possible. It accepts two parameters, the number of epochs we want `n`

, and the training-loop `ex`

. Then, it simply repeats `ex`

as many times as required, printing the progress to the terminal every epoch using the `@info`

macro.

```
macro epochs(n, ex)
:(@progress for i = 1:$(esc(n))
@info "Epoch $i"
$(esc(ex))
end)
end
```

Now you’re sufficiently proficient in reading Flux’s source code that a lack of documentation won’t hinder you. And you’ll be able to confidently debug your code without having to rely on StackOverflow.