View source on GitHub |
Custom_getter class is used to do.
tf.contrib.opt.ModelAverageCustomGetter(
worker_device
)
- Change trainable variables to local collection and place them at worker device
- Generate global variables Notice that the class should be used with tf.replica_device_setter, so that the global center variables and global step variable can be placed at ps device. Besides, use 'tf.compat.v1.get_variable' instead of 'tf.Variable' to use this custom getter.
For example, ma_custom_getter = ModelAverageCustomGetter(worker_device) with tf.device( tf.compat.v1.train.replica_device_setter( worker_device=worker_device, ps_device="/job:ps/cpu:0", cluster=cluster)), tf.compat.v1.variable_scope('',custom_getter=ma_custom_getter): hid_w = tf.compat.v1.get_variable( initializer=tf.random.truncated_normal( [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units], stddev=1.0 / IMAGE_PIXELS), name="hid_w") hid_b = tf.compat.v1.get_variable(initializer=tf.zeros([FLAGS.hidden_units]), name="hid_b")
Args | |
---|---|
worker_device
|
String. Name of the worker job.
|
Methods
__call__
__call__(
getter, name, trainable, collections, *args, **kwargs
)
Call self as a function.