Acyr Locatelli for a living

Porting research code

Posted at — Jun 23, 2020

Having spent some time implementing ideas from papers, reproducing results/baselines can be quite painful.

For starters, that are some non-code related issues you can encounter: experiments are not particularly described well, the odd undocumented hyper-parameters, etc.

If the code is released – why this has not been the standard is something I don’t understand – it might be implemented in a framework you don’t work with. In most cases its a matter of porting code between Pytorch and TensorFlow.

Here are a few things I ran into while trying to reproduce results.

CrossEntropy loss

Pytorch’s CrossEntropyLoss works directly on logits. This improves numerical stability as we can avoid underflow problems that we encounter when summing log-probabilities.

This is not done by default in TensorFlow. If we look at both:

	label_smoothing=0, reduction=losses_utils.ReductionV2.AUTO,


    y_true, y_pred,

the parameter from_logits is set to False by default.

We just need to make sure from_logits is set correctly. In particular, we need to avoid the lazy model.compile call:

model.compile(loss="categorical_crossentropy", ... )

Note that if from_logits=True and the last operation happens to be a Softmax, the TensorFlow will automatically strip it and apply the loss to the logits. See here for more details.

The same applies to torch.nn.BCELoss and tf.keras.losses.BinaryCrossentropy.

Weight initalisation for convolutions

Pytorch’s convolutional layers use He initalisation while TensorFlow uses Xavier initalisation by default. If you are not careful while porting out the code this can be easily missed.

$l_{2}$ regularization

Pytorch optimisers have a weight_decay parameter that applies $l_{2}$ regularisation to all weights it is optimising over. That means if this is set:

torch.optim.SGD(net.parameters(), lr=0.001, weight_decay=alpha)

we need to set


in every layer in the equivalent TensorFlow code.


These are just a few differences that costed me a little time while porting code.

comments powered by Disqus