Ter uma questão? Conecte-se com a comunidade no Fórum TensorFlow Visite o Fórum

Module: tf.compat.v1.train

Support for training models.

See the Training guide.


experimental module: Public API for tf.train.experimental namespace.

queue_runner module: Public API for tf.train.queue_runner namespace.


class AdadeltaOptimizer: Optimizer that implements the Adadelta algorithm.

class AdagradDAOptimizer: Adagrad Dual Averaging algorithm for sparse linear models.

class AdagradOptimizer: Optimizer that implements the Adagrad algorithm.

class AdamOptimizer: Optimizer that implements the Adam algorithm.

class BytesList: Container that holds repeated fundamental values of byte type in the tf.train.Feature message.

class Checkpoint: Groups trackable objects, saving and restoring them.

class CheckpointManager: Manages multiple checkpoints by keeping some and deleting unneeded ones.

class CheckpointOptions: Options for constructing a Checkpoint.

class CheckpointSaverHook: Saves checkpoints every N steps or seconds.

class CheckpointSaverListener: Interface for listeners that take action before or after checkpoint save.

class ChiefSessionCreator: Creates a tf.compat.v1.Session for a chief.

class ClusterDef: A ProtocolMessage

class ClusterSpec: Represents a cluster as a set of "tasks", organized into "jobs".

class Coordinator: A coordinator for threads.

class Example: An Example is a mostly-normalized data format for storing data for training and inference.

class ExponentialMovingAverage: Maintains moving averages of variables by employing an exponential decay.

class Feature: A Feature is a list which may hold zero or more values.

class FeatureList: Contains zero or more values of tf.train.Features.

class FeatureLists: Contains the mapping from name to tf.train.FeatureList.

class Features: Protocol message for describing the features of a tf.train.Example.

class FeedFnHook: Runs feed_fn and sets the feed_dict accordingly.

class FinalOpsHook: A hook which evaluates Tensors at the end of a session.

class FloatList: Container that holds repeated fundamental values of float type in the tf.train.Feature message.

class FtrlOptimizer: Optimizer that implements the FTRL algorithm.

class GlobalStepWaiterHook: Delays execution until global step reaches wait_until_step.

class GradientDescentOptimizer: Optimizer that implements the gradient descent algorithm.

class Int64List: Container that holds repeated fundamental value of int64 type in the tf.train.Feature message.

class JobDef: A ProtocolMessage

class LoggingTensorHook: Prints the given tensors every N local steps, every N seconds, or at end.

class LooperThread: A thread that runs code repeatedly, optionally on a timer.

class MomentumOptimizer: Optimizer that implements the Momentum algorithm.

class MonitoredSession: Session-like object that handles initialization, recovery and hooks.

class NanLossDuringTrainingError: Unspecified run-time error.

class NanTensorHook: Monitors the loss tensor and stops training if loss is NaN.

class Optimizer: Base class for optimizers.

class ProfilerHook: Captures CPU/GPU profiling information every N steps or seconds.

class ProximalAdagradOptimizer: Optimizer that implements the Proximal Adagrad algorithm.

class ProximalGradientDescentOptimizer: Optimizer that implements the proximal gradient descent algorithm.

class QueueRunner: Holds a list of enqueue operations for a queue, each to be run in a thread.

class RMSPropOptimizer: Optimizer that implements the RMSProp algorithm (Tielemans et al.

class Saver: Saves and restores variables.

class SaverDef: A ProtocolMessage

class Scaffold: Structure to create or gather pieces commonly needed to train a model.

class SecondOrStepTimer: Timer that triggers at most once every N seconds or once every N steps.

class SequenceExample: A SequenceExample is a format for representing one or more sequences and some context.

class Server: An in-process TensorFlow server, for use in distributed training.

class ServerDef: A ProtocolMessage

class SessionCreator: A factory for tf.Session.

class SessionManager: Training helper that restores from checkpoint and creates session.

class SessionRunArgs: Represents arguments to be added to a Session.run() call.

class SessionRunContext: Provides information about the session.run() call being made.

class SessionRunHook: Hook to extend calls to MonitoredSession.run().

class SessionRunValues: Contains the results of Session.run().

class SingularMonitoredSession: Session-like object that handles initialization, restoring, and hooks.

class StepCounterHook: Hook that counts steps per second.

class StopAtStepHook: Hook that requests stop at a specified step.

class SummarySaverHook: Saves summaries every N steps.

class Supervisor: A training helper that checkpoints models and computes summaries.

class SyncReplicasOptimizer: Class to synchronize, aggregate gradients and pass them to the optimizer.

class VocabInfo: Vocabulary information for warm-starting.

class WorkerSessionCreator: Creates a tf.compat.v1.Session for a worker.


MonitoredTrainingSession(...): Creates a MonitoredSession for training.

NewCheckpointReader(...): A function that returns a CheckPointReader.

add_queue_runner(...): Adds a QueueRunner to a collection in the graph. (deprecated)

assert_global_step(...): Asserts global_step_tensor is a scalar int Variable or Tensor.

basic_train_loop(...): Basic loop to train a model.

batch(...): Creates batches of tensors in tensors. (deprecated)

batch_join(...): Runs a list of tensors to fill a queue to create batches of examples. (deprecated)

checkpoint_exists(...): Checks whether a V1 or V2 checkpoint exists with the specified prefix. (deprecated)

checkpoints_iterator(...): Continuously yield new checkpoint files as they appear.

cosine_decay(...): Applies cosine decay to the learning rate.

cosine_decay_restarts(...): Applies cosine decay with restarts to the learning rate.

create_global_step(...): Create global step tensor in graph.

do_quantize_training_on_graphdef(...): A general quantization scheme is being developed in tf.contrib.quantize. (deprecated)

exponential_decay(...): Applies exponential decay to the learning rate.

export_meta_graph(...): Returns MetaGraphDef proto.

generate_checkpoint_state_proto(...): Generates a checkpoint state proto.

get_checkpoint_mtimes(...): Returns the mtimes (modification timestamps) of the checkpoints. (deprecated)

get_checkpoint_state(...): Returns CheckpointState proto from the "checkpoint" file.

get_global_step(...): Get the global step tensor.

get_or_create_global_step(...): Returns and create (if necessary) the global step tensor.

global_step(...): Small helper to get the global step.

import_meta_graph(...): Recreates a Graph saved in a MetaGraphDef proto.

init_from_checkpoint(...): Replaces tf.Variable initializers so they load from a checkpoint file.

input_producer(...): Output the rows of input_tensor to a queue for