tfp.util.externalize_variables_as_args(
fn,
fn_args=(),
ancestor_variables=None,
possible_ancestor_vars=None,
assert_variable_override=False,
name=None
)
"Converts variables within a callable into explicit args.
Makes a new callable from fn
which has arguments list(fn_args) +
list(ancestor_variables)
. If ancestor_variables
is not specified, it is
inferred by checking which of possible_ancestor_vars
actually influences the
return value of fn
(concretely, gradient of fn(*fn_args)
is not None
).
By default possible_ancestor_vars
is tf.trainable_variables() +
tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)
.
Examples
num_samples = 2
num_dims = 1
dtype = np.float32
def foo(x):
x = tf.convert_to_tensor(x, dtype=dtype, name="x")
s = x.shape.as_list()
y = tf.get_variable(
name="y",
dtype=dtype,
initializer=np.arange(np.prod(s)).reshape(s).astype(dtype))
return x + y
x = tf.constant(dtype([0.1, 0.2]))
wrapped_foo, discovered_ancestor_variables = (
externalize_variables_as_args(foo, [x]))
new_x = dtype([[1.], [2.]])
new_y = dtype([[3.], [4.]])
new_result = wrapped_foo(new_x, new_y)
# ==> [[4.], [6.]]
discovered_ancestor_variables == [tf.get_variable("y", dtype)]
# ==> [True]
Args:
fn
: Python callable which returns aTensor
and accepts*fn_args
.fn_args
: Python list of args tofn
. Represents dummy arguments passed tofn
to trace its execution; actual values are unimportant. These args are only used to construct the output offn
and to resolve the ancestortf.Variable
s. Default value:()
(i.e.,fn
takes no args).ancestor_variables
: Python list oftf.Variable
s. WhenNone
the list is expanded to non-None
gradients offn(*fn_args)
. By directly providing theancestor_variables
the internal call tofn
is avoided. Default value:None
(i.e.,tf.Variable
dependencies are discovered).possible_ancestor_vars
: Python list of possibletf.Variable
s which might be a dependency of computingfn(*fn_args)
. Default value:None
(i.e., expanded as described above).assert_variable_override
: Pythonbool
indicating that not finding atf.Variable
in the override list is an exception. Default value:False
(i.e., missing aVariable
triggers awarning
).name
: Pythonstr
name prefixed to Ops created by this function. Default value:None
(i.e., "externalize_variables_as_args").
Returns:
wrapped_fn
: Python callable taking arguments like*(list(fn_args) + discovered_ancestor_variables)
.discovered_ancestor_variables
: Python list oftf.Variable
s known to be a dependency offn(*fn_args)
.
Raises:
ValueError
: ifassert_variable_override
isTrue
andVariable
is requested but not overridden.