|View source on GitHub|
Custom_getter class is used to do:
- Change trainable variables to local collection and place them at worker device
- Generate global variables(global center variables)
- Generate grad variables(gradients) which record the gradients sum and place them at worker device Notice that the class should be used with tf.replica_device_setter, so that the global center variables and global step variable can be placed at ps device.
Args: worker_device: put the grad_variables on worker device
__call__( getter, name, trainable, collections, *args, **kwargs )
Call self as a function.