How to train your classifier - review

This is the final part in our series on training objectives, exploring objectives you could use to train a deep learning classifier.

We've met softmax cross entropy, teacher-student training, sampled softmax, value function estimation and policy gradients. We reviewed the core ideas and walked through a typical forward and backward pass. All that's remains is to provide some demo code, so you can play around with these at your leisure.

Demo code

The demo can be found on GitHub or opened directly in Google Colab.

It follows our running example of training a classifier for small image patches on the CIFAR10 dataset, using PyTorch. If you run the code as-is, it should successfully train a model using each objective. The parameters have been chosen to give reasonable performance in each case.

  • softmax_cross_entropy (blog) trains quickly and reliably, it's the default choice for a reason!
  • teacher_student (blog) is as fast or faster than softmax cross-entropy (which it uses as teacher). Note that this example is a bit pointless, since the teacher is the same architecture as the student.
  • sampled_softmax (blog) is slower than full softmax cross-entropy. It can be improved by increasing the number of samples.
  • value_function (blog) trains relatively quickly (although still slower than full softmax cross-entropy), but can be unreliable. Recall that it's playing a harder game than previous techniques, a multi-armed contextual bandit problem.
  • policy_gradient (blog) is slower than full softmax cross-entropy, and may be unreliable. The entropy weight hyperparameter can make a big difference.

Playing around

Please do have a play with it. Or even better, just throw this example away and have a go at implementing the objectives yourself, maybe for another dataset or domain. But if you like to learn by tweaking, here are a few things you could try:

  • Explore the hyperparameters. What do alpha, n_samples, epsilon and entropy_weight do?
  • Try to train a deeper network. E.g. ResNet18 from torchvision. Which objectives are harder to train?
  • Try changing the step size or optimiser. Are there better settings for certain objectives?
  • Try removing the baseline from policy gradient. How does it perform?
  • Can you make the value function more consistent? I.e. so that the expected reward sums to one across actions. How does performance change?

References