TPU embedding_columns에서 TPUEmbedding 레이어로 마이그레이션

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 소스 보기 노트북 다운로드

이 가이드는 TPUEstimator가 있는 TensorFlow 1의 embedding_column API에서 TPUEstimator 가 있는 TensorFlow 2의 TPUEmbedding 레이어 API로 TPU 에 대한 임베딩 교육을 마이그레이션하는 방법을 보여 TPUStrategy .

임베딩은 (큰) 행렬입니다. 희소 특성 공간에서 밀집 벡터로 매핑되는 조회 테이블입니다. 임베딩은 기능 간의 복잡한 유사성과 관계를 캡처하여 효율적이고 조밀한 표현을 제공합니다.

TensorFlow에는 TPU의 교육 임베딩에 대한 전문적인 지원이 포함되어 있습니다. 이 TPU 전용 임베딩 지원을 통해 단일 TPU 장치의 메모리보다 큰 임베딩을 훈련하고 TPU에서 희소하고 불규칙한 입력을 사용할 수 있습니다.

추가 정보는 tfrs.layers.embedding.TPUEmbedding 레이어의 API 문서와 tf.tpu.experimental.embedding.TableConfigtf.tpu.experimental.embedding.FeatureConfig 문서를 참조하세요. tf.distribute.TPUStrategy 에 대한 개요는 분산 교육 가이드 및 TPU 사용 가이드를 확인하세요. TPUStrategy 에서 TPUEstimator 로 마이그레이션하는 경우 TPU 마이그레이션 가이드 를 확인하세요.

설정

TensorFlow Recommenders 를 설치하고 몇 가지 필요한 패키지를 가져와 시작합니다.

pip install tensorflow-recommenders
import tensorflow as tf
import tensorflow.compat.v1 as tf1

# TPUEmbedding layer is not part of TensorFlow.
import tensorflow_recommenders as tfrs
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/requests/__init__.py:104: RequestsDependencyWarning: urllib3 (1.26.8) or chardet (2.3.0)/charset_normalizer (2.0.11) doesn't match a supported version!
  RequestsDependencyWarning)

그리고 데모용으로 간단한 데이터세트를 준비합니다.

features = [[1., 1.5]]
embedding_features_indices = [[0, 0], [0, 1]]
embedding_features_values = [0, 5]
labels = [[0.3]]
eval_features = [[4., 4.5]]
eval_embedding_features_indices = [[0, 0], [0, 1]]
eval_embedding_features_values = [4, 3]
eval_labels = [[0.8]]

TensorFlow 1: TPUEstimator를 사용하여 TPU에서 임베딩 학습

TensorFlow 1에서는 tf.compat.v1.tpu.experimental.embedding_column API를 사용하여 TPU 임베딩을 설정하고 tf.compat.v1.estimator.tpu.TPUEstimator 를 사용하여 TPU에서 모델을 학습/평가합니다.

입력은 0에서 TPU 임베딩 테이블에 대한 어휘 크기 범위의 정수입니다. tf.feature_column.categorical_column_with_identity 를 사용하여 범주 ID에 대한 입력을 인코딩하는 것으로 시작합니다. 입력 기능은 정수 값이고 num_buckets 는 임베딩 테이블( 10 )의 어휘 크기이므로 key 매개변수에 "sparse_feature" 를 사용합니다.

embedding_id_column = (
      tf1.feature_column.categorical_column_with_identity(
          key="sparse_feature", num_buckets=10))

다음으로, 희소 범주형 입력을 tpu.experimental.embedding_column 을 사용하여 조밀한 표현으로 변환합니다. 여기서 dimension 은 임베딩 테이블의 너비입니다. num_buckets 각각에 대한 임베딩 벡터를 저장합니다.

embedding_column = tf1.tpu.experimental.embedding_column(
    embedding_id_column, dimension=5)

이제 tf.estimator.tpu.experimental.EmbeddingConfigSpec 을 통해 TPU별 임베딩 구성을 정의합니다. 나중에 tf.estimator.tpu.TPUEstimatorembedding_config_spec 매개변수로 전달합니다.

embedding_config_spec = tf1.estimator.tpu.experimental.EmbeddingConfigSpec(
    feature_columns=(embedding_column,),
    optimization_parameters=(
        tf1.tpu.experimental.AdagradParameters(0.05)))

다음으로 TPUEstimator 를 사용하려면 다음을 정의하십시오.

  • 훈련 데이터에 대한 입력 함수
  • 평가 데이터에 대한 평가 입력 기능
  • TPUEstimator 에게 학습 작업이 기능 및 레이블로 정의되는 방식을 지시하기 위한 모델 함수
def _input_fn(params):
  dataset = tf1.data.Dataset.from_tensor_slices((
      {"dense_feature": features,
       "sparse_feature": tf1.SparseTensor(
           embedding_features_indices,
           embedding_features_values, [1, 2])},
           labels))
  dataset = dataset.repeat()
  return dataset.batch(params['batch_size'], drop_remainder=True)

def _eval_input_fn(params):
  dataset = tf1.data.Dataset.from_tensor_slices((
      {"dense_feature": eval_features,
       "sparse_feature": tf1.SparseTensor(
           eval_embedding_features_indices,
           eval_embedding_features_values, [1, 2])},
           eval_labels))
  dataset = dataset.repeat()
  return dataset.batch(params['batch_size'], drop_remainder=True)

def _model_fn(features, labels, mode, params):
  embedding_features = tf1.keras.layers.DenseFeatures(embedding_column)(features)
  concatenated_features = tf1.keras.layers.Concatenate(axis=1)(
      [embedding_features, features["dense_feature"]])
  logits = tf1.layers.Dense(1)(concatenated_features)
  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
  optimizer = tf1.train.AdagradOptimizer(0.05)
  optimizer = tf1.tpu.CrossShardOptimizer(optimizer)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
  return tf1.estimator.tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op)

이러한 기능을 정의하고 클러스터 정보를 제공하는 tf.distribute.cluster_resolver.TPUClusterResolvertf.compat.v1.estimator.tpu.RunConfig 객체를 생성합니다.

정의한 모델 함수와 함께 이제 TPUEstimator 를 만들 수 있습니다. 여기에서는 체크포인트 절약을 건너뛰어 흐름을 단순화합니다. 그런 다음 TPUEstimator 에 대한 훈련 및 평가 모두에 대한 배치 크기를 지정합니다.

cluster_resolver = tf1.distribute.cluster_resolver.TPUClusterResolver(tpu='')
print("All devices: ", tf1.config.list_logical_devices('TPU'))
All devices:  []
tpu_config = tf1.estimator.tpu.TPUConfig(
    iterations_per_loop=10,
    per_host_input_for_training=tf1.estimator.tpu.InputPipelineConfig
          .PER_HOST_V2)
config = tf1.estimator.tpu.RunConfig(
    cluster=cluster_resolver,
    save_checkpoints_steps=None,
    tpu_config=tpu_config)
estimator = tf1.estimator.tpu.TPUEstimator(
    model_fn=_model_fn, config=config, train_batch_size=8, eval_batch_size=8,
    embedding_config_spec=embedding_config_spec)
WARNING:tensorflow:Estimator's model_fn (<function _model_fn at 0x7eff1dbf4ae8>) includes params argument, but params are not passed to Estimator.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpc68an8jx
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpc68an8jx', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.240.1.2:8470"
    }
  }
}
isolate_session_state: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.240.1.2:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.240.1.2:8470', '_evaluation_master': 'grpc://10.240.1.2:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=10, num_shards=None, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None, eval_training_input_configuration=2, experimental_host_call_every_n_steps=1, experimental_allow_per_host_v2_parallel_get_next=False, experimental_feed_hook=None), '_cluster': <tensorflow.python.distribute.cluster_resolver.tpu.tpu_cluster_resolver.TPUClusterResolver object at 0x7eff1dbfa2b0>}
INFO:tensorflow:_TPUContext: eval_on_tpu True

TPUEstimator.train 을 호출하여 모델 학습을 시작합니다.

estimator.train(_input_fn, steps=1)
INFO:tensorflow:Querying Tensorflow master (grpc://10.240.1.2:8470) for TPU system metadata.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, -3018931587863375246)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 1249032734884062775)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, -3881759543008185868)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, -3421771184935649663)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 8872583169621331661)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, -1222373804129613329)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 6258068298163390748)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 5190265587768274342)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 3073578684150069836)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 17179869184, 2071242092327503173)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, -1319360343564144287)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/tpu/feature_column_v2.py:479: IdentityCategoricalColumn._num_buckets (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.
Instructions for updating:
The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.
INFO:tensorflow:Querying Tensorflow master (grpc://10.240.1.2:8470) for TPU system metadata.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, -3018931587863375246)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 1249032734884062775)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, -3881759543008185868)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, -3421771184935649663)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 8872583169621331661)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, -1222373804129613329)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 6258068298163390748)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 5190265587768274342)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 3073578684150069836)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 17179869184, 2071242092327503173)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, -1319360343564144287)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Bypassing TPUEstimator hook
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:TPU job name worker
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py:758: Variable.load (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer Variable.assign which has equivalent behavior in 2.X.
INFO:tensorflow:Initialized dataset iterators in 0 seconds
INFO:tensorflow:Installing graceful shutdown hook.
INFO:tensorflow:Creating heartbeat manager for ['/job:worker/replica:0/task:0/device:CPU:0']
INFO:tensorflow:Configuring worker heartbeat: shutdown_mode: WAIT_FOR_COORDINATOR

INFO:tensorflow:Init TPU system
INFO:tensorflow:Initialized TPU in 9 seconds
INFO:tensorflow:Starting infeed thread controller.
INFO:tensorflow:Starting outfeed thread controller.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Outfeed finished for iteration (0, 0)
INFO:tensorflow:loss = 0.5212165, step = 1
INFO:tensorflow:Stop infeed thread controller
INFO:tensorflow:Shutting down InfeedController thread.
INFO:tensorflow:InfeedController received shutdown signal, stopping.
INFO:tensorflow:Infeed thread finished, shutting down.
INFO:tensorflow:infeed marked as finished
INFO:tensorflow:Stop output thread controller
INFO:tensorflow:Shutting down OutfeedController thread.
INFO:tensorflow:OutfeedController received shutdown signal, stopping.
INFO:tensorflow:Outfeed thread finished, shutting down.
INFO:tensorflow:outfeed marked as finished
INFO:tensorflow:Shutdown TPU system.
INFO:tensorflow:Loss for final step: 0.5212165.
INFO:tensorflow:training_loop marked as finished
<tensorflow_estimator.python.estimator.tpu.tpu_estimator.TPUEstimator at 0x7eff1dbfa7b8>

그런 다음 TPUEstimator.evaluate 를 호출하여 평가 데이터를 사용하여 모델을 평가합니다.

estimator.evaluate(_eval_input_fn, steps=1)
INFO:tensorflow:Could not find trained model in model_dir: /tmp/tmpc68an8jx, running initialization to evaluate.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Querying Tensorflow master (grpc://10.240.1.2:8470) for TPU system metadata.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, -3018931587863375246)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 1249032734884062775)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, -3881759543008185868)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, -3421771184935649663)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 8872583169621331661)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, -1222373804129613329)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 6258068298163390748)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 5190265587768274342)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 3073578684150069836)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 17179869184, 2071242092327503173)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, -1319360343564144287)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py:3406: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-02-05T13:21:42
INFO:tensorflow:TPU job name worker
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Init TPU system
INFO:tensorflow:Initialized TPU in 11 seconds
INFO:tensorflow:Starting infeed thread controller.
INFO:tensorflow:Starting outfeed thread controller.
INFO:tensorflow:Initialized dataset iterators in 0 seconds
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Outfeed finished for iteration (0, 0)
INFO:tensorflow:Evaluation [1/1]
INFO:tensorflow:Stop infeed thread controller
INFO:tensorflow:Shutting down InfeedController thread.
INFO:tensorflow:InfeedController received shutdown signal, stopping.
INFO:tensorflow:Infeed thread finished, shutting down.
INFO:tensorflow:infeed marked as finished
INFO:tensorflow:Stop output thread controller
INFO:tensorflow:Shutting down OutfeedController thread.
INFO:tensorflow:OutfeedController received shutdown signal, stopping.
INFO:tensorflow:Outfeed thread finished, shutting down.
INFO:tensorflow:outfeed marked as finished
INFO:tensorflow:Shutdown TPU system.
INFO:tensorflow:Inference Time : 12.50468s
INFO:tensorflow:Finished evaluation at 2022-02-05-13:21:54
INFO:tensorflow:Saving dict for global step 1: global_step = 1, loss = 36.28813
INFO:tensorflow:evaluation_loop marked as finished
{'loss': 36.28813, 'global_step': 1}

TensorFlow 2: TPUStrategy를 사용하여 TPU에서 임베딩 학습

TensorFlow 2에서 TPU 작업자를 교육하려면 모델 정의 및 교육/평가를 위해 tf.distribute.TPUStrategy API와 함께 tf.distribute.TPUStrategy를 사용하세요. (Keras Model.fit 및 맞춤형 훈련 루프( tf.functiontf.GradientTape )를 사용한 훈련에 대한 더 많은 예는 TPU 사용 가이드를 참조하십시오.)

원격 클러스터에 접속하고 TPU 워커를 초기화하기 위해서는 초기화 작업이 필요하므로 먼저 TPUClusterResolver 를 생성하여 클러스터 정보를 제공하고 클러스터에 접속한다. (TPU 사용 가이드의 TPU 초기화 섹션에서 자세히 알아보세요.)

cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Initializing the TPU system: grpc://10.240.1.2:8470
INFO:tensorflow:Initializing the TPU system: grpc://10.240.1.2:8470
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Finished initializing TPU system.
All devices:  [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')]

다음으로 데이터를 준비합니다. 이것은 데이터 세트 함수가 이제 params 딕셔너리 대신 tf.distribute.InputContext 객체를 전달한다는 점을 제외하고는 TensorFlow 1 예제에서 데이터 세트를 생성한 방법과 유사합니다. 이 개체를 사용하여 로컬 배치 크기(및 이 파이프라인의 호스트 대상이므로 데이터를 적절하게 분할할 수 있음)를 결정할 수 있습니다.

  • tfrs.layers.embedding.TPUEmbedding API를 사용할 때 TPUEmbedding Dataset.batch 데이터 세트를 일괄 처리할 때 drop_remainder=True 옵션을 포함하는 것이 중요합니다.
  • 또한 동일한 장치 세트에서 수행되는 경우 평가 및 교육에 동일한 배치 크기를 사용해야 합니다.
  • 마지막으로 tf.distribute.InputOptions (전략별 구성 보유)에서 특수 입력 옵션인 experimental_fetch_to_device=False 와 함께 tf.keras.utils.experimental.DatasetCreator 를 사용해야 합니다. 이것은 아래에 설명되어 있습니다.
global_batch_size = 8

def _input_dataset(context: tf.distribute.InputContext):
  dataset = tf.data.Dataset.from_tensor_slices((
      {"dense_feature": features,
       "sparse_feature": tf.SparseTensor(
           embedding_features_indices,
           embedding_features_values, [1, 2])},
           labels))
  dataset = dataset.shuffle(10).repeat()
  dataset = dataset.batch(
      context.get_per_replica_batch_size(global_batch_size),
      drop_remainder=True)
  return dataset.prefetch(2)

def _eval_dataset(context: tf.distribute.InputContext):
  dataset = tf.data.Dataset.from_tensor_slices((
      {"dense_feature": eval_features,
       "sparse_feature": tf.SparseTensor(
           eval_embedding_features_indices,
           eval_embedding_features_values, [1, 2])},
           eval_labels))
  dataset = dataset.repeat()
  dataset = dataset.batch(
      context.get_per_replica_batch_size(global_batch_size),
      drop_remainder=True)
  return dataset.prefetch(2)

input_options = tf.distribute.InputOptions(
    experimental_fetch_to_device=False)

input_dataset = tf.keras.utils.experimental.DatasetCreator(
    _input_dataset, input_options=input_options)

eval_dataset = tf.keras.utils.experimental.DatasetCreator(
    _eval_dataset, input_options=input_options)

다음으로, 데이터가 준비되면 TPUStrategy 를 만들고 이 전략의 범위( Strategy.scope )에서 모델, 메트릭 및 옵티마이저를 정의합니다.

tf.function 호출 중에 실행할 배치 수를 지정하고 성능에 중요하기 때문에 steps_per_execution 에서 Model.compile 에 대한 숫자를 선택해야 합니다. 이 인수는 TPUEstimator 에서 사용되는 iterations_per_loop 와 유사합니다.

tf.tpu.experimental.embedding_column (및 tf.tpu.experimental.shared_embedding_column )을 통해 TensorFlow 1에 지정된 기능 및 테이블 구성은 한 쌍의 구성 개체를 통해 TensorFlow 2에서 직접 지정할 수 있습니다.

(자세한 내용은 관련 API 설명서를 참조하십시오.)

strategy = tf.distribute.TPUStrategy(cluster_resolver)
with strategy.scope():
  optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
  dense_input = tf.keras.Input(shape=(2,), dtype=tf.float32, batch_size=global_batch_size)
  sparse_input = tf.keras.Input(shape=(), dtype=tf.int32, batch_size=global_batch_size)
  embedded_input = tfrs.layers.embedding.TPUEmbedding(
      feature_config=tf.tpu.experimental.embedding.FeatureConfig(
          table=tf.tpu.experimental.embedding.TableConfig(
              vocabulary_size=10,
              dim=5,
              initializer=tf.initializers.TruncatedNormal(mean=0.0, stddev=1)),
          name="sparse_input"),
      optimizer=optimizer)(sparse_input)
  input = tf.keras.layers.Concatenate(axis=1)([dense_input, embedded_input])
  result = tf.keras.layers.Dense(1)(input)
  model = tf.keras.Model(inputs={"dense_feature": dense_input, "sparse_feature": sparse_input}, outputs=result)
  model.compile(optimizer, "mse", steps_per_execution=10)
INFO:tensorflow:Found TPU system:
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)

이것으로 훈련 데이터 세트로 모델을 훈련할 준비가 되었습니다.

model.fit(input_dataset, epochs=5, steps_per_epoch=10)
Epoch 1/5
10/10 [==============================] - 2s 164ms/step - loss: 0.4005
Epoch 2/5
10/10 [==============================] - 0s 3ms/step - loss: 0.0036
Epoch 3/5
10/10 [==============================] - 0s 3ms/step - loss: 3.0932e-05
Epoch 4/5
10/10 [==============================] - 0s 3ms/step - loss: 2.5767e-07
Epoch 5/5
10/10 [==============================] - 0s 3ms/step - loss: 2.1366e-09
<keras.callbacks.History at 0x7efd8c461c18>

마지막으로 평가 데이터 세트를 사용하여 모델을 평가합니다.

model.evaluate(eval_dataset, steps=1, return_dict=True)
1/1 [==============================] - 1s 1s/step - loss: 15.3952
{'loss': 15.395216941833496}

다음 단계

API 문서에서 TPU 관련 임베딩 설정에 대해 자세히 알아보세요.

TensorFlow 2의 TPUStrategy 에 대한 자세한 내용은 다음 리소스를 참조하세요.

훈련 사용자 지정에 대한 자세한 내용은 다음을 참조하십시오.

기계 학습을 위한 Google의 특수 ASIC인 TPU는 Google Colab , TPU Research CloudCloud TPU 를 통해 사용할 수 있습니다.