Acyr Locatelli

model.fit for a living

Custom optimizers in TensorFlow

Posted at — Sep 17, 2020

In the post we will discuss how to implement a custom TensorFlow optimizer.

As an illustrative example, we will implement Learning Rate Dropout. This is a simple optimizer I came across a few months ago. The basic idea is to mask parameter updates (similarly to what happens to weights in standard dropout) while continuing to accumulate variables like momentum. More details can be found here.

Basics

Following TensorFlow documentation, the setup to implement a custom optimizer is fairly straight-forward. We define a class that inherits from tf.keras.optimizers.Optimizer and implement four methods:

class AdamLRD(tf.keras.optimizers.Optimizer):
	"""
		Implementation of Adam + Learning Rate Dropout
	"""

	def _resource_apply_dense(self, grad, var, apply_state=None):
		pass
	def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
		pass
	def _create_slots(self, var_list):
		pass
	def get_config(self)
		pass

Note:

  1. The documentation has a misspelling – it is missing the underscores.
  2. We don’t really need to implement _resource_apply_sparse if we don’t intend to use sparse tensors.

Useful methods

The Optimizer class has four “private” methods that are worth knowing:

def _set_hyper(self, name, value):
	...

def _get_hyper(self, name, dtype=None):
	...

def _serialize_hyperparameter(self, hyperparameter_name):
	...

def  _decayed_lr(self, var_dtype):
	...

It is pretty clear what these functions do but it is worth adding a little more context:

  1. For a small explanation about the advantages of using _set/get_hyper instead of just setting the attributes ourselves see here.
  2. The function _decayed_lr returns the learning rate given the current step in the optimizer. You can see the details in the code here.

Let’s initalise the class

We will inherit directly from tf.keras.optimizers.Adam so we can skip implementing _create_slots. However, we will briefly discuss this method later one.

class AdamLRD(tf.keras.optimizers.Adam):

...

    def __init__(self, learning_rate=0.001, dropout_rate=0.0, beta_1=0.9,
			     beta_2=0.999, epsilon=1e-7, amsgrad=False, name='AdamLRD',
			     **kwargs):
        super().__init__(learning_rate, beta_1, beta_2, epsilon, amsgrad, name,
						 **kwargs)
        self._set_hyper('dropout_rate', dropout_rate)

The only additional hyper-parameter we need to keep track of dropout_rate. The default value means it behaves just as Adam does.

One of the advantages of using _set_hyper as opposed to just writing:

	self.dropout_rate = dropout_rate

is that we get to use _serialize_hyperparameter function for free. This is particularly useful when we implement the get_config method.

class AdamLRD(tf.keras.optimizers.Adam):

...
    def get_config(self):
        config = super().get_config()
        config.update(
			{'dropout_rate': self._serialize_hyperparameter("dropout_rate")})
        return config

Creating slots

According to the documentation, slots are variables associated the parameters being trained. In the case of Adam, we need to create two slots – three if we count variables associated with AMSGrad. Source below can be found here.

class AdamLRD(tf.keras.optimizers.Adam):

...

	def _create_slots(self, var_list):
		for var in var_list:
		  self.add_slot(var, 'm')
		for var in var_list:
		  self.add_slot(var, 'v')
		if self.amsgrad:
		  for var in var_list:
			self.add_slot(var, 'vhat')

As we are inheriting these from Adam, we don’t need to implement this method.

Dense implementation

Now we get to the core of the optimizer. From the paper, all we need is to dropout the update at each step. Most of the code below is Adam with certain modifications coming from here. We can see that the learning-rate dropout itself is a couple of lines.

class AdamLRD(tf.keras.optimizers.Adam):

...

    def _resource_apply_dense(self, grad, var, apply_state=None):
		# We will use this to create tensors of the appropriate type
        var_dtype = var.dtype.base_dtype

		# NOTE: we use it here.
        lr_t = self._decayed_lr(var_dtype)
        local_step = math_ops.cast(self.iterations + 1, var_dtype)
        m = self.get_slot(var, 'm')
        v = self.get_slot(var, 'v')
        beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
        beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
        beta_1_power = math_ops.pow(beta_1_t, local_step)
        beta_2_power = math_ops.pow(beta_2_t, local_step)
        epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype)

        # bias correction
        lr_t = lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power) 

        m_t = state_ops.assign(m,
                               beta_1_t * m + (1.0 - beta_1_t) * grad,
                               use_locking=self._use_locking)
        v_t = state_ops.assign(v,
                               beta_2_t * v + (1.0 - beta_2_t
                                               ) * math_ops.square(grad),
                               use_locking=self._use_locking)
        
        if self.amsgrad:
            vhat = self.get_slot(var, 'vhat')
            vhat_t = state_ops.assign(vhat,
                                      math_ops.maximum(vhat, v_t),
                                      use_locking=self._use_locking)
            var_delta = m_t / (math_ops.sqrt(vhat_t) + epsilon_t)
        else:
            var_delta = m_t / (math_ops.sqrt(v_t) + epsilon_t)

        # =================== Learning Rate Dropout ============================
        mask = tf.cast(tf.random.uniform(grad.shape) > self.dropout_rate,
					   var_dtype)
        lr_d = lr_t * mask
		# =================== Back to business =================================

        var_t = math_ops.sub(var, lr_d * var_delta)
        var_update = state_ops.assign(var, var_t, use_locking=self._use_locking)

        updates = [var_update, m_t, v_t]

        if self.amsgrad:
            updates.append(vhat_t)
        return control_flow_ops.group(*updates)

A small note on control_flow_ops.group

This method ensures that everything in the list updates has been computed. If we are executing eagerly then this is not an issue. In fact, we see in the code that it doesn’t do anything in this case.

This operation is identical to tf.group. You can have a look at its code and docs.

Sparse implementation

This is used a little less often but its still worth implementing it. The basic idea is to replace dense operation with scatter_* family of operations that support sparse tensors.

The only new method we need is defined in tf.keras.optimizers.Optimizer:

from tensorflow.python.ops import resource_variable_ops

...

def _resource_scatter_add(self, x, i, v):
	with ops.control_dependencies(
		[resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
	return x.value()

The function ops.control_dependencies ensures everything has been computed before we call return. Similarly to the code above, this is not needed in eager mode – see here.

class AdamLRD(tf.keras.optimizers.Adam):

...

    def _resource_apply_sparse(self, grad, var, indices):
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)
        local_step = math_ops.cast(self.iterations + 1, var_dtype)
        m = self.get_slot(var, 'm')
        v = self.get_slot(var, 'v')
        beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
        beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
        beta_1_power = math_ops.pow(beta_1_t, local_step)
        beta_2_power = math_ops.pow(beta_2_t, local_step)
        epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype)

        lr_t = lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)


        m_scaled_g_values = grad * (1 - beta_1_t)
        m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking)
        
        # ensures m_t is finished 
        with ops.control_dependencies([m_t]):
            m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)

        v_scaled_g_values = (grad * grad) * (1 - beta_2_t)
        v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking)
        with ops.control_dependencies([v_t]):
            v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)

        if self.amsgrad:
            vhat = self.get_slot(var, 'vhat')
            vhat_t = state_ops.assign(vhat,
                                      math_ops.maximum(vhat, v_t),
                                      use_locking=self._use_locking)
            var_delta = m_t / (math_ops.sqrt(vhat_t) + epsilon_t)
        else:
            var_delta = m_t / (math_ops.sqrt(v_t) + epsilon_t)

        # =================== Learning Rate Dropout ============================
        mask = tf.cast(tf.random.uniform(grad.shape) > self.dropout_rate,
					   var_dtype)
        lr_d = lr_t * mask
		# =================== Back to business =================================

        var_t = math_ops.sub(var, self.eta_t * lr_t * var_delta)
        var_update = state_ops.assign(var, var_t, use_locking=self._use_locking)

        updates = [var_update, m_t, v_t]

        if self.amsgrad:
            updates.append(vhat_t)
        return control_flow_ops.group(*updates)

A note on locking

This illustrates some issues with TensorFlow locking mechanism. TensorFlow uses advisory locks: that means that a function can be called with use_locking=False and it will simply not check if another object holds the lock. Moreover, any reads to the variable are performed without use_locking, so it is possible to get an intermediate result.

Both of these can cause concurrency issues so we need to be mindful.

Conclusion

As we saw, the basic mechanism for implementing a custom TF optimiser is fairly straightforward. However, as always, the devil is in the details. I hope this small post can be a good start to coding your own.

comments powered by Disqus