View source on GitHub |
Class outlining the default Tracing Protocol for Scalarizer.
tf_agents.bandits.multi_objective.multi_objective_scalarizer.ScalarizerTraceType(
value
)
If included as an argument, corresponding tf.function will always retrace for each usage.
Derived classes can override this behavior by specifying their own Tracing Protocol.
Methods
cast
cast(
value, cast_context
) -> Any
Cast value to this type.
Args | |
---|---|
value
|
An input value belonging to this TraceType. |
cast_context
|
A context reserved for internal/future usage. |
Returns | |
---|---|
The value casted to this TraceType. |
Raises | |
---|---|
AssertionError
|
When _cast is not overloaded in subclass, the value is returned directly, and it should be the same to self.placeholder_value(). |
flatten
flatten() -> List['TraceType']
Returns a list of TensorSpecs corresponding to to_tensors
values.
from_tensors
from_tensors(
tensors: Iterator[core.Tensor]
) -> Any
Generates a value of this type from Tensors.
Must use the same fixed amount of tensors as to_tensors
.
Args | |
---|---|
tensors
|
An iterator from which the tensors can be pulled. |
Returns | |
---|---|
A value of this type. |
is_subtype_of
is_subtype_of(
_
)
Returns True if self
is a subtype of other
.
For example, tf.function
uses subtyping for dispatch:
if a.is_subtype_of(b)
is True, then an argument of TraceType
a
can be used as argument to a ConcreteFunction
traced with an
a TraceType
b
.
Args | |
---|---|
other
|
A TraceType object to be compared against. |
Example:
class Dimension(TraceType):
def __init__(self, value: Optional[int]):
self.value = value
def is_subtype_of(self, other):
# Either the value is the same or other has a generalized value that
# can represent any specific ones.
return (self.value == other.value) or (other.value is None)
most_specific_common_supertype
most_specific_common_supertype(
_
)
Returns the most specific supertype of self
and others
, if exists.
The returned TraceType
is a supertype of self
and others
, that is,
they are all subtypes (see is_subtype_of
) of it.
It is also most specific, that is, there it has no subtype that is also
a common supertype of self
and others
.
If self
and others
have no common supertype, this returns None
.
Args | |
---|---|
others
|
A sequence of TraceTypes. |
Example:
class Dimension(TraceType):
def __init__(self, value: Optional[int]):
self.value = value
def most_specific_common_supertype(self, other):
# Either the value is the same or other has a generalized value that
# can represent any specific ones.
if self.value == other.value:
return self.value
else:
return Dimension(None)
placeholder_value
placeholder_value(
placeholder_context=None
)
Creates a placeholder for tracing.
tf.funcion traces with the placeholder value rather than the actual value. For example, a placeholder value can represent multiple different actual values. This means that the trace generated with that placeholder value is more general and reusable which saves expensive retracing.
Args | |
---|---|
placeholder_context
|
A context reserved for internal/future usage. |
For the Fruit
example shared above, implementing:
class FruitTraceType:
def placeholder_value(self, placeholder_context):
return Fruit()
instructs tf.function to trace with the Fruit()
objects
instead of the actual Apple()
and Mango()
objects when it receives a
call to get_mixed_flavor(Apple(), Mango())
. For example, Tensor arguments
are replaced with Tensors of similar shape and dtype, output from
a tf.Placeholder op.
More generally, placeholder values are the arguments of a tf.function, as seen from the function's body:
@tf.function
def foo(x):
# Here `x` is be the placeholder value
...
foo(x) # Here `x` is the actual value
to_tensors
to_tensors(
value: Any
) -> List[core.Tensor]
Breaks down a value of this type into Tensors.
For a TraceType instance, the number of tensors generated for corresponding value should be constant.
Args | |
---|---|
value
|
A value belonging to this TraceType |
Returns | |
---|---|
List of Tensors. |
__eq__
__eq__(
_
)
Return self==value.