tfm.nlp.ops.SamplingModule
Stay organized with collections
Save and categorize content based on your preferences.
Implementation for sampling strategies (go/decoding-tf-nlp).
tfm.nlp.ops.SamplingModule(
symbols_to_logits_fn,
vocab_size: int,
max_decode_length: int,
eos_id: int,
padded_decode: bool,
length_normalization_fn: Optional[Callable[[int, tf.DType], float]] = None,
top_k=0,
top_p=1.0,
sample_temperature=0.0,
enable_greedy: bool = True,
dtype: tf.DType = tf.float32,
decoding_name: Optional[str] = None,
extra_cache_output: bool = False
)
Methods
generate
View source
generate(
initial_ids: tf.Tensor,
initial_cache: Dict[str, tf.Tensor],
initial_log_probs: Optional[tf.Tensor] = None
) -> Output
Implements the decoding strategy (beam_search or sampling).
Args |
initial_ids
|
initial ids to pass into the symbols_to_logits_fn. int tensor
with shape [batch_size, 1]
|
initial_cache
|
dictionary for caching model outputs from previous step.
|
initial_log_probs
|
Optionally initial log probs if there is a prefix
sequence we want to start to decode from.
|
Returns |
Tuple of tensors representing
finished_sequence: shape [batch, max_seq_length]
finished_scores: [batch]
first_cache: The cache after init token
|
inf
View source
inf()
Returns a value close to infinity, but is still finite in dtype
.
This is useful to get a very large value that is still zero when multiplied
by zero. The floating-point "Inf" value is NaN when multiplied by zero.
Returns |
A very large value.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-02-02 UTC.
[{
"type": "thumb-down",
"id": "missingTheInformationINeed",
"label":"Missing the information I need"
},{
"type": "thumb-down",
"id": "tooComplicatedTooManySteps",
"label":"Too complicated / too many steps"
},{
"type": "thumb-down",
"id": "outOfDate",
"label":"Out of date"
},{
"type": "thumb-down",
"id": "samplesCodeIssue",
"label":"Samples / code issue"
},{
"type": "thumb-down",
"id": "otherDown",
"label":"Other"
}]
[{
"type": "thumb-up",
"id": "easyToUnderstand",
"label":"Easy to understand"
},{
"type": "thumb-up",
"id": "solvedMyProblem",
"label":"Solved my problem"
},{
"type": "thumb-up",
"id": "otherUp",
"label":"Other"
}]
{"lastModified": "Last updated 2024-02-02 UTC."}
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-02-02 UTC."],[],[]]