Warning: This project is deprecated. TensorFlow Addons has stopped development,
The project will only be providing minimal maintenance releases until May 2024. See the full
announcement here or on
github.
tfa.seq2seq.ScheduledEmbeddingTrainingSampler
Stay organized with collections
Save and categorize content based on your preferences.
A training sampler that adds scheduled sampling.
Inherits From: TrainingSampler
, Sampler
tfa.seq2seq.ScheduledEmbeddingTrainingSampler(
sampling_probability: tfa.types.TensorLike
,
embedding_fn: Optional[Callable] = None,
time_major: bool = False,
seed: Optional[int] = None,
scheduling_seed: Optional[TensorLike] = None
)
Returns -1s for sample_ids where no sampling took place; valid
sample id values elsewhere.
Args |
sampling_probability
|
A float32 0-D or 1-D tensor: the probability
of sampling categorically from the output ids instead of reading
directly from the inputs.
|
embedding_fn
|
A callable that takes a vector tensor of ids
(argmax ids).
|
time_major
|
Python bool. Whether the tensors in inputs are time
major. If False (default), they are assumed to be batch major.
|
seed
|
The sampling seed.
|
scheduling_seed
|
The schedule decision rule sampling seed.
|
Raises |
ValueError
|
if sampling_probability is not a scalar or vector.
|
Attributes |
batch_size
|
Batch size of tensor returned by sample .
Returns a scalar int32 tensor. The return value might not
available before the invocation of initialize(), in this case,
ValueError is raised.
|
sample_ids_dtype
|
DType of tensor returned by sample .
Returns a DType. The return value might not available before the
invocation of initialize().
|
sample_ids_shape
|
Shape of tensor returned by sample , excluding the batch dimension.
Returns a TensorShape . The return value might not available
before the invocation of initialize().
|
Methods
initialize
View source
initialize(
inputs, sequence_length=None, mask=None, embedding=None
)
Initialize the TrainSampler.
Args |
inputs
|
A (structure of) input tensors.
|
sequence_length
|
An int32 vector tensor.
|
mask
|
A boolean 2D tensor.
|
Returns |
(finished, next_inputs), a tuple of two items. The first item is a
boolean vector to indicate whether the item in the batch has
finished. The second item is the first slide of input data based on
the timestep dimension (usually the second dim of the input).
|
View source
next_inputs(
time, outputs, state, sample_ids
)
Returns (finished, next_inputs, next_state)
.
sample
View source
sample(
time, outputs, state
)
Returns sample_ids
.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-05-25 UTC.
[{
"type": "thumb-down",
"id": "missingTheInformationINeed",
"label":"Missing the information I need"
},{
"type": "thumb-down",
"id": "tooComplicatedTooManySteps",
"label":"Too complicated / too many steps"
},{
"type": "thumb-down",
"id": "outOfDate",
"label":"Out of date"
},{
"type": "thumb-down",
"id": "samplesCodeIssue",
"label":"Samples / code issue"
},{
"type": "thumb-down",
"id": "otherDown",
"label":"Other"
}]
[{
"type": "thumb-up",
"id": "easyToUnderstand",
"label":"Easy to understand"
},{
"type": "thumb-up",
"id": "solvedMyProblem",
"label":"Solved my problem"
},{
"type": "thumb-up",
"id": "otherUp",
"label":"Other"
}]
{"lastModified": "Last updated 2023-05-25 UTC."}
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-05-25 UTC."],[],[]]