Kaiming Init — A Consolidation (fastai)

Greetings! In this article, I will be consolidating my thoughts on Jeremy Howard’s implementation of Kaiming Init from fastai’s Lesson 9. When I encountered this implementation, I had a few troubles in my understanding:

  • How does the code relate to the math from the paper?
  • What is gain(a)?
  • Why is he multiplying std by √3 to get bound?
  • How does he derive fan_in and fan_out?

Code vs Paper

The code for Kaiming Init in the fully connected layer (as presented in Lesson 8) is very different from that in the convolutional layer. I guess this was what threw me off a bit.

# Kaiming Init for fully connected layerw1 = torch.randn(m,nh)*math.sqrt(2/m)
b1 = torch.zeros(nh)
w2 = torch.randn(nh,1)*math.sqrt(2/nh)
b2 = torch.zeros(1)
# Kaiming Init for convolutional layerdef gain(a): return math.sqrt(2.0 / (1 + a**2))def kaiming2(x,a, use_fan_out=False):
nf,ni,*_ = x.shape
rec_fs = x[0,0].shape.numel()
fan = nf*rec_fs if use_fan_out else ni*rec_fs
std = gain(a) / math.sqrt(fan)
bound = math.sqrt(3.) * std

To relate code to paper, I first did a read-through of the relevant section in the paper (2.2. Initialization of Filter Weights for Rectifiers). I was quickly bogged down by the math and searched YouTube hoping for a video explanation. I came across Andrew Ng’s video called Weight Initialization in a Deep Network.

At about 3 minutes in, I was puzzled as to how Andrew’s math was much simpler than what I had seen in the paper. For one, the paper had multiple lines of math but Andrew simply showed this:

The 1/n was for Xavier Init (the predecessor to Kaiming Init) and Kaiming Init was really only different in that instead of a 1, it had a 2 for its numerator. What I took away from this video was quite simply that the idea of both weight initialisations was to set the variance to a certain value. With this in mind, I went back to the research paper and it started making more sense.

In the blue box, the authors are making the case for why the Xavier Init fails when there’s ReLU and are only introducing the initialisation design in the red box. I manipulated the final variance equation (equation 10) and it started “looking” more like the code already.

For presentation purposes, I shall remove the math.sqrt(2) numerator in gain(a) to show that we’ve already got something very similar to the the equation in the paper.

std = math.sqrt(2.0) / math.sqrt(fan) * math.sqrt(1 / (1 + a**2))

n corresponds to fan. With that in mind, we realise the only difference is the √(1 / (1 + a**2)) that multiplies the denominator.


I rewatched the video to find out if Jeremy had explained what a meant.

It had something to do with Leaky ReLU. In other searches on this topic, I had also came across this Medium article by SG who gave a nice breakdown of what a meant when it took on different values.

I then quickly realised (almost out of luck from scrolling past the only section we were told to read) that this was actually in the paper.

math.sqrt(3.) * std

After answering my first two questions, the third one kind of just fell into place. For a uniform distribution that is centred at 0, we can derive what the bounds have to be to allow for a certain value of standard deviation.

I previously did MIT OCW 18.05 so I am well aware that wikipedia will give me the formula for the variance of any distribution in a nice table. And there it was:

Substituing b = -a to allow mean = 0, I quickly derived that for a uniform distribution centred at 0 to have a certain standard deviation, the absolute value of its bounds should be the value of the standard deviation multiplied by √3

fan_in and fan_out

The first three cells were simple to understand. The second cell gives us the number of elements in our convolution kernel. The third cell gives us the number of in_channels and out_channels (or a.k.a number of kernels). What I didn’t get was the fourth cell. It was more intuitive in the case of the fully connected layer since you could somewhat visualise the matrix multiplication. Then fan_in is the number of input neurons and fan_out is the number of output neurons. The new idea here in Lesson 9 is that convolutional layers, however, can actually be visualised as matrix multiplication which Matthew Kleinsmith explores here. This idea is also represented in the paper itself:

Initially, I tried to visualise the dimensions of the supposed weight matrix if we were to view the convolutional layer as matrix multiplication just like what Kleinsmith did. This got really messy especially since the matrices are huge. But even before getting the final dimensions, I notice that the numbers weren’t going to match…

Relooking at the paper, I realised that the weight matrix, that we are using Kaiming Init on, is different from the one that Kleinsmith pictures. Its not exactly the whole image (flattened into a vector) thats multipying a weight matrix. What the paper actually describes is a weight matrix that

  • has the weights of all the filters
  • multiplies a single mini-frame (k x k) of the whole image

This means that for a weight matrix in a convolutional layer, the fan_in i.e number of input neurons is the number of elements in x which is k²c. This corresponds to 5*5*1=25 in the fastai example (where k=5 and c=1). Likewise fan_out i.e the number of output neurons is the number of elements in y which is dn = dk²c. And this corresponds to 32*5*5=800 (where d=32).

Closing off

I hope this article helps someone who might be facing the same troubles as I did. I also wish to motivate others who might feel discouraged for having to take more time to understand certain new things. Just for this block of code, I went back and forth a variety of resources (some not even linked here to prevent more confusion than this article might already give). This is also my first Medium article hence I’ll appreciate any feedback on my writing. Do also feel free to drop me any question (if any) regarding what I’ve shared. Thank you.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store