|View source on GitHub|
Estimator with TPU support.
tf.compat.v1.estimator.tpu.TPUEstimator( model_fn=None, model_dir=None, config=None, params=None, use_tpu=True, train_batch_size=None, eval_batch_size=None, predict_batch_size=None, batch_axis=None, eval_on_tpu=True, export_to_tpu=True, export_to_cpu=True, warm_start_from=None, embedding_config_spec=None, export_saved_model_api_version=ExportSavedModelApiVersion.V1 )
Migrate to TF2
TPU Estimator manages its own TensorFlow graph and session, so it is not
compatible with TF2 behaviors. We recommend that you migrate to the newer
tf.distribute.TPUStrategy. See the
TPU guide for details.
TPUEstimator also supports training on CPU and GPU. You don't need to define
TPUEstimator handles many of the details of running on TPU devices, such as replicating inputs and models for each core, and returning to host periodically to run hooks.
TPUEstimator transforms a global batch size in params to a per-shard batch
size when calling the
model_fn. Users should specify
global batch size in constructor, and then get the batch size for each shard
model_fngets per-core batch size;
input_fnmay get per-core or per-host batch size depending on
TPUConfig(See docstring for TPUConfig for details).
For evaluation and prediction,
model_fngets per-core batch size and
input_fnget per-host batch size.
model_fn should return
TPUEstimatorSpec, which expects the
for TPU evaluation. If eval_on_tpu is False, the evaluation will execute on
CPU or GPU; in this case the following discussion on TPU evaluation does not
TPUEstimatorSpec.eval_metrics is a tuple of
tensors could be a list of any nested structure of
TPUEstimatorSpec for details).
metric_fn takes the
tensors and returns
a dict from metric string name to the result of calling a metric function,
(metric_tensor, update_op) tuple.
One can set
False for testing. All training, evaluation, and
predict will be executed on CPU.
model_fn will receive
eval_batch_size unmodified as
TPU evaluation only works on a single host (one TPU worker) except BROADCAST mode.
input_fnfor evaluation should NOT raise an end-of-input exception (
StopIteration). And all evaluation step