tf_agents.bandits.networks.global_and_arm_feature_network.create_feed_forward_common_tower_network
Stay organized with collections
Save and categorize content based on your preferences.
Creates a common tower network with feedforward towers.
tf_agents.bandits.networks.global_and_arm_feature_network.create_feed_forward_common_tower_network(
observation_spec: tf_agents.typing.types.NestedTensorSpec
,
global_layers: Sequence[int],
arm_layers: Sequence[int],
common_layers: Sequence[int],
output_dim: int = 1,
global_preprocessing_combiner: Optional[Callable[..., tf_agents.typing.types.LossFn
]] = None,
arm_preprocessing_combiner: Optional[Callable[..., tf_agents.typing.types.LossFn
]] = None,
activation_fn: Callable[[tf_agents.typing.types.Tensor
], tf_agents.typing.types.Tensor
] = tf.keras.activations.relu,
name: Optional[str] = None
) -> tf_agents.typing.types.Network
Used in the notebooks
The network produced by this function can be used either in
GreedyRewardPredictionPolicy
, or NeuralLinUCBPolicy
.
In the former case, the network must have output_dim=1
, it is going to be an
instance of QNetwork
, and used in the policy as a reward prediction network.
In the latter case, the network will be an encoding network with its output
consumed by a reward layer or a LinUCB method. The specified output_dim
will
be the encoding dimension.
Args |
observation_spec
|
A nested tensor spec containing the specs for global as
well as per-arm observations.
|
global_layers
|
Iterable of ints. Specifies the layers of the global tower.
|
arm_layers
|
Iterable of ints. Specifies the layers of the arm tower.
|
common_layers
|
Iterable of ints. Specifies the layers of the common tower.
|
output_dim
|
The output dimension of the network. If 1, the common tower will
be a QNetwork. Otherwise, the common tower will be an encoding network
with the specified output dimension.
|
global_preprocessing_combiner
|
Preprocessing combiner for global features.
|
arm_preprocessing_combiner
|
Preprocessing combiner for the arm features.
|
activation_fn
|
A keras activation, specifying the activation function used
in all layers. Defaults to relu.
|
name
|
The network name to use. Shows up in Tensorboard losses.
|
Returns |
A network that takes observations adhering observation_spec and outputs
reward estimates for every action.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-04-26 UTC.
[{
"type": "thumb-down",
"id": "missingTheInformationINeed",
"label":"Missing the information I need"
},{
"type": "thumb-down",
"id": "tooComplicatedTooManySteps",
"label":"Too complicated / too many steps"
},{
"type": "thumb-down",
"id": "outOfDate",
"label":"Out of date"
},{
"type": "thumb-down",
"id": "samplesCodeIssue",
"label":"Samples / code issue"
},{
"type": "thumb-down",
"id": "otherDown",
"label":"Other"
}]
[{
"type": "thumb-up",
"id": "easyToUnderstand",
"label":"Easy to understand"
},{
"type": "thumb-up",
"id": "solvedMyProblem",
"label":"Solved my problem"
},{
"type": "thumb-up",
"id": "otherUp",
"label":"Other"
}]
{"lastModified": "Last updated 2024-04-26 UTC."}
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[]]