DTensor による分散型トレーニング

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

概要

DTensor を使用すると、デバイス間でモデルのトレーニングを分散し、有効性、信頼性、およびスケーラビリティを改善することができます。DTensor の概念についての詳細は、DTensor プログラミングガイドをご覧ください。

このチュートリアルでは、DTensor を使って、センチメント分析モデルをトレーニングします。この例では、以下の 3 つの分散型トレーニングスキームについて紹介します。

  • データ並列トレーニング: トレーニングサンプルを複数のデバイスにシャーディング(分割)します。
  • モデル並列トレーニング: モデル変数を複数のデバイスにシャーディングします。
  • 空間並列トレーニング: 入力データの特徴量を複数のデバイスにシャーディングします(空間分割としても知られています)。

このチュートリアルのトレーニングの部分は、センチメント分析に関する Kaggle ガイドノートブックを基盤としています。完全なトレーニングと評価のワークフロー(DTensor なし)について学習するには、そちらのノートブックをご覧ください。

このチュートリアルでは、以下のステップを説明します。

  • まず、データクリーニングを行い、トークン化された文とその極性の tf.data.Dataset を取得します。

  • 次に、カスタム Dense レイヤーと BatchNorm レイヤーを使って MLP モデルを構築します。推論変数の追跡には、tf.Module を使用します。モデルコンストラクタは、追加の Layout 引数を取って、変数のシャーディングを制御します。

  • トレーニングには、はじめに tf.experimental.dtensor のチェックポイント機能を使ってデータ並列トレーニングを使用します。次に、モデル並列トレーニングと空間並列トレーニングを使用します。

  • 最後のセクションでは、TensorFlow 2.9 時点での tf.saved_modeltf.experimental.dtensor の対話を簡単に説明します。

セットアップ

DTensor は、TensorFlow 2.9.0 リリースに含まれています。

pip install --quiet --upgrade --pre tensorflow tensorflow-datasets

次に、tensorflowtensorflow.experimental.dtensor をインポートし、8 個の仮想 CPU を使用するように TensorFlow を構成します。

この例では CPU を使用しますが、DTensor は CPU、GPU、または TPU デバイスで同じように動作します。

import tempfile
import numpy as np
import tensorflow_datasets as tfds

import tensorflow as tf

from tensorflow.experimental import dtensor
print('TensorFlow version:', tf.__version__)
2024-01-11 18:23:30.059833: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 18:23:30.059874: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 18:23:30.061347: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
TensorFlow version: 2.15.0
def configure_virtual_cpus(ncpu):
  phy_devices = tf.config.list_physical_devices('CPU')
  tf.config.set_logical_device_configuration(phy_devices[0], [
        tf.config.LogicalDeviceConfiguration(),
    ] * ncpu)

configure_virtual_cpus(8)
DEVICES = [f'CPU:{i}' for i in range(8)]

tf.config.list_logical_devices('CPU')
[LogicalDevice(name='/device:CPU:0', device_type='CPU'),
 LogicalDevice(name='/device:CPU:1', device_type='CPU'),
 LogicalDevice(name='/device:CPU:2', device_type='CPU'),
 LogicalDevice(name='/device:CPU:3', device_type='CPU'),
 LogicalDevice(name='/device:CPU:4', device_type='CPU'),
 LogicalDevice(name='/device:CPU:5', device_type='CPU'),
 LogicalDevice(name='/device:CPU:6', device_type='CPU'),
 LogicalDevice(name='/device:CPU:7', device_type='CPU')]

データセットをダウンロードする

センチメント分析モデルをトレーニングするための IMDB レビューデータセットをダウンロードします。

train_data = tfds.load('imdb_reviews', split='train', shuffle_files=True, batch_size=64)
train_data
<_PrefetchDataset element_spec={'label': TensorSpec(shape=(None,), dtype=tf.int64, name=None), 'text': TensorSpec(shape=(None,), dtype=tf.string, name=None)}>

データを準備する

まず、テキストをトークン化します。ここでは、One-Hot エンコーディングの拡張機能である 'tf_idf' モードの tf.keras.layers.TextVectorization を使用します。

  • 速度を得るために、トークン数は 1200 に制限します。
  • tf.Module を単純に維持するために、トレーニングの前のプリプロセッシングステップとして TextVectorization を実行します。

データクリーニングセクションの最終結果は、トークン化したテキストを x、ラベルを y とした Dataset です。

注意: プリプロセッシングステップとして TextVectorization を実行するのは、通常の実践でも推奨される実践もありません。こうすることで、トレーニングデータがクライアントメモリに収まることが想定されますが、常にそうであるとは限りません。

text_vectorization = tf.keras.layers.TextVectorization(output_mode='tf_idf', max_tokens=1200, output_sequence_length=None)
text_vectorization.adapt(data=train_data.map(lambda x: x['text']))
def vectorize(features):
  return text_vectorization(features['text']), features['label']

train_data_vec = train_data.map(vectorize)
train_data_vec
<_MapDataset element_spec=(TensorSpec(shape=(None, 1200), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))>

DTensor を使ってニューラルネットワークを構築する

では、DTensor を使って多層パーセプトロン(MLP)ネットワークを構築しましょう。このネットワークでは、全結合の Dense と BatchNorm レイヤーを使用します。

DTensor は、入力 Tensor と変数の dtensor.Layout 属性に従って、通常の TensorFlow Ops の単一プログラムマルチデータ(SPMD)拡張を通じて TensorFlow を拡張します。

DTensor を認識するレイヤーの変数は dtensor.DVariable で、DTensor を認識するレイヤーオブジェクトのコンストラクタは、通常のレイヤーパラメータの他に追加の Layout 入力を取ります。

注意: TensorFlow 2.9 の時点では、tf.keras.layer.Densetf.keras.layer.BatchNormalization などの Keras レイヤーは、dtensor.Layout 引数を受け取ります。DTensor を使って Keras を使用する方法の詳細については、DTensor と Keras の統合チュートリアルをご覧ください。

Dense レイヤー

以下のカスタム Dense レイヤーは、2 つのレイヤー変数を定義します。1 つは重みの変数 \(W_{ij}\)、もう 1 つはバイアスの変数 \(b_i\) です。

\[ y_j = \sigma(\sum_i x_i W_{ij} + b_j) \]

Layout の推論

この結果は、以下の観察結果から得られます。

  • 行列内積 \(t_j = \sum_i x_i W_{ij}\) のオペランドに推奨される DTensor シャーディングは、\(i\) 軸に沿って \(\mathbf{W}\) と \(\mathbf{x}\) を同じ方法でシャーディングすることです。

  • 行列和 \(t_j + b_j\) のオペランドに推奨される DTensor シャーディングは、\(j\) 軸に沿って \(\mathbf{t}\) と \(\mathbf{b}\) を同じ方法でシャーディングすることです。

class Dense(tf.Module):

  def __init__(self, input_size, output_size,
               init_seed, weight_layout, activation=None):
    super().__init__()

    random_normal_initializer = tf.function(tf.random.stateless_normal)

    self.weight = dtensor.DVariable(
        dtensor.call_with_layout(
            random_normal_initializer, weight_layout,
            shape=[input_size, output_size],
            seed=init_seed
            ))
    if activation is None:
      activation = lambda x:x
    self.activation = activation

    # bias is sharded the same way as the last axis of weight.
    bias_layout = weight_layout.delete([0])

    self.bias = dtensor.DVariable(
        dtensor.call_with_layout(tf.zeros, bias_layout, [output_size]))

  def __call__(self, x):
    y = tf.matmul(x, self.weight) + self.bias
    y = self.activation(y)

    return y

BatchNorm

バッチ正規化レイヤーでは、トレーニング中にモードが崩壊するのを回避できます。この場合は、バッチ正則化レイヤーを追加することで、モデルのトレーニングでゼロのみを生成するモデルが生成されないようにすることができます。

以下のカスタム BatchNorm レイヤーのコンストラクタは、Layout 引数を取りません。これは、BatchNorm にレイヤー変数がないためです。ただし、レイヤーへの唯一の入力である 'x' がすでにグローバルバッチを表現する DTensor であるため、DTensor では機能します。

注意: DTensor では、入力 Tensor 'x' は常にグローバルバッチを表現します。したがって、tf.nn.batch_normalization はグローバルバッチに適用されます。これは、Tensor 'x' がバッチのレプリカ単位のシャード(ローカルバッチ)のみを表現する tf.distribute.MirroredStrategy を使ってトレーニングとは異なります。

class BatchNorm(tf.Module):

  def __init__(self):
    super().__init__()

  def __call__(self, x, training=True):
    if not training:
      # This branch is not used in the Tutorial.
      pass
    mean, variance = tf.nn.moments(x, axes=[0])
    return tf.nn.batch_normalization(x, mean, variance, 0.0, 1.0, 1e-5)

フル機能のバッチ正規化レイヤー(tf.keras.layers.BatchNormalization など)は、変数に Layout 引数が必要となります。

def make_keras_bn(bn_layout):
  return tf.keras.layers.BatchNormalization(gamma_layout=bn_layout,
                                            beta_layout=bn_layout,
                                            moving_mean_layout=bn_layout,
                                            moving_variance_layout=bn_layout,
                                            fused=False)

すべてのレイヤーをまとめる

次に、上記のビルディングブロックを使って、多層パーセプトロン(MLP)ネットワークを構築しましょう。下の図では、DTensor シャーディングまたは複製を適用しない 2 つの Dense レイヤーの入力 x と重み行列を示します。

非分散型モデルの入力と重み行列。

最初の Dense レイヤーの出力は、2 つ目の Dense レイヤーの入力に渡されます(BatchNorm の後)。したがって、最初の Dense レイヤー(\(\mathbf{W_1}\))の出力と 2 つ目の Dense レイヤー(\(\mathbf{W_2}\))の入力に推奨される DTensor シャーディングは、\(\mathbf{W_1}\) と \(\mathbf{W_2}\) を共通する軸 \(\hat{j}\) に沿って同じ方法でシャーディングすることです。

レイアウトの推論では、2 つのレイアウトが独立していないことが示されていますが、モデルインターフェイスを単純化するために、MLP は Dense レイヤーごとに1つずつ、2 つの Layout 引数を取ります。

レイアウトの推論では、2 つのレイアウトが独立していないことが示されていますが、モデルインターフェイスを単純化するために、MLP は Dense レイヤーごとに1つずつ、2 つの Layout 引数を取ります。

from typing import Tuple

class MLP(tf.Module):

  def __init__(self, dense_layouts: Tuple[dtensor.Layout, dtensor.Layout]):
    super().__init__()

    self.dense1 = Dense(
        1200, 48, (1, 2), dense_layouts[0], activation=tf.nn.relu)
    self.bn = BatchNorm()
    self.dense2 = Dense(48, 2, (3, 4), dense_layouts[1])

  def __call__(self, x):
    y = x
    y = self.dense1(y)
    y = self.bn(y)
    y = self.dense2(y)
    return y

レイアウト推論の制約の正確さと API の単純さの間に発生するトレードオフは、DTensor を使用する API の一般的な設計ポイントです。別の API を使用して Layout 間の依存関係をキャプチャすることも可能です。たとえば、MLPStricter クラスはコンストラクタに Layout オブジェクトを作成します。

class MLPStricter(tf.Module):

  def __init__(self, mesh, input_mesh_dim, inner_mesh_dim1, output_mesh_dim):
    super().__init__()

    self.dense1 = Dense(
        1200, 48, (1, 2), dtensor.Layout([input_mesh_dim, inner_mesh_dim1], mesh),
        activation=tf.nn.relu)
    self.bn = BatchNorm()
    self.dense2 = Dense(48, 2, (3, 4), dtensor.Layout([inner_mesh_dim1, output_mesh_dim], mesh))


  def __call__(self, x):
    y = x
    y = self.dense1(y)
    y = self.bn(y)
    y = self.dense2(y)
    return y

モデルが確実に実行するように、完全に複製されたレイアウトと完全に複製された 'x' 入力のバッチを使用してモデルをプローブします。

WORLD = dtensor.create_mesh([("world", 8)], devices=DEVICES)

model = MLP([dtensor.Layout.replicated(WORLD, rank=2),
             dtensor.Layout.replicated(WORLD, rank=2)])

sample_x, sample_y = train_data_vec.take(1).get_single_element()
sample_x = dtensor.copy_to_mesh(sample_x, dtensor.Layout.replicated(WORLD, rank=2))
print(model(sample_x))
tf.Tensor([[-5.61041546 5.04737568]
 [-7.14075 6.86515808]
 [-3.10483789 1.58168292]
 ...
 [6.87280369 -3.56776071]
 [8.27548695 -5.70918465]
 [-1.98807716 1.71495843]], layout="sharding_specs:unsharded,unsharded, mesh:|world=8|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1,/job:localhost/replica:0/task:0/device:CPU:2,/job:localhost/replica:0/task:0/device:CPU:3,/job:localhost/replica:0/task:0/device:CPU:4,/job:localhost/replica:0/task:0/device:CPU:5,/job:localhost/replica:0/task:0/device:CPU:6,/job:localhost/replica:0/task:0/device:CPU:7", shape=(64, 2), dtype=float32)

データをデバイスに移動する

通常、tf.data イテレータ(およびその他のデータの取得手法)によって、ローカルホストのデバイスメモリにバックアップされるテンソルオブジェクトが生成されます。このデータは、DTensor のコンポーネントテンソルをバックアップするアクセラレータデバイスのメモリに転送する必要があります。

このような状況においては、dtensor.copy_to_mesh は適していません。DTensor はグローバル観点であるため、すべてのデバイスに入力テンソルを複製してしまうためです。そのため、このチュートリアルでは、データの転送を容易にするヘルパー関数 repack_local_tensor を使用します。このヘルパー関数は、レプリカをバックアップするデバイスに、グローバルバッチのレプリカ用のシャードを送信する(送信するだけです)dtensor.pack を使用します。

単純化されたこの関数は、シングルクライアントを想定しています。マルチクライアントアプリケーションでは、ローカルテンソルを分割する正しい方法と、Split とローカルデバイスのマッピングを特定するには、多大な労力が必要となる可能性があります。

tf.data の統合を単純化する追加の DTensor API が計画されており、シングルクライアントとマルチクライアントの両方のアプリケーションがサポートされる予定です。ご期待ください。

def repack_local_tensor(x, layout):
  """Repacks a local Tensor-like to a DTensor with layout.

  This function assumes a single-client application.
  """
  x = tf.convert_to_tensor(x)
  sharded_dims = []

  # For every sharded dimension, use tf.split to split the along the dimension.
  # The result is a nested list of split-tensors in queue[0].
  queue = [x]
  for axis, dim in enumerate(layout.sharding_specs):
    if dim == dtensor.UNSHARDED:
      continue
    num_splits = layout.shape[axis]
    queue = tf.nest.map_structure(lambda x: tf.split(x, num_splits, axis=axis), queue)
    sharded_dims.append(dim)

  # Now we can build the list of component tensors by looking up the location in
  # the nested list of split-tensors created in queue[0].
  components = []
  for locations in layout.mesh.local_device_locations():
    t = queue[0]
    for dim in sharded_dims:
      split_index = locations[dim]  # Only valid on single-client mesh.
      t = t[split_index]
    components.append(t)

  return dtensor.pack(components, layout)

データ並列トレーニング

このセクションでは、データ並列トレーニング使用して、MLP モデルをトレーニングします。その後のセクションでは、モデル並列トレーニングと空間並列トレーニングについて説明します。

データ並列トレーニングは、分散型機械学習で一般的に使用されているスキームです。

  • モデル変数は、N 個のデバイスにそれぞれ複製されます。
  • グローバルバッチは、N 個のレプリカごとのバッチに分割されます。
  • それぞれのレプリカごとのバッチは、レプリカデバイスでトレーニングされます。
  • 勾配は、すべてのレプリカでデータの重み付けが集団的に実行される前に減らされます。

データ並列トレーニングでは、デバイスの数に関してほぼ直線的なスピードアップが得られます。

データ並列メッシュを作成する

典型的なデータ並行トレーニングループは、単一の batch 次元で構成される DTensor Mesh を使用します。この場合、各デバイスは、グローバルバッチからシャードを受け取るモデルのレプリカとなります。

データ並列メッシュ

複製されたモデルはレプリカで実行するため、モデル変数が完全に複製されます(シャーディングされません)。

mesh = dtensor.create_mesh([("batch", 8)], devices=DEVICES)

model = MLP([dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh),
             dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh),])

トレーニングデータを DTensor にパッキングする

トレーニングデータバッチは、DTensor がトレーニングデータを 'batch' メッシュ次元に均等に分散するように、'batch'(first) 軸に沿ってシャーディングされて DTensor にパックされます。

注意: DTensor では、batch size は常にグローバルバッチサイズを指します。バッチサイズは、batch メッシュ次元のサイズで均等に分割されるように選択します。

def repack_batch(x, y, mesh):
  x = repack_local_tensor(x, layout=dtensor.Layout(['batch', dtensor.UNSHARDED], mesh))
  y = repack_local_tensor(y, layout=dtensor.Layout(['batch'], mesh))
  return x, y

sample_x, sample_y = train_data_vec.take(1).get_single_element()
sample_x, sample_y = repack_batch(sample_x, sample_y, mesh)

print('x', sample_x[:, 0])
print('y', sample_y)
x tf.Tensor({"CPU:0": [85.6979828 57.1319885 139.655975 ... 260.267944 438.011902 111.089973], "CPU:1": [117.437973 66.6539841 107.915977 ... 146.003967 260.267944 47.6099892], "CPU:2": [136.481964 215.831955 285.659943 ... 355.487915 206.309952 101.567978], "CPU:3": [107.915977 57.1319885 79.3499832 ... 63.4799881 203.135956 371.35791], "CPU:4": [206.309952 73.0019836 34.9139938 ... 82.5239792 44.4359894 69.8279877], "CPU:5": [95.2199783 219.005951 434.837891 ... 98.3939819 95.2199783 345.965912], "CPU:6": [174.569962 282.485931 38.0879898 ... 234.875946 79.3499832 79.3499832], "CPU:7": [215.831955 590.363892 107.915977 ... 238.049942 244.397949 82.5239792]}, layout="sharding_specs:batch, mesh:|batch=8|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1,/job:localhost/replica:0/task:0/device:CPU:2,/job:localhost/replica:0/task:0/device:CPU:3,/job:localhost/replica:0/task:0/device:CPU:4,/job:localhost/replica:0/task:0/device:CPU:5,/job:localhost/replica:0/task:0/device:CPU:6,/job:localhost/replica:0/task:0/device:CPU:7", shape=(64,), dtype=float32)
y tf.Tensor({"CPU:0": [0 0 0 ... 1 0 0], "CPU:1": [0 0 0 ... 0 1 0], "CPU:2": [1 1 1 ... 1 1 1], "CPU:3": [1 0 0 ... 0 0 0], "CPU:4": [0 0 0 ... 0 0 1], "CPU:5": [0 0 0 ... 0 0 1], "CPU:6": [0 1 1 ... 1 0 1], "CPU:7": [0 1 1 ... 1 0 1]}, layout="sharding_specs:batch, mesh:|batch=8|0,1,2,3,4,5,6,7|0,1,2,3,4,5,6,7|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1,/job:localhost/replica:0/task:0/device:CPU:2,/job:localhost/replica:0/task:0/device:CPU:3,/job:localhost/replica:0/task:0/device:CPU:4,/job:localhost/replica:0/task:0/device:CPU:5,/job:localhost/replica:0/task:0/device:CPU:6,/job:localhost/replica:0/task:0/device:CPU:7", shape=(64,), dtype=int64)

トレーニングステップ

この例では、カスタムトレーニングループ(CTL)で確率的勾配降下法オプティマイザを使用します。このトピックについての詳細は、カスタムトレーニングループガイドウォークスルーをご覧ください。

train_step は、この本体が TensorFlow Graph としてトレースされることを示すために、tf.function としてカプセル化されます。train_step の本体は、前方推論パス、後方勾配パス、および変数更新で構成されています。

train_step の本体には特殊な DTensor アノテーションが含まれないことに注意してください。代わりに、train_step には、入力バッチとモデルのグローバルビューから入力 xy を処理する高レベルの TensorFlow 演算子のみが含まれています。すべての DTensor アノテーション(Mesh, Layout)は、トレーニングステップから除外されます。

# Refer to the CTL (custom training loop guide)
@tf.function
def train_step(model, x, y, learning_rate=tf.constant(1e-4)):
  with tf.GradientTape() as tape:
    logits = model(x)
    # tf.reduce_sum sums the batch sharded per-example loss to a replicated
    # global loss (scalar).
    loss = tf.reduce_sum(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=y))
  parameters = model.trainable_variables
  gradients = tape.gradient(loss, parameters)
  for parameter, parameter_gradient in zip(parameters, gradients):
    parameter.assign_sub(learning_rate * parameter_gradient)

  # Define some metrics
  accuracy = 1.0 - tf.reduce_sum(tf.cast(tf.argmax(logits, axis=-1, output_type=tf.int64) != y, tf.float32)) / x.shape[0]
  loss_per_sample = loss / len(x)
  return {'loss': loss_per_sample, 'accuracy': accuracy}

チェックポイントを設定する

DTensor モデルには、初期状態の tf.train.Checkpoint を使用してチェックポイントを設定できます。シャーディングされた DVariables の保存と復元は、有効な分割保存と復元を実行します。現在、tf.train.Checkpoint.savetf.train.Checkpoint.restore を使用する場合、すべての DVariables は同じホストメッシュ状にある必要があり、DVariables と通常の変数を同時に保存することはできません。チェックポイントの設定についての詳細は、こちらのガイドをご覧ください。

DTensor のチェックポイントが復元されると、変数の Layout がチェックポイントの保存時と異なる場合があります。つまり、DTensor モデルの保存は、レイアウトとメッシュに関係なく、分割保存の効率にのみ影響するということです。DTensor モデルを 1 つのメッシュとレイアウトで保存し、別のメッシュとレイアウトで復元することが可能です。このチュートリアルではこの機能を利用して、モデルの並列トレーニングと空間並列トレーニングのセクションでトレーニングを続けます。

CHECKPOINT_DIR = tempfile.mkdtemp()

def start_checkpoint_manager(model):
  ckpt = tf.train.Checkpoint(root=model)
  manager = tf.train.CheckpointManager(ckpt, CHECKPOINT_DIR, max_to_keep=3)

  if manager.latest_checkpoint:
    print("Restoring a checkpoint")
    ckpt.restore(manager.latest_checkpoint).assert_consumed()
  else:
    print("New training")
  return manager

トレーニングループ

データ並列トレーニングスキームの場合、トレーニングを数エポック行って、その進捗をレポートします。モデルのトレーニングには 3 エポックでは不十分です。精度 50% は、適当な推定と同等と言えます。

後でトレーニングを再開できるように、チェックポイント設定を有効にします。以降のセクションにおいて、チェックポイントを読み込み、別の並列スキームでトレーニングを行います。

num_epochs = 2
manager = start_checkpoint_manager(model)

for epoch in range(num_epochs):
  step = 0
  pbar = tf.keras.utils.Progbar(target=int(train_data_vec.cardinality()), stateful_metrics=[])
  metrics = {'epoch': epoch}
  for x,y in train_data_vec:

    x, y = repack_batch(x, y, mesh)

    metrics.update(train_step(model, x, y, 1e-2))

    pbar.update(step, values=metrics.items(), finalize=False)
    step += 1
  manager.save()
  pbar.update(step, values=metrics.items(), finalize=True)
New training
391/391 [==============================] - 5s 14ms/step - epoch: 0.0000e+00 - loss: 1.3644 - accuracy: 0.5345
391/391 [==============================] - 4s 10ms/step - epoch: 1.0000 - loss: 0.9407 - accuracy: 0.5978

モデル並列トレーニング

2 次元 Mesh に切り替えて、2 つ目のメッシュ次元に沿ってモデル変数をシャーディングすると、トレーニングがモデル並列になります。

モデル並列トレーニングでは、モデルの各レプリカは複数のデバイス(この場合は 2 つ)にまたがっています。

  • 4 個のモデルレプリカがあり、トレーニングデータバッチは、その 4 個のレプリカに分散されます。
  • 単一のモデルレプリカ内の 2 つのデバイスは、複製されたトレーニングデータを受け取ります。

Model parallel mesh

mesh = dtensor.create_mesh([("batch", 4), ("model", 2)], devices=DEVICES)
model = MLP([dtensor.Layout([dtensor.UNSHARDED, "model"], mesh), 
             dtensor.Layout(["model", dtensor.UNSHARDED], mesh)])
WARNING:tensorflow:5 out of the last 5 calls to <function stateless_random_normal at 0x7fb396938ee0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:5 out of the last 5 calls to <function stateless_random_normal at 0x7fb396938ee0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:6 out of the last 6 calls to <function stateless_random_normal at 0x7fb396938ee0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:6 out of the last 6 calls to <function stateless_random_normal at 0x7fb396938ee0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.

トレーニングデータは、バッチ次元に沿ってシャーディングされたままであるため、データ並列トレーニングの場合と同じ repack_batch 関数を再利用できます。DTensor は "model" メッシュ次元に沿って、レプリカごとのバッチをレプリカ内のすべてのデバイスに自動的に複製します。

def repack_batch(x, y, mesh):
  x = repack_local_tensor(x, layout=dtensor.Layout(['batch', dtensor.UNSHARDED], mesh))
  y = repack_local_tensor(y, layout=dtensor.Layout(['batch'], mesh))
  return x, y

次に、トレーニングループを実行します。トレーニングループは、データ並列トレーニングの例と同じチェックポイントマネージャーを再利用するため、コードは全く同じです。

モデル並列トレーニングで、データ並列でトレーニングされたモデルのトレーニングを続けることができます。

num_epochs = 2
manager = start_checkpoint_manager(model)

for epoch in range(num_epochs):
  step = 0
  pbar = tf.keras.utils.Progbar(target=int(train_data_vec.cardinality()))
  metrics = {'epoch': epoch}
  for x,y in train_data_vec:
    x, y = repack_batch(x, y, mesh)
    metrics.update(train_step(model, x, y, 1e-2))
    pbar.update(step, values=metrics.items(), finalize=False)
    step += 1
  manager.save()
  pbar.update(step, values=metrics.items(), finalize=True)
Restoring a checkpoint
391/391 [==============================] - 4s 11ms/step - epoch: 0.0000e+00 - loss: 0.7249 - accuracy: 0.6642
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<1200x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:23:59.506897: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] tfg_optimizer{any(tfg-consolidate-attrs,tfg-toposort,tfg-shape-inference{graph-version=0},tfg-prepare-attrs-export)} failed: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
2024-01-11 18:23:59.507301: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<1200x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:23:59.509535: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<1200x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:23:59.511363: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<1200x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:23:59.515844: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<1200x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:23:59.519707: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<1200x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:23:59.524811: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: error: 'tfg.Case' op branch #0 function argument #3 type ''tfg.Casetensor<' op 1200branch #0x function argument #243x type f32'>'tensor< is not compatible with corresponding operand type: 1200'x24tensor<x1200f32x>48x'f32 is not compatible with corresponding operand type: >''tensor<
1200x48xf32>'
2024-01-11 18:23:59.542655: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
2024-01-11 18:23:59.542708: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed:
391/391 [==============================] - 4s 10ms/step - epoch: 1.0000 - loss: 0.6449 - accuracy: 0.7058
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<1200x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:24:03.253339: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<1200x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:24:03.255692: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<1200x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:24:03.259452: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<1200x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:24:03.263639: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<1200x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:24:03.267780: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<1200x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:24:03.274404: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: error: 'tfg.Case' op branch #0 function argument #3 type ''tfg.Case' op tensor<branch #01200 function argument #x324 type x'f32>tensor<'1200 is not compatible with corresponding operand type: x'24tensor<x1200f32x>48'x is not compatible with corresponding operand type: f32'>'tensor<
1200x48xf32>'
2024-01-11 18:24:03.292271: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
2024-01-11 18:24:03.292377: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed:

空間並列トレーニング

非常に高次元のデータ(非常に大きな画像や動画など)をトレーニングする際は、特徴量次元に沿ってシャーディングすることが推奨される可能性があります。これは空間分割と呼ばれる手法で、はじめは大きな 3D 入力サンプルでモデルをトレーニングするために TensorFlow に導入された手法です。

空間並列メッシュ

DTensor はこのようなケースもサポートしています。唯一変更が必要なのは、feature 次元を含めて対応する Layout を適用するメッシュを作成することです。

mesh = dtensor.create_mesh([("batch", 2), ("feature", 2), ("model", 2)], devices=DEVICES)
model = MLP([dtensor.Layout(["feature", "model"], mesh), 
             dtensor.Layout(["model", dtensor.UNSHARDED], mesh)])

入力テンソルを DTensor にパッキングする際に、feature 次元に沿って入力データをシャーディングします。この作業は、repack_batch_for_spt というわずかに異なる再パック関数を使って行います。ここで、spt は、空間並列トレーニング(Spatial Parallel Training)略です。

def repack_batch_for_spt(x, y, mesh):
    # Shard data on feature dimension, too
    x = repack_local_tensor(x, layout=dtensor.Layout(["batch", 'feature'], mesh))
    y = repack_local_tensor(y, layout=dtensor.Layout(["batch"], mesh))
    return x, y

空間並列トレーニングも、他の並列トレーニングスキームで作成されたチェックポイントから続行することができます。

num_epochs = 2

manager = start_checkpoint_manager(model)
for epoch in range(num_epochs):
  step = 0
  metrics = {'epoch': epoch}
  pbar = tf.keras.utils.Progbar(target=int(train_data_vec.cardinality()))

  for x, y in train_data_vec:
    x, y = repack_batch_for_spt(x, y, mesh)
    metrics.update(train_step(model, x, y, 1e-2))

    pbar.update(step, values=metrics.items(), finalize=False)
    step += 1
  manager.save()
  pbar.update(step, values=metrics.items(), finalize=True)
Restoring a checkpoint
  7/391 [..............................] - ETA: 37s - epoch: 0.0000e+00 - loss: 0.6948 - accuracy: 0.6895
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1704997444.279909   70406 op_level_cost_estimator.cc:1121] Incompatible Matrix dimensions
E0000 00:00:1704997444.292530   70412 op_level_cost_estimator.cc:1121] Incompatible Matrix dimensions
E0000 00:00:1704997444.294701   70410 op_level_cost_estimator.cc:1121] Incompatible Matrix dimensions
E0000 00:00:1704997444.299492   70416 op_level_cost_estimator.cc:1121] Incompatible Matrix dimensions
E0000 00:00:1704997444.302982   70408 op_level_cost_estimator.cc:1121] Incompatible Matrix dimensions
E0000 00:00:1704997444.307657   70402 op_level_cost_estimator.cc:1121] Incompatible Matrix dimensions
E0000 00:00:1704997444.327074   70404 op_level_cost_estimator.cc:1121] Incompatible Matrix dimensions
E0000 00:00:1704997444.328998   70414 op_level_cost_estimator.cc:1121] Incompatible Matrix dimensions
390/391 [============================>.] - ETA: 0s - epoch: 0.0000e+00 - loss: 0.5933 - accuracy: 0.7363
E0000 00:00:1704997447.729200   70402 op_level_cost_estimator.cc:1121] Incompatible Matrix dimensions
E0000 00:00:1704997447.733975   70410 op_level_cost_estimator.cc:1121] Incompatible Matrix dimensions
E0000 00:00:1704997447.735483   70404 op_level_cost_estimator.cc:1121] Incompatible Matrix dimensions
E0000 00:00:1704997447.739732   70414 op_level_cost_estimator.cc:1121] Incompatible Matrix dimensions
E0000 00:00:1704997447.744464   70408 op_level_cost_estimator.cc:1121] Incompatible Matrix dimensions
E0000 00:00:1704997447.752916   70416 op_level_cost_estimator.cc:1121] Incompatible Matrix dimensions
E0000 00:00:1704997447.767094   70412 op_level_cost_estimator.cc:1121] Incompatible Matrix dimensions
E0000 00:00:1704997447.771650   70406 op_level_cost_estimator.cc:1121] Incompatible Matrix dimensions
WARNING:tensorflow:5 out of the last 5 calls to <function MultiDeviceSaver.save.<locals>.tf_function_save at 0x7fb2f00908b0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<600x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:24:07.949029: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<600x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:24:07.955580: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<600x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:24:07.957776: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<600x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:24:07.962740: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<600x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:24:07.970366: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<600x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:24:07.973323: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<600x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<600x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:24:07.976116: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
2024-01-11 18:24:07.976275: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
WARNING:tensorflow:5 out of the last 5 calls to <function MultiDeviceSaver.save.<locals>.tf_function_save at 0x7fb2f00908b0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
391/391 [==============================] - 4s 11ms/step - epoch: 0.0000e+00 - loss: 0.5933 - accuracy: 0.7362
387/391 [============================>.] - ETA: 0s - epoch: 1.0000 - loss: 0.5674 - accuracy: 0.7541WARNING:tensorflow:6 out of the last 6 calls to <function MultiDeviceSaver.save.<locals>.tf_function_save at 0x7fb2f0090940> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<600x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:24:11.597332: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<600x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:24:11.599337: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<600x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:24:11.602197: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<600x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:24:11.605658: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<600x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:24:11.609263: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: 'tfg.Case' op branch #0 function argument #3 type 'tensor<600x24xf32>' is not compatible with corresponding operand type: 'tensor<1200x48xf32>'
2024-01-11 18:24:11.613720: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
error: error: 'tfg.Case' op branch #0 function argument #3' type tfg.Case'' op branch #0 function argument #tensor<3600 type x'24xf32tensor<>600x'24 is not compatible with corresponding operand type: x'f32>tensor<1200'x is not compatible with corresponding operand type: 48'xf32tensor<>1200x'48
xf32>'
2024-01-11 18:24:11.632237: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
2024-01-11 18:24:11.632329: W tensorflow/core/common_runtime/optimize_function_graph_utils.cc:615] Ignoring multi-device function optimization failure: INVALID_ARGUMENT: MLIR Graph Optimizer failed: 
WARNING:tensorflow:6 out of the last 6 calls to <function MultiDeviceSaver.save.<locals>.tf_function_save at 0x7fb2f0090940> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
391/391 [==============================] - 4s 9ms/step - epoch: 1.0000 - loss: 0.5666 - accuracy: 0.7544

SavedModel と DTensor

DTensor と SavedModel の統合は、現在開発中です。

TensorFlow 2.11 の時点では、tf.saved_model は分割されて複製された DTensor モデルを保存することが可能であるため、保存は、メッシュの様々なデバイスで有効な分割保存を実行しますが、モデルが保存されると、すべての DTensor アノテーションが失われ、保存したシグネチャは DTensor ではなく通常の Tensor とのみ使用できるようになってしまいます。

mesh = dtensor.create_mesh([("world", 1)], devices=DEVICES[:1])
mlp = MLP([dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh), 
           dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], mesh)])

manager = start_checkpoint_manager(mlp)

model_for_saving = tf.keras.Sequential([
  text_vectorization,
  mlp
])

@tf.function(input_signature=[tf.TensorSpec([None], tf.string)])
def run(inputs):
  return {'result': model_for_saving(inputs)}

tf.saved_model.save(
    model_for_saving, "/tmp/saved_model",
    signatures=run)
Restoring a checkpoint
INFO:tensorflow:Assets written to: /tmp/saved_model/assets
INFO:tensorflow:Assets written to: /tmp/saved_model/assets

TensorFlow 2.9.0 の時点では、読み込まれたシグネチャは通常の Tensor か完全に複製された DTensor(通常の Tensor に変換されます)を使ってのみ呼び出せます。

sample_batch = train_data.take(1).get_single_element()
sample_batch
{'label': <tf.Tensor: shape=(64,), dtype=int64, numpy=
 array([0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1,
        1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1])>,
 'text': <tf.Tensor: shape=(64,), dtype=string, numpy=
 array([b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.",
        b'I have been known to fall asleep during films, but this is usually due to a combination of things including, really tired, being warm and comfortable on the sette and having just eaten a lot. However on this occasion I fell asleep because the film was rubbish. The plot development was constant. Constantly slow and boring. Things seemed to happen, but with no explanation of what was causing them or why. I admit, I may have missed part of the film, but i watched the majority of it and everything just seemed to happen of its own accord without any real concern for anything else. I cant recommend this film at all.',
        b'Mann photographs the Alberta Rocky Mountains in a superb fashion, and Jimmy Stewart and Walter Brennan give enjoyable performances as they always seem to do. <br /><br />But come on Hollywood - a Mountie telling the people of Dawson City, Yukon to elect themselves a marshal (yes a marshal!) and to enforce the law themselves, then gunfighters battling it out on the streets for control of the town? <br /><br />Nothing even remotely resembling that happened on the Canadian side of the border during the Klondike gold rush. Mr. Mann and company appear to have mistaken Dawson City for Deadwood, the Canadian North for the American Wild West.<br /><br />Canadian viewers be prepared for a Reefer Madness type of enjoyable howl with this ludicrous plot, or, to shake your head in disgust.',
        b'This is the kind of film for a snowy Sunday afternoon when the rest of the world can go ahead with its own business as you descend into a big arm-chair and mellow for a couple of hours. Wonderful performances from Cher and Nicolas Cage (as always) gently row the plot along. There are no rapids to cross, no dangerous waters, just a warm and witty paddle through New York life at its best. A family film in every sense and one that deserves the praise it received.',
        b'As others have mentioned, all the women that go nude in this film are mostly absolutely gorgeous. The plot very ably shows the hypocrisy of the female libido. When men are around they want to be pursued, but when no "men" are around, they become the pursuers of a 14 year old boy. And the boy becomes a man really fast (we should all be so lucky at this age!). He then gets up the courage to pursue his true love.',
        b"This is a film which should be seen by anybody interested in, effected by, or suffering from an eating disorder. It is an amazingly accurate and sensitive portrayal of bulimia in a teenage girl, its causes and its symptoms. The girl is played by one of the most brilliant young actresses working in cinema today, Alison Lohman, who was later so spectacular in 'Where the Truth Lies'. I would recommend that this film be shown in all schools, as you will never see a better on this subject. Alison Lohman is absolutely outstanding, and one marvels at her ability to convey the anguish of a girl suffering from this compulsive disorder. If barometers tell us the air pressure, Alison Lohman tells us the emotional pressure with the same degree of accuracy. Her emotional range is so precise, each scene could be measured microscopically for its gradations of trauma, on a scale of rising hysteria and desperation which reaches unbearable intensity. Mare Winningham is the perfect choice to play her mother, and does so with immense sympathy and a range of emotions just as finely tuned as Lohman's. Together, they make a pair of sensitive emotional oscillators vibrating in resonance with one another. This film is really an astonishing achievement, and director Katt Shea should be proud of it. The only reason for not seeing it is if you are not interested in people. But even if you like nature films best, this is after all animal behaviour at the sharp edge. Bulimia is an extreme version of how a tormented soul can destroy her own body in a frenzy of despair. And if we don't sympathise with people suffering from the depths of despair, then we are dead inside.",
        b'Okay, you have:<br /><br />Penelope Keith as Miss Herringbone-Tweed, B.B.E. (Backbone of England.) She\'s killed off in the first scene - that\'s right, folks; this show has no backbone!<br /><br />Peter O\'Toole as Ol\' Colonel Cricket from The First War and now the emblazered Lord of the Manor.<br /><br />Joanna Lumley as the ensweatered Lady of the Manor, 20 years younger than the colonel and 20 years past her own prime but still glamourous (Brit spelling, not mine) enough to have a toy-boy on the side. It\'s alright, they have Col. Cricket\'s full knowledge and consent (they guy even comes \'round for Christmas!) Still, she\'s considerate of the colonel enough to have said toy-boy her own age (what a gal!)<br /><br />David McCallum as said toy-boy, equally as pointlessly glamourous as his squeeze. Pilcher couldn\'t come up with any cover for him within the story, so she gave him a hush-hush job at the Circus.<br /><br />and finally:<br /><br />Susan Hampshire as Miss Polonia Teacups, Venerable Headmistress of the Venerable Girls\' Boarding-School, serving tea in her office with a dash of deep, poignant advice for life in the outside world just before graduation. Her best bit of advice: "I\'ve only been to Nancherrow (the local Stately Home of England) once. I thought it was very beautiful but, somehow, not part of the real world." Well, we can\'t say they didn\'t warn us.<br /><br />Ah, Susan - time was, your character would have been running the whole show. They don\'t write \'em like that any more. Our loss, not yours.<br /><br />So - with a cast and setting like this, you have the re-makings of "Brideshead Revisited," right?<br /><br />Wrong! They took these 1-dimensional supporting roles because they paid so well. After all, acting is one of the oldest temp-jobs there is (YOU name another!)<br /><br />First warning sign: lots and lots of backlighting. They get around it by shooting outdoors - "hey, it\'s just the sunlight!"<br /><br />Second warning sign: Leading Lady cries a lot. When not crying, her eyes are moist. That\'s the law of romance novels: Leading Lady is "dewy-eyed."<br /><br />Henceforth, Leading Lady shall be known as L.L.<br /><br />Third warning sign: L.L. actually has stars in her eyes when she\'s in love. Still, I\'ll give Emily Mortimer an award just for having to act with that spotlight in her eyes (I wonder . did they use contacts?)<br /><br />And lastly, fourth warning sign: no on-screen female character is "Mrs." She\'s either "Miss" or "Lady."<br /><br />When all was said and done, I still couldn\'t tell you who was pursuing whom and why. I couldn\'t even tell you what was said and done.<br /><br />To sum up: they all live through World War II without anything happening to them at all.<br /><br />OK, at the end, L.L. finds she\'s lost her parents to the Japanese prison camps and baby sis comes home catatonic. Meanwhile (there\'s always a "meanwhile,") some young guy L.L. had a crush on (when, I don\'t know) comes home from some wartime tough spot and is found living on the street by Lady of the Manor (must be some street if SHE\'s going to find him there.) Both war casualties are whisked away to recover at Nancherrow (SOMEBODY has to be "whisked away" SOMEWHERE in these romance stories!)<br /><br />Great drama.',
        b'The film is based on a genuine 1950s novel.<br /><br />Journalist Colin McInnes wrote a set of three "London novels": "Absolute Beginners", "City of Spades" and "Mr Love and Justice". I have read all three. The first two are excellent. The last, perhaps an experiment that did not come off. But McInnes\'s work is highly acclaimed; and rightly so. This musical is the novelist\'s ultimate nightmare - to see the fruits of one\'s mind being turned into a glitzy, badly-acted, soporific one-dimensional apology of a film that says it captures the spirit of 1950s London, and does nothing of the sort.<br /><br />Thank goodness Colin McInnes wasn\'t alive to witness it.',
        b'I really love the sexy action and sci-fi films of the sixties and its because of the actress\'s that appeared in them. They found the sexiest women to be in these films and it didn\'t matter if they could act (Remember "Candy"?). The reason I was disappointed by this film was because it wasn\'t nostalgic enough. The story here has a European sci-fi film called "Dragonfly" being made and the director is fired. So the producers decide to let a young aspiring filmmaker (Jeremy Davies) to complete the picture. They\'re is one real beautiful woman in the film who plays Dragonfly but she\'s barely in it. Film is written and directed by Roman Coppola who uses some of his fathers exploits from his early days and puts it into the script. I wish the film could have been an homage to those early films. They could have lots of cameos by actors who appeared in them. There is one actor in this film who was popular from the sixties and its John Phillip Law (Barbarella). Gerard Depardieu, Giancarlo Giannini and Dean Stockwell appear as well. I guess I\'m going to have to continue waiting for a director to make a good homage to the films of the sixties. If any are reading this, "Make it as sexy as you can"! I\'ll be waiting!',
        b'Sure, this one isn\'t really a blockbuster, nor does it target such a position. "Dieter" is the first name of a quite popular German musician, who is either loved or hated for his kind of acting and thats exactly what this movie is about. It is based on the autobiography "Dieter Bohlen" wrote a few years ago but isn\'t meant to be accurate on that. The movie is filled with some sexual offensive content (at least for American standard) which is either amusing (not for the other "actors" of course) or dumb - it depends on your individual kind of humor or on you being a "Bohlen"-Fan or not. Technically speaking there isn\'t much to criticize. Speaking of me I find this movie to be an OK-movie.',
        b'During a sleepless night, I was switching through the channels & found this embarrassment of a movie. What were they thinking?<br /><br />If this is life after "Remote Control" for Kari (Wuhrer) Salin, no wonder she\'s gone nowhere.<br /><br />And why did David Keith take this role? It\'s pathetic!<br /><br />Anyway, I turned on the movie near the end, so I didn\'t get much of the plot. But this must\'ve been the best part. This nerdy college kid brings home this dominatrix-ish girl...this scene is straight out of the comic books -- or the cheap porn movies. She calls the mother anal retentive and kisses the father "Oh, I didn\'t expect tongue!" Great lines!<br /><br />After this, I had to see how it ended..<br /><br />Well, of course, this bitch from hell has a helluva past, so the SWAT team is upstairs. And yes...they surround her! And YES YES! The kid blows her brains out!!!! AHAHHAHAHAHA!!<br /><br />This is must-see TV. <br /><br />',
        b'Cute film about three lively sisters from Switzerland (often seen running about in matching outfits) who want to get their parents back together (seems mom is still carrying the torch for dad) - so they sail off to New York to stop the dad from marrying a blonde gold-digger he calls "Precious". Dad hasn\'t seen his daughters in ten years, they (oddly enough) don\'t seem to mind and think he\'s wonderful, and meanwhile Precious seems to lead a life mainly run by her overbearing mother (Alice Brady), a woman who just wants to see to it her daughter marries a rich man. The sisters get the idea of pushing Precious into the path of a drunken Hungarian count, tricking the two gold-digging women into thinking he is one of the richest men in Europe. But a case of mistaken identity makes the girls think the count is good-looking Ray Milland, who goes along with the scheme \'cause he has a crush on sister Kay.<br /><br />This film is enjoyable, light fare. Barbara Read as Kay comes across as sweet and pretty, Ray Milland looks oh so young and handsome here (though, unfortunately, is given little to do), Alice Brady is quite good as the scheming mother - but it is Deanna Durbin, a real charmer and cute as a button playing youngest sister Penny, who pretty much steals the show. With absolutely beautiful vocals, she sings several songs throughout the film, though I actually would have liked to have seen them feature her even more in this. The plot in this film is a bit silly, but nevertheless, I found the film to be entertaining and fun.',
        b"This 1984 version of the Dickens' classic `A Christmas Carol,' directed by Clive Donner, stars George C. Scott as Ebenezer Scrooge. By this time around, the challenge for the filmmaker was to take such familiar material and make it seem fresh and new again; and, happily to say, with this film Donner not only met the challenge but surpassed any expectations anyone might have had for it. He tells the story with precision and an eye to detail, and extracts performances from his actors that are nothing less than superlative, especially Scott. One could argue that the definitive portrayal of Scrooge-- one of the best known characters in literary fiction, ever-- was created by Alastair Sim in the 1951 film; but I think with his performance here, Scott has now achieved that distinction. There is such a purity and honesty in his Scrooge that it becomes difficult to even consider anyone else in the role once you've seen Scott do it; simply put, he IS Scrooge. And what a tribute it is to such a gifted actor; to be able to take such a well known figure and make it so uniquely his own is quite miraculous. It is truly a joy to see an actor ply his trade so well, to be able to make a character so real, from every word he utters down to the finest expression of his face, and to make it all ring so true. It's a study in perfection.<br /><br />The other members of the cast are splendid as well, but then again they have to be in order to maintain the integrity of Scott's performance; and they do. Frank Finlay is the Ghost of Jacob Marley; a notable turn, though not as memorable, perhaps, as the one by Alec Guinness (as Marley) in the film, `Scrooge.' Angela Pleasence is a welcome visage as the Spirit of Christmas Past; Edward Woodward, grand and boisterous, and altogether convincing as the Spirit of Christmas Present; and Michael Carter, grim and menacing as the Spirit of Christmas Yet To Come.<br /><br />David Warner hits just the right mark with his Bob Cratchit, bringing a sincerity to the role that measures up well to the standard of quality set by Scott's Scrooge, and Susannah York fares just as well as Mrs. Cratchit. The real gem to be found here, though, is the performance of young Anthony Walters as Tiny Tim; it's heartfelt without ever becoming maudlin, and simply one of the best interpretations-- and the most real-- ever presented on film.<br /><br />The excellent supporting cast includes Roger Rees (Fred Holywell, and also the narrator of the film), Caroline Langrishe (Janet Holywell), Lucy Gutteridge (Belle), Michael Gough (Mr. Poole) and Joanne Whalley (Fan). A flawless presentation, this version of `A Christmas Carol' sets the standard against which all others must be gauged; no matter how many versions you may have seen, watching this one is like seeing it for the first time ever. And forever after, whenever you think of Scrooge, the image your mind will conjure up will be that of George C. Scott. A thoroughly entertaining and satisfying experience, this film demands a place in the annual schedule of the holiday festivities of every home. I rate this one 10/10.",
        b'Put the blame on executive producer Wes Craven and financiers the Weinsteins for this big-budget debacle: a thrash-metal updating of "Dracula", with a condescending verbal jab at Bram Stoker (who probably wouldn\'t want his name on this thing anyway) and nothing much for the rest of us except slasher-styled jolts and gore. Christopher Plummer looks winded as Van Helsing in the modern-day--not just a descendant of Van Helsing but the real thing; he keeps himself going with leeches obtained from Count Dracula\'s corpse, which is exhumed from its coffin after being stolen from Van Helsing\'s vault and flown to New Orleans. This is just what New Orleans needs in the 21st Century! The film, well-produced but without a single original idea (except for multi-racial victims), is both repulsive and lazy, and after about an hour starts repeating itself. * from ****',
        b'Hilarious, evocative, confusing, brilliant film. Reminds me of Bunuel\'s L\'Age D\'Or or Jodorowsky\'s Holy Mountain-- lots of strange characters mucking about and looking for..... what is it? I laughed almost the whole way through, all the while keeping a peripheral eye on the bewildered and occasionally horrified reactions of the audience that surrounded me in the theatre. Entertaining through and through, from the beginning to the guts and poisoned entrails all the way to the end, if it was an end. I only wish i could remember every detail. It haunts me sometimes.<br /><br />Honestly, though, i have only the most positive recollections of this film. As it doesn\'t seem to be available to take home and watch, i suppose i\'ll have to wait a few more years until Crispin Glover comes my way again with his Big Slide Show (and subsequent "What is it?" screening)... I saw this film in Atlanta almost directly after being involved in a rather devastating car crash, so i was slightly dazed at the time, which was perhaps a very good state of mind to watch the prophetic talking arthropods and the retards in the superhero costumes and godlike Glover in his appropriate burly-Q setting, scantily clad girlies rising out of the floor like a magnificent DADAist wet dream.<br /><br />Is it a statement on Life As We Know It? Of course everyone EXPECTS art to be just that. I rather think that the truth is more evident in the absences and in the negative space. What you don\'t tell us is what we must deduce, but is far more valid than the lies that other people feed us day in and day out. Rather one "WHAT IS IT?" than 5000 movies like "Titanic" or "Sleepless in Seattle" (shudder, gag, groan).<br /><br />Thank you, Mr. Glover (additionally a fun man to watch on screen or at his Big Slide Show-- smart, funny, quirky, and outrageously hot). Make more films, write more books, keep the nightmare alive.',
        b'It was disgusting and painful. What a waste of a cast! I swear, the audience (1/2 full) laughed TWICE in 90 minutes. This is not a lie. Do not even rent it.<br /><br />Zeta Jones was just too mean to be believable.<br /><br />Cusack was OK. Just OK. I felt sorry for him (the actor) in case people remember this mess.<br /><br />Roberts was the same as she always is. Charming and sweet, but with no purpose. The "romance" with John was completely unbelievable.',
        b'This is a straight-to-video movie, so it should go without saying that it\'s not going to rival the first Lion King, but that said, this was downright good.<br /><br />My kids loved this, but that\'s a given, they love anything that\'s a cartoon. The big shock was that *I* liked it too, it was laugh out loud funny at some parts (even the fart jokes*), had lots of rather creative tie-ins with the first movie, and even some jokes that you had to be older to understand (but without being risqu\xc3\xa9 like in Shrek ["do you think he\'s compensating for something?"]).<br /><br />A special note on the fart jokes, I was surprised to find that none of the jokes were just toilet noises (in fact there were almost no noises/imagery at all, the references were actually rather subtle), they actually had a setup/punchline/etc, and were almost in good taste. I\'d like my kids to think that there\'s more to humor than going to the bathroom, and this movie is fine in those regards.<br /><br />Hmm what else? The music was so-so, not nearly as creative as in the first or second movie, but plenty of fun for the kids. No painfully corny moments, which was a blessing for me. A little action but nothing too scary (the Secret of NIMH gave my kids nightmares, not sure a G rating was appropriate for that one...)<br /><br />All in all I\'d say this is a great movie for kids of any age, one that\'s 100% safe to let them watch (I try not to be overly sensitive but I\'ve had to jump up and turn off the TV during a few movies that were less kid-appropriate than expected) - but you\'re safe to leave the room during this one. I\'d say stick around anyway though, you might find that you enjoy it too :)',
        b'Finally, Timon and Pumbaa in their own film...<br /><br />\'The Lion King 1 1/2: Hakuna Matata\' is an irreverent new take on a classic tale. Which classic tale, you ask? Why, \'The Lion King\' of course!<br /><br />Yep, if there\'s one thing that Disney is never short of, it\'s narcissism.<br /><br />But that doesn\'t mean that this isn\'t a good film. It\'s basically the events of \'The Lion King\' as told from Timon and Pumbaa\'s perspective. And it\'s because of this that you\'ll have to know the story of \'The Lion King\' by heart to see where they\'re coming from.<br /><br />Anyway, at one level I was watching this and thinking "Oh my god this is so lame..." and on another level I was having a ball. Much of the humour is predictable - I mean, when Pumbaa makes up two beds, a big one for himself and a small one for Timon, within the first nanosecond we all know that Timon is going to take the big one. But that doesn\'t stop it from being hilarious, which, IMO, is \'Hakuna Matata\' in a nutshell. It\'s not what happens, it\'s how.<br /><br />And a note of warning: there are also some fart jokes. Seriously, did you expect anything else in a film where Pumbaa takes centre stage? But as fart jokes go, these are especially good, and should satisfy even the most particular connoisseur.<br /><br />The returning voice talent is great. I\'m kinda surprised that some of the actors were willing to return, what with most of them only having two or three lines (if they\'re lucky). Whoopi Goldberg is particularly welcome.<br /><br />The music is also great. From \'Digga Tunnah\' at the start to \'That\'s all I need\', an adaption of \'Warthog Rhapsody\' (a song that was cut from \'The Lion King\' and is frankly much improved in this incarnation), the music leaves me with nothing to complain about whatsoever.<br /><br />In the end, Timon and Pumbaa are awesome characters, and while it may be argued that \'Hakuna Matata\' is simply an excuse to see them in various fun and assorted compromising situations then so be it. It\'s rare to find characters that you just want to spend time with.<br /><br />Am I starting to sound creepy?<br /><br />Either way, \'The Lion King 1 1/2\' is great if you\'ve seen \'The Lion King\' far too many times. Especially if you are right now thinking "Don\'t be silly, there\'s no such thing as seeing \'The Lion King\' too many times!"',
        b'Indian Directors have it tough, They have to compete with movies like "Laggan" where 11 henpecked,Castrated males defend their village and half of them are certifiable idiots. "Devdas", a hapless, fedar- festooned foreign return drinking to oblivion, with characters running in endless corridors oblivious to any one\'s feelings or sentiments-alas they live in an ornate squalor of red tapestry and pageantry. But to make a good movie, you have to tight-rope walk to appease the frontbenchers who are the quentessential gapers who are mesmerized with Split skirts and Dishum-Dishum fights preferably involving a nitwit "Bollywood" leading actor who is marginally handsome. So you can connect with a director who wants to tell a tale of Leonine village head who in own words "defending his Village" this is considered a violent movie or too masculine for a male audience. There are very few actors who can convey the anger and pathos like Nana Patekar (Narasimhan). Nana Patekar lets you in his courtyard and watch him beret and mock the Politician when his loyal admirers burst in laughter with every word of satire thrown at him, meanwhile his daughter is bathing his Grandson.This is as authentic a scene you can get in rural India. Nana Patekar is the essential actor who belongs to the old school of acting which is a disappearing breed in Hindi Films. The violence depicted is an intricate part of storytelling with Song&Dances thrown in for the gawkers without whom movies won\'t sell, a sad but true state of affairs. Faster this changes better for "Bollywood". All said and done this is one good Movie.',
        b"Nathan Detroit runs illegal craps games for high rollers in NYC, but the heat is on and he can't find a secure location. He bets chronic gambler Sky Masterson that Sky can't make a prim missionary, Sarah Brown, go out to dinner with him. Sky takes up the challenge, but both men have some surprises in store \xc2\x85<br /><br />This is one of those expensive fifties MGM musicals in splashy colour, with big sets, loud music, larger-than-life roles and performances to match; Broadway photographed for the big screen if you like that sort of thing, which I don't. My main problem with these type of movies is simply the music. I like all kinds of music, from Albinoni to ZZ Top, but Broadway show tunes in swing time with never-ending pah-pah-tah-dah trumpet flourishes at the end of every fourth bar aren't my cup of tea. This was written by the tag team of Frank Loesser, Mankiewicz, Jo Swerling and Abe Burrows (based on a couple of Damon Runyon stories), and while the plot is quite affable the songs are weak. Blaine's two numbers for example are identical, unnecessary, don't advance the plot and grate on the ears (and are also flagrantly misogynistic if that sort of thing bothers you). There are only two memorable tunes, Luck Be A Lady (sung by Brando, not Sinatra as you might expect) and Sit Down, You're Rockin' The Boat (nicely performed by Kaye) but you have to sit through two hours to get to them. The movie's trump card is a young Brando giving a thoughtful, laid-back performance; he also sings quite well and even dances a little, and is evenly matched with the always interesting Simmons. The sequence where the two of them escape to Havana for the night is a welcome respite from all the noise, bustle and vowel-murdering of Noo Yawk. Fans of musicals may dig this, but in my view a musical has to do something more than just film the stage show.",
        b"I can still remember first seeing this on TV. I couldn't believe TVNZ let it on! I had to own it! A lot of the humor will be lost on non-NZ'ers, but give it a go! <br /><br />Since finishing the Back of the Y series Matt and Chris have gone on to bigger and better(?) things. NZ's greatest dare-devil stuntman, Randy Campbell has often appeared on the British TV series Balls of Steel. Yes, he still f^@ks up all his stunts because he is too drunk.<br /><br />Also the 'house band' Deja Voodoo have since released 2 albums, Brown Sabbath and Back in Brown. The band consists of members of the Back of the Y team and singles such as 'I Would Give You One of My Beers (But I've Only Got 6)' and 'You Weren't Even Born in The 80's' continue their humor.<br /><br />The South-By-Southwest film festival also featured their feature length film 'The Devil Made Me Do It' which will be released early 2008 in NZ.<br /><br />All up, if you don't find these guys funny then you can just F%^K OFF!!",
        b'In a time of magic, barbarians and demons abound a diabolical tyrant named Nekhron and his mother Queen Juliane who lives in the realm of ice and wants to conquer the region of fire ruled by the King Jerol but when his beautiful daughter Princess Teegra has been kidnapped by Nekhron\'s goons, a warrior named Larn must protect her and must defeat Nekhron from taking over the world and the kingdom with the help of an avenger named Darkwolf.<br /><br />A nicely done and excellent underrated animated fantasy epic that combines live actors with animation traced over them ( rotoscoping), it\'s Ralph Bakshi\'s second best movie only with "American Pop" being number one and "Heavy Traffic" being third and "Wizards" being fourth. It\'s certainly better than his "Cool World" or "Lord of the Rings", the artwork is designed by famed artist Frank Farzetta and the animation has good coloring and there\'s also a hottie for the guys.<br /><br />I highly recommend this movie to fantasy and animation lovers everywhere especially the new 2-Disc Limited Edition DVD from Blue Underground.<br /><br />Also recommended: "The Black Cauldron", "The Dark Crystal", "Conan The Barbarian", "The Wizard of Oz", " Rock & Rule", "Wizards", "Heavy Metal", "Starchaser: Legend of Orin", "Fantastic Planet", " Princess Mononoke", " Nausicca: Valley of the Wind", " Conan The Destroyer", " Willow", " The Princess Bride", "Lord of the Rings ( 1978)", " The Sword in The Stone", " Excalibur", " Army of Darkness", " Krull", "Dragonheart", " King Arthur", " The Hobbit", " Return of the King ( 1980)", "Conquest", " American Pop", " Jason and The Argonauts", " Clash of the Titans", " The Last Unicorn", " The Secret of NIMH", "The Flight of Dragons", " Hercules (Disney)", " Legend", " The Chronicles of Narnia", " Harry Potter and The Goblet of Fire".',
        b'A pretty memorable movie of the animals-killing-people variety, specifically similar to "Willard" in that it stars an aging character actor (in this case, a step down a bit to the level of Les Tremayne, who puts in the only distinguished performance I\'ve seen him give) in a role as a man whose life is unbalanced and who subsequently decides to use his animal friends to exact revenge on those who have wronged him. Yes, this is one of those movies where pretty much everybody is despicable, so that you will cheer when they die, and really the selection of actors, locations, etc. couldn\'t be better at giving the film an atmosphere of shabby decadence.<br /><br />Tremayne\'s character is "Snakey Bender", and he is certainly the most interesting thing about the movie: an aged snake collector who is obsessed with John Philip Souza\'s music. When the local preacher clamps down on his practice of collecting small animals from the local schoolchildren as bait for his snakes, and his friend gets married to a stripper (thus upsetting his ritual Wednesday night band concert) he goes on the rampage, in the process creating a memorable pile-up of clunkers beneath the cliff where he dumps the wrecks after disposing of their unfortunate owners. One amusing game you can play while watching "Snakes" is to place bets on which cars will land the farthest down the cliff.<br /><br />All in all, very cheap and exploitative, but will really be a lot of fun for fans of these kinds of movies.',
        b"Except for an awkward scene, this refreshing fairy tale fantasy has a fun and delightful undercurrent of adult cynical wit that charms its way into the audience as well as a soundtrack that powerfully moves this fairy epic along. Except for one of the Robert DeNiro scenes that doesn't come across smooth and appears out of sync with the tone of the rest of the movie, this luscious romantic fairy tail has a great storytelling feel and the strong magic and the fine balance between serious adventure scenes and the lighter spiritual humor is well done. In the updated tradition of THE PRINCESS BRIDE this contemporary presentation of magic and love is captivating. Eight out of Ten Stars.",
        b'In this film we have the fabulous opportunity to see what happened to Timon and Pumbaa in the film when they are not shown - which is a lot! This film even goes back to before Simba and (presumbably) just after the birth of Kiara. <br /><br />Quite true to the first film, "Lion King 1/2 (or Lion King 3 in other places)" is a funny, entertaining, exciting and surprising film (or sequel if that\'s what you want to call it). A bundle of surprises and hilarity await for you!<br /><br />While Timon and Pumbaa are watching a film at the cinema (with a remote control), Timon and Pumbaa have an argument of what point of "The Lion King" they are going to start watching, as Timon wants to go to the part when he and Pumbaa come in and Pumbaa wants to go back to the beginning. They have a very fair compromise of watching the film of their own story, which is what awaits... It starts with Timon\'s first home...<br /><br />For anyone with a good sense of humour who liked the first films of just about any age, enjoy "Lion King 1/2"! :-)',
        b"Well, i rented this movie and found out it realllllllly sucks. It is about that family with the stepmother and the same stupid fights in the family,then the cool son comes with his stupid camera and he likes to take a photo to damaged building and weird things and weird movie ,and then he asks his father to take him to a side trip and simply agrees, etc etc etc..... They go to that town which no one know it exists (blah blah blah) And the most annoying thing is that the movie ends and yet you don't understand what is THAT MOVIE!!!!I have seen many mystery movies but that was the worst, Honestly it doesn't have a description at all and i wish i didn't see it.",
        b'I actually had quite high hopes going into this movie, so I took what was given with a grain of salt and hoped for the best. About 1/3 of the way through the film I simply had to give up, quite simply the movie is a mish-mash of stuff happening for no apparent reason and it\'s all disconnected. I love movies that make you think, but this movie was just a bunch of ideas thrown together and never really connected.<br /><br />Don\'t think it\'s David Lynch-esquire as some would have you believe, it is nowhere near that realm other than some trippy visuals. Saying it\'s artsy to disguise the fact there\'s no apparent plot or story is just a manner or justifying why you wasted the 1.5 hours in the film. The acting was good, but that cannot save lack of story. I do agree with the one comment posted previously... "it\'s like being in some other person\'s head... while they\'re on drugs," in other words nothing makes sense.',
        b"I liked the initial premise to this film which is what led me to hunt it out but the problem I quickly found is that one pretty much knows what's going to happen within the first 20-30 minutes ( the doubles will come from behind the mirror and take over everybody).<br /><br />There is no real twist (which is fine) , but the final reveal doesn't make a great deal of sense either (how can she be racked with uncertainty and fear for the whole film, if she's an evil id from beyond the mirror?).<br /><br />Admittedly the scenes 'beyond the mirror' were chilling when they first appeared and the blonde's murder is also effectively creepy, but ultimately alas this seems to be a film in search of a story or a more engaging script, piling atmosphere upon atmosphere and over the top scary sound design for 80-90 minutes does not really cut it, in fact it gets quite dull.",
        b"My main problem with the film is that it goes on too long. Other then that, it's pretty good. Paul Muni plays a poor Chinese farmer who is about to get married through an arranged marriage. Luise Rainer is a servant girl who gets married to Muni. They live with Muni's father on a farm and they are doing pretty bad. When he finally gets some money to buy some more land, a drought hits and nothing is growing. Everybody stars to head north by Muni stays behind at first. When they leave and arrive at town they find that their are no jobs and they are worse off than before. They even think about selling their youngest daughter as a slave for some money but decide against it. When a bunch of people start looting the town, the military show up and start executing people . Paul Muni does a good job and Luise Rainer won a second oscar for this movie.",
        b"This movie never made it to theaters in our area, so when it became available on DVD I was one of the first to rent it. For once, I should listened to the critics and passed on this one.<br /><br />Despite the excellent line up of actors the movie was very disappointing. I can see now why it went straight to video. <br /><br />I had thought that with Bloom, Ledger, and Rush it could have some value. All have done wonderful work in the past. <br /><br />The movie was slow moving and never pulled me in. I failed to develop much empathy for the characters and had to fight the urge to fast-forward just to get to the end. <br /><br />I do not recommend this film even if you are thinking of renting it for only for 'eye candy' purposes. It won't satisfy even that.",
        b'Mike Brady (Michael Garfield who had a minuscule part in the classic "The Warriors") is the first person in the community to realize that there\'s murderous slugs in his small town. Not just any slugs, mind you, but carnivorous killer bigger then normal, mutated by toxic waste slugs (who still only go as fast as a normal slug, which isn\'t that frightening, but I digress). No one will believe him at first, but they will. Oh yes, they will.<br /><br />OK, killer slugs are right above psychotic sloths and right below Johnathon Winters as Mork\'s baby in the creepiness factor. So the absurdness of it all is quite apparent from the get go. The flick is fun somewhat through and is of the \'so bad that it\'s good\' variety. I appreciate that they spelled out that this was Slugs: the Movie as opposed to Slugs: the Children\'s Game or Slugs: the Other White Meat. Probably not worthy of watching it more than once and promptly forgetting it except for playing a rather obscure trivia game. Director Juan Piquer Sim\xc3\xb3n is more widely known for his previous films "Pod People" (which MST3K deservedly mocked) and "Peices" (which is quite possibly the funnest bad movie ever made) <br /><br />Eye Candy: Kari Rose shows T&A <br /><br />My Grade: D+ <br /><br />DVD Extras: Merely a theatrical trailer for this movie',
        b'(Honestly, Barbra, I know it\'s you who\'s klicking all those "NO"s on my review. 22 times?? How many people did you have to instruct to help you out here? Don\'t you have anything better to do, like look at yourself in the mirror all day?)<br /><br />Steven Spielberg told Barbra that this was "the best movie I\'ve seen since \'Citizen Kane\'". That pretty much says it all - and serves as a dire warning!<br /><br />What are the ingredients for a sure-fire cinematic disaster, and one that will haunt you, never letting you forget the tears of both laughter and pain? The ingredients: Barbra Streisand\'s face, a musical, feminism, Barbra Streisand\'s voice, Barbra Streisand directing, and an ultra-corny/idiotic premise.<br /><br />Hollywood is full of egomaniacs, this much we know. In fact, nearly everyone \xc2\x96 by definition \xc2\x96 has to be an egomaniac in Hollywood. Why would anyone want to act? For the "art"?!? Well, if you\'re dumb enough to believe what they tell you in their carefully prepared interviews\xc2\x85 And Streisand has the biggest ego of them all! This is quite an achievement. To be surrounded by narcissistic cretins, and yet to manage to top them all \xc2\x96 remarkable.<br /><br />The movie, like all her "solo" endeavors, is an ego trip straight out of hell. Every scene Streisand is in is automatically ruined. Stillborn. But as it that weren\'t enough, she sings a whole bunch of Streisandy songs \xc2\x96 you know, the kind that enabled the Mariah Careys, the Celine Dions, and the Whitney Hustons of this world to poison our precious air-waves for decades now. Just for that she deserves not one but 100 South Park episodes mocking her.<br /><br />The premise, Streisand dressing up as a man to study to become a rabbi, sounds like a zany ZAZ comedy. Apart from it being a clich\xc3\xa9, the obvious problem is that Streisand doesn\'t look like a woman nor does she look like a man \xc2\x96 in fact I\'m not even sure she\'s human. The way she looks in this movie, well\xc2\x85 it cannot be described in words. E.T. looks like a high-school jock by comparison. She looks more alien than Michael Jackson in the year 2015. She looks HORRIBLE.<br /><br />The songs. They made me shiver. Particularly "Papa Can You Hear Me Squeel Like A Demented Female Walrus In Heat?" and "Tomorrow Night I\'ll Prepare the Sequel, YENTL 2: THE RETURN OF THE BITCH".<br /><br />Did you know that Streisand considered having a nose-job early on in her career, but changed her mind when they told her her voice might change? Can you believe that? She should have done it! Killing two flies with one swipe, that\'s what it would have been.<br /><br />If you\'re interested in reading my biographies of Barbra Streisand and other Hollywood intellectuals, contact me by e-mail.<br /><br />SHOULD BARBRA STREISAND FINALLY GO INTO RETIREMENT? CLICK "YES" OR "NO".',
        b'The "Confidential" part was meant to piggy-back on the popular appeal of the lurid magazine of the same name, while the labor racketeering theme tied in with headline Congressional investigations of the day. However, despite the A-grade B-movie cast and some good script ideas, the movie plods along for some 73 minutes. It\'s a cheap-jack production all the way. What\'s needed to off-set the poor production values is some imagination, especially from uninspired director Sidney Salkow. A few daylight location shots, for example, would have helped relieve the succession of dreary studio sets. A stylish helmsman like Anthony Mann might have done something with the thick-ear material, but Salkow treats it as just another pay-day exercise. Too bad that Brian Keith\'s typical low-key style doesn\'t work here, coming across as merely wooden and lethargic, at the same time cult figure Elisha Cook Jr. goes over the top as a wild-eyed drunk. Clearly, Salkow is no actor\'s director. But, you\'ve got to hand it to that saucy little number Beverly Garland who treats her role with characteristic verve and dedication. Too bad, she wasn\'t in charge. My advice-- skip it, unless you\'re into ridiculous bar-girls who do nothing else but knock back whiskeys in typical strait-jacketed 50\'s fashion.',
        b'I was thirteen years old, when I saw this movie. I expected a lot of action. Since Escape From New York was 16-rated in Germany I entered the movie as fallback. It was so boring. Afterwards I realized that this was just crap where a husband exhibits his wife. I mean today you do this via internet and you pay for instant access. It is more then 20 years ago, but I am still angry that I waste my time with this film. This is a soft-porno for schoolboys. Undressing Bo Derek and painting her with color - nice. But then they should named the film Undressing Bo and painting her.',
        b"I'm sorry, I had high hopes for this movie. Unfortunately, it was too long, too thin and too weak to hold my attention. When I realized the whole movie was indeed only about an older guy reliving his dream, I felt cheated. Surely it could have been a device to bring us into something deeper, something more meaningful.<br /><br />So, don't buy a large drink or you'll be running to the rest room. My kids didn't enjoy it either. Ah well.",
        b"A movie made for contemporary audience. The masses get to see what they want to see. Action, comedy, drama and of course sensuous scenes as well. This is not exactly a movie that one would feel comfortable watching with entire family. It isn't for eyes of children. I had to fast forward quite a number of scenes.<br /><br />If it is just entertainment you are looking for, then this movie has it all. The songs are catchy. A lavish production, I must add.<br /><br />However, the message of the movie is not universal. It emphasizes on the idea of karma. That is, if you do good, you will get good. And if you do evil, you will get evil. The fruit of good deeds is good, while the fruit of evil is evil. <br /><br />In real life, this is not always true. It is well-known that most people do not get justice in this world. While it is true that some evil people do meet with an evil end, there are many who escape. And then, there are many people who do good, and yet in return they meet with a sorry end.<br /><br />If you don't care about the message, and all you want is an escape from worldly reality, this movie is an entertainer alright.",
        b'Of the three remakes of this plot, I like them all, I have all three on VHS and in addition have a copy of this one on DVD. There is just enough variation in the scripts to make all three entertaining and re-watchable. In addition has any other film been remade three times with such all star casts in each? Of course the main stars in this one are great, but the supporting actors are also superb. I particularly like William Tracy as Pepi. He was such a scene stealer that I have searched to find other movies he is in. He appeared in many, but most are not available. As the other comments, I also say - buy this one.',
        b'"Wild Tigers I have Known." It will only be showing in big cities, to be sure. It is one of those films SO artsy, that it makes no sense what so ever, except to the director! I HATE those! And all of those oh-so-alternative/artsy people try DESPERATELY to find "metaphors" in what is EVIDENT horseshit.<br /><br />There was NO plot, no story, no moral, no chronology, and nothing amusing or even touching. To me, it was a bunch of scenes thrown together that had nothing to do with one another, and were all for "show" to show how "artsy" and "visual" they could get. It was an ATTEMPT at yet ANOTHER teen angst film, but missed the mark on every level humanly possible. Then the credits roll! I was waiting for it to make SENSE! I was waiting for "the good part." I own about 60 independent films in my DVD collection, many of which could arguably be called "art house" films. This will NOT be amongst them. You will be very angry at yourself for paying to see this film, much less ever buying it on DVD.',
        b'This is a big step down after the surprisingly enjoyable original. This sequel isn\'t nearly as fun as part one, and it instead spends too much time on plot development. Tim Thomerson is still the best thing about this series, but his wisecracking is toned down in this entry. The performances are all adequate, but this time the script lets us down. The action is merely routine and the plot is only mildly interesting, so I need lots of silly laughs in order to stay entertained during a "Trancers" movie. Unfortunately, the laughs are few and far between, and so, this film is watchable at best.',
        b"This is one of three 80's movies that I can think of that were sadly overlooked at the time and unfortunately, still overlooked. One of the others was Clownhouse directed by Victor Salva, a movie horribly overlook due to Salva's legal/sexual problems. Another would be Cameron's Closet which strikes me as somewhat underrated--not great, but not nearly as bad as the reviews I've seen. Paper House is well worth your time and I think that it is one of those very quiet films that will just stick in your brain for far longer than you might think. I mean, 10 years after I've seen it and I still give it some pause, whereas something that I might have seen 6 months ago has gone into the ether.",
        b"I have yet to read a negative professional review of this movie. I guess I must have missed something. The beginning is intriguing, the three main characters meet late at night in an otherwise empty bar and entertain each other with invented stories. That's the best part. After the three go their separate ways, the film splits into three threads. That's when boredom sets in. Certainly, the thread with the Felliniesque babushkas who make dolls out of chewed bread is at first an eye opening curiosity. Unfortunately, the director beat this one to death, even injecting a wild plot line that leads nowhere in particular. Bottom line: a two-hour plot-thin listlessness. If you suffer from insomnia, view it in bed and you will have a good night sleep.",
        b'Although the likeliness of someone focusing on THIS comment among the other 80+ for this movie is low, I feel that I have to say something about this one. I am not the kind of movie-watcher who pays attention to production value, thought-provoking dialog, or brilliant acting & directing. However, I claim that this movie sucks. I don\'t know why I don\'t like it... I mean it has almost everything i want out of a horror movie: blood, outrageousness, unintentional humor, etc. According to this evidence it should be my favorite. Still, Zombi 3 is a baaad movie.<br /><br />There are just too many things that compels you to yell at the screen. Like when the girl leaves the army guy when their car breaks down to find water (this spoils nothing so don\'t worry). She walks into what I see as an abandoned hotel or something. Did she not see that there was a friggin\' lake in the middle of the building??? Yes she\'s looking for water and passes up a lake. Why? Cuz she wants to know why the people (who aren\'t there cuz the place is abandoned) won\'t answer her when she calls out: "Is anybody there?" Oh this is just a little, insignificant piece of the big picture I\'m painting.<br /><br />There is a reason, though, why I gave this film more than 1 star. It\'s one of those movies where if you forget how bad it really is, like I have a few times, you\'ll want to watch it again because it\'s just so over-the-top in every aspect. I called it blood in the first paragraph, but this movie has no blood, it has an ocean of gore. Also, it has pretty weird creatures in it as well: a zombie-baby (with an adult-size hand???) and a magically flying head to name just two.<br /><br />You know when you try to think of the worst and cheesiest movies ever made and you come up with \'50\'s sci-fi movies? I believe that Zombi 3 and movies like it should top those. It has all the elements: scientists arguing with the government, warnings of the apocalypse on the radio, armies battling monsters, and so on. This IS the Plan 9 of the \'80\'s! While I won\'t say that this is a waste of money if you want to buy it, just expect the very worst. And when you find out that expecting the worst is underestimating Zombi 3, it won\'t be all that bad. You might actually like it, I\'m not saying that\'s impossible.<br /><br />Don\'t think I hate this movie, I don\'t... really. Oh, P.S. Killing Birds (aka Zombie 5) rules! (did I just blow my credibility?)',
        b'One of the weaker Carry On adventures sees Sid James as the head of a crime gang stealing contraceptive pills. The fourth of the series to be hospital-based, it\'s possibly the least of the genre. There\'s a curiously flat feel throughout, with all seemingly squandered on below-par material. This is far from the late-70s nadir, but Williams, James, Bresslaw, Maynard et al. are all class performers yet not given the backing of a script equal to their ability.<br /><br />Most of the gags are onrunning, rather than episodic as Carry Ons usually are. So that instead of the traditional hit and miss ratio, if you don\'t find the joke funny in the first place you\'re stuck with it for most of the film. These continuous plot strands include Williams \xc2\x96 for no good reason \xc2\x96 worrying that he\'s changing sex, and Kenneth Cope in drag. Like the stagy physical pratt falls, the whole thing feels more contrived than in other movies, and lacking in cast interest. Continuing this theme, Matron lacks the customary pun and innuendo format, largely opting for characterisation and consequence to provide the humour. In fact, the somewhat puerile series of laboured misunderstandings and forced circumstance reminds one more of Terry and June ... so it\'s appropriate that Terry Scott is present, mugging futilely throughout.<br /><br />Some dialogue exchanges have a bit of the old magic, such as this between Scott and Cope: "What about a little drink?" "Oh, no, no, I never touch it." "Oh. Cigarette then?" "No, I never touch them." "That leaves only one thing to offer you." "I never touch that either." That said, while a funny man in his own right (livening up the duller episodes of Randall and Hopkirk (Deceased) no end), you do feel that Cope isn\'t quite tapped in to the self-parodying Carry On idealology and that Bernard Bresslaw dressed as a nurse would be far funnier. This does actually happen, in part, though only for the last fifteen minutes.<br /><br />Williams attempting to seduce Hattie Jacques while Charles Hawtrey is hiding in a cupboard is pure drawer room farce, but lacks the irony to carry it off. That said, Williams\'s description of premarital relations is priceless: "You don\'t just go into the shop and buy enough for the whole room, you tear yourself off a little strip and try it first!" "That may be so," counters Jacques, "but you\'re not going to stick me up against a wall." Williams really comes to life in his scenes with Hattie, and you can never get bored of hearing a tin whistle whenever someone accidentally flashes their knickers.<br /><br />Carry On Matron is not a bad film by any means, just a crushingly bog-standard one.',
        b'I was looking forward to this ride, and was horribly disappointed.<br /><br />And I am very easily amused at roller coaster and amusement park rides.<br /><br />The roller coaster part was just okay - and that was all of about 30 seconds of a 90 second ride. <br /><br />It was visually dull and poorly executed. <br /><br />It was trying desperately to be like a mixture of the far superior Indiana Jones and Space Mountain rides and Disneyland, and failed in every aspect.<br /><br />It was not thrilling or exciting in the least.',
        b'I really wish i could give this a negative vote, because i think i just wasted 83 minutes of my life watching the worst horror movie ever put to film. the acting was just god awful, i mean REALLLYYYY bad, the dialog was worse, the script sounded like it was written by.... i can\'t think of anything horrible enough to say. And the day "outside" and the night "inside" shots make you think the events took over several days. Terribly acted, directed, written, etc etc. all the way down to the gofer how gets lunch for everyone. STAY AWAY FROM THIS ONE AT ALL COSTS. If my only saving grace to stay out of hell is by doing one good deed, it is to tell the world to not watch this crap. This movie is the exact reason why horror movies are never taken seriously.',
        b"The director Sidney J. Furie has created in Hollow Point a post-modern absurdist masterpiece that challenges and constantly surprises the audience. <br /><br />Sidney J. Furie dares to ask the question of what happens to the tired conventional traditionalist paradigms of 'plot' and 'characterisation' when you remove the crutches of 'motivation' and 'reason'. <br /><br />The result leads me to say that my opinion of him could not possibly get any higher.<br /><br />One and a half stars.<br /><br />P.S. Nothing in this movie makes any sense, the law enforcement agents are flat out unlikeable and the organised criminals are full on insane.",
        b"Luise Rainer received an Oscar for her performance in The Good Earth. Unfortunately, her role required no. She did not say much and looked pale throughout the film. Luise's character was a slave then given away to marriage to Paul Muni's character (he did a fantastic job for his performance). Set in ancient Asia, both actors were not Asian, but were very convincing in their roles. I hope that Paul Muni received an Oscar for his performance, because that is what Luise must have gotten her Oscar for. She must have been a breakthrough actress, one of the first to method act. This seems like something that Hollywood does often. Al Pacino has played an Italian and Cuban. I felt Luise's performance to be lackluster throughout, and when she died, she did not change in expression from any previous scenes. She stayed the same throughout the film; she only changed her expression or emotion maybe twice. If her brilliant acting was so subtle, I suppose I did not see it.",
        b'I thought Rachel York was fantastic as "Lucy." I have seen her in "Kiss Me, Kate" and "Victor/Victoria," as well, and in each of these performances she has developed very different, and very real, characterizations. She is a chameleon who can play (and sing) anything!<br /><br />I am very surprised at how many negative reviews appear here regarding Rachel\'s performance in "Lucy." Even some bonafide TV and entertainment critics seem to have missed the point of her portrayal. So many people have focused on the fact that Rachel doesn\'t really look like Lucy. My response to that is, "So what?" I wasn\'t looking for a superficial impersonation of Lucy. I wanted to know more about the real woman behind the clown. And Rachel certainly gave us that, in great depth. I also didn\'t want to see someone simply "doing" classic Lucy routines. Therefore I was very pleased with the decision by the producers and director to have Rachel portray Lucy in rehearsal for the most memorable of these skits - Vitameatavegamin and The Candy Factory. (It seems that some of the reviewers didn\'t realize that these two scenes were meant to be rehearsal sequences and not the actual skits). This approach, I thought, gave an innovative twist to sketches that so many of us know by heart. I also thought Rachel was terrifically fresh and funny in these scenes. And she absolutely nailed the routines that were recreated - the Professor and the Grape Stomping, in particular. There was one moment in the Grape scene where the corner of Rachel\'s mouth had the exact little upturn that I remember Lucy having. I couldn\'t believe she was able to capture that - and so naturally.<br /><br />I wonder if many of the folks who criticized the performance were expecting to see the Lucille Ball of "I Love Lucy" throughout the entire movie. After all, those of us who came to know her only through TV would not have any idea what Lucy was really like in her early movie years. I think Rachel showed a natural progression in the character that was brilliant. She planted all the right seeds for us to see the clown just waiting to emerge, given the right set of circumstances. Lucy didn\'t fit the mold of the old studio system. In her frustrated attempts to become the stereotypical movie star of that era, she kept repressing what would prove to be her ultimate gifts.<br /><br />I believe that Rachel deftly captured the comedy, drama, wit, sadness, anger, passion, love, ambition, loyalty, sexiness, self absorption, childishness, and stoicism all rolled into one complex American icon. And she did it with an authenticity and freshness that was totally endearing. "Lucy" was a star turn for Rachel York. I hope it brings a flood of great roles her way in the future. I also hope it brings her an Emmy.',
        b'This might be the poorest example of amateur propaganda ever made. The writers and producers should study the German films of the thirties and forties. They knew how to sell. Even soviet-style clunky leader as god-like father-figure were better done. Disappointing. The loss of faith, regained in church at last second just in time for daddy to be "saved" by the Hoover/God was not too bad. Unfortunately, it seemed rushed and not nearly melodramatic enough. A few misty heavenlier shots of the angelical Hoover up in the corner of the screen-beaming and nodding- would have added a lot. The best aspect is Hoover only saving the deserving family and children WHO had "proven" their worth. Unfortunately, other poor homeless were portrayed as likable and even good- yet the Hoover-God doesn\'t help them. A better approach would have been shots of them drinking spirits to show the justice of their condition. Finally, bright and cheerful scenes of recovery (after Hoover saved the country from the depression) should have rolled at the end. We could see then how Hoover-God had saved not just THIS deserving family, but all the truly deserving. Amateurist at best.',
        b'"The Plainsman" represents the directorial prowess of Cecil B. DeMille at its most inaccurate and un-factual. It sets up parallel plots for no less stellar an entourage than Wild Bill Hickok (Gary Cooper), Buffalo Bill Cody (James Ellison), Calamity Jane (Jean Arthur), George Armstrong Custer and Abraham Lincoln to interact, even though in reality Lincoln was already dead at the time the story takes place. Every once in a while DeMille floats dangerously close toward the truth, but just as easily veers away from it into unabashed spectacle and showmanship. The film is an attempt to buttress Custer\'s last stand with a heap of fiction that is only loosely based on the lives of people, who were already the product of manufactured stuffs and legends. Truly, this is the world according to DeMille - a zeitgeist in the annals of entertainment, but a pretty campy relic by today\'s standards.<br /><br />TRANSFER: Considering the vintage of the film, this is a moderately appealing transfer, with often clean whites and extremely solid blacks. There\'s a considerable amount of film grain in some scenes and an absence of it at other moments. All in all, the image quality is therefore somewhat inconsistent, but it is never all bad or all good \xc2\x96 just a bit better than middle of the road. Age related artifacts are kept to a minimum and digital anomalies do not distract. The audio is mono but nicely balanced.<br /><br />EXTRAS: Forget it. It\'s Universal! BOTTOM LINE: As pseudo-history painted on celluloid, this western is compelling and fun. Just take its characters and story with a grain of salt \xc2\x96 in some cases \xc2\x96 a whole box seems more appropriate!',
        b'Not the most successful television project John Cleese ever did, "Strange Case" has the feel of a first draft that was rushed into production before any revisions could be made. There are some silly ideas throughout and even a few clever ones, but the story as a whole unfortunately doesn\'t add up to much.<br /><br />Arthur Lowe is a hoot, though, as Dr. Watson, bionic bits and all. "Good Lord."',
        b'Only after some contemplation did I decide I liked this movie. And after reading comments from all the other posters here, and thinking about it some more, I decided that I liked it tremendously. I love American films - probably because they are so narrative. They usually have a well-defined beginning, middle, and end. "Presque rien," on the other hand, makes no such attempt. I disagree with other posters that say it\'s \'too artsy.\' In every way, this film is meant to evoke your sense memories. So often throughout the film you feel like you\'re there... you feel the summer sun, the breezes, the heat, the winter chill, the companionship, the loneliness, etc., etc.<br /><br />In every way, the director pulls you into the lives of the characters - which is why so many people feel so strongly that the movie disappointed them. After I finished watching it, I felt the same. But upon some reflection, I recognized that this is how the movie had to be: the \'story\' isn\'t the narrative, it\'s the emotions you (the viewer) feel.<br /><br />The lighting, scenery, and camera angles immerse you in the scenes - they\'re rich, exquisite, and alive with detail and nuance. Although I normally cannot countenance films without a fully developed plot (after all, isn\'t a movie \'supposed\' to tell a story), this film is definitely one of my new favorites.',
        b'I just saw the movie on tv. I really enjoyed it. I like a good mystery. and this one had me guessing up to the end. Sean Connery did a good job. I would recomend it to a friend.',
        b'I know that originally, this film was NOT a box office hit, but in light of recent Hollywood releases (most of which have been decidedly formula-ridden, plot less, pointless, "save-the-blonde-chick-no-matter-what" drivel), Feast of All Saints, certainly in this sorry context deserves a second opinion. The film--like the book--loses anchoring in some of the historical background, but it depicts a uniquely American dilemma set against the uniquely horrific American institution of human enslavement, and some of its tragic (and funny, and touching) consequences.<br /><br />And worthy of singling out is the youthful Robert Ri\'chard, cast as the leading figure, Marcel, whose idealistic enthusiasm is truly universal as he sets out in the beginning of his \'coming of age,\' only to be cruelly disappointed at what turns out to become his true education in the ways of the Southern plantation world of Louisiana, at the apex of the antebellum period. When I saw the previews featuring the (dreaded) blond-haired Ri\'chard, I expected a buffoon, a fop, a caricature--I was pleasantly surprised.<br /><br />Ossie Davis, Ruby Dee, the late Ben Vereen, Pam Grier, Victoria Rowell and even Jasmine Guy lend vivid imagery and formidable skill as actors in the backdrop tapestry of placage, voodoo, Creole "aristocracy," and Haitian revolt woven into this tale of human passion, hate, love, family, and racial perplexity in a society which is supposedly gone and yet somehow is still with us.',
        b'How Disney can you get? Preppy rich girls act like idiots, buy a bunch of stuff, and get taught a lesson. Is Disney trying to send a lesson to itself? That maybe while buying everything it should maybe still be human? Whatever the psycho-analysis, this movie sucked.<br /><br />The girls want a rich party for their rich lives. But then money disappears and they have to use their riches to get the milk plant (yes, milk) going to employ the workers. They keep it afloat until daddy comes home. And the man at the beginning, who appears to be the one that takes the money, is the one. But the ending is dumb. Webcam in the Cayman Islands? Huh? Not worth my time ever again. <br /><br />But it is better than Howl\'s Moving Castle. "D-"',
        b'I was surprised at the low rating this film got from viewers. I saw it one late night on TV and it hit the spot - I actually think it was back in 1989 when it first appeared. Yet I remember it pretty well, with a nice twist or two, and an interesting ambiance on a windmill farm. Michael Pollard looks suitably seedy for his role which pretty much sums up the unfulfilled early promise of his career, and everyone else plays it pretty straight ahead. I definitely recommend it as a rental, although some of the themes, which might have seemed a bit edgy in 1989, now may seem tame, which is a shame, considering that contemporary "edginess" is often just used as a necessary marketing tool, sort of like clamoring just to get noticed.',
        b"As the film begins a narrator warns us THE SCREAMING SKULL is so terrifying you might die of fright--and if such happens a free burial is guaranteed. Well, I don't think any one has died of fright from seeing this film, but a few may have died of boredom. THE SCREAMING SKULL is the sort of movie that makes Ed Wood look good.<br /><br />Very loosely based on the famous Francis Marion Crawford story, SKULL is about a wealthy but nervous woman who marries a sinister man whose first wife died under mysterious circumstances. Once installed in his home, she is tormented by a half-wit gardener, a badly executed portrait, peacocks, and ultimately a skull that rolls around the room and causes her to scream a lot. And to her credit, actress Peggy Webber screams rather well.<br /><br />Unfortunately, her ability to do so is the high point of the film. The plot is pretty transparent, to say the least, and while the cast is actually okay, the script is dreadful and the movie so uninspired you'll be ready to run screaming yourself. True, the thing only runs about sixty-eight minutes, but it all feels a lot longer. Add to this a truly terrible print quality and there you are.<br /><br />There are films that are so bad they are fun to watch. It is true that THE SCREAMING SKULL has a few howlers--but the film drags so much I couldn't work up more than an occasional giggle, and by the time the whole thing is over your head will roll from ennui. If it weren't for Peggy Webber's way with a scream, this would be the surefire cure for insomnia. Give it a miss.<br /><br />GFT, Amazon Reviewer",
        b'Elegance and class are not always the first words that come to mind when folks (at least folks who might do such a thing) sit around and talk about film noir. <br /><br />Yet some of the best films of the genre, "Out of the Past," "The Killers," "In A Lonely Place," "Night and the City," manage a level of sleek sophistication that elevates them beyond a moody catch phrase and its connotations of foreboding shadows, fedoras, and femme-fatales. <br /><br />"Where the Sidewalk Ends," a fairly difficult to find film -- the only copy in perhaps the best stocked video store in Manhattan was a rough bootleg from the AMC cable channel -- belongs in a category with these classics.<br /><br />From the moment the black cloud of opening credits pass, a curtain is drawing around rogue loner detective Marc Dixon\'s crumbling world, and as the moments pass, it inches ever closer, threatening suffocation. <br /><br />Sure, he\'s that familiar "cop with a dark past", but Dana Andrews gives Dixon a bleak stare and troubled intensity that makes you as uncomfortable as he seems. And yeah, he\'s been smacking around suspects for too long, and the newly promoted chief (Karl Malden, in a typically robust and commanding outing) is warning him "for the last time." <br /><br />Yet Dixon hates these thugs too much to stop now. And boy didn\'t they had have it coming? <br /><br />"Hoods, dusters, mugs, gutter nickel-rats" he spits when that tough nut of a boss demotes him and rolls out all of the complaints the bureau has been receiving about Dixon\'s right hook. The advice is for him to cool off for his own good. But instead he takes matters into his own hands. <br /><br />And what a world of trouble he finds when he relies on his instincts, and falls back on a nature that may or may not have been passed down from a generation before. <br /><br />Right away he\'s in deep with the cops, the syndicate, his own partner. Dixon\'s questionable involvement in a murder "investigation" threatens his job, makes him wonder whether he is simply as base as those he has sworn to bring in. Like Bogart in "Lonely Place," can he "escape what he is?"<br /><br />When he has nowhere else to turn, he discovers that he has virtually doomed his unexpected relationship with a seraphic beauty (the marvelous Gene Tierney) who seems as if she can turn his barren bachelor\'s existence into something worth coming home to. <br /><br />The pacing of this superb film is taut and gripping. The group of writers that contributed to the production polished the script to a high gloss -- the dialogue is snappy without disintegrating into dated parody fodder, passionate without becoming melodramatic or sappy. <br /><br />And all of this top-notch direction and acting isn\'t too slick or buffed to loosen the film\'s emotional hold. Gene Tierney\'s angelic, soft-focus beauty is used to great effect. She shows herself to be an actress of considerable range, and her gentle, kind nature is as boundless here as is her psychosis in "Leave Her to Heaven." The scenes between Tierney and Andrews\'s Dixon grow more intense and touching the closer he seems to self-destruction. <br /><br />Near the end of his rope, cut, bruised, and exhausted Dixon summarizes his lot: "Innocent people can get into terrible jams, too,.." he says. "One false move and you\'re in over your head." <br /><br />Perhaps what makes this film so totally compelling is the sense that things could go wildly wrong for almost anyone -- especially for someone who is trying so hard to do right -- with one slight shift in the wind, one wrong decision or punch, or, most frighteningly, due to factors you have no control over. Noir has always reflected the darkest fears, brought them to the surface. "Where the Sidewalk Ends" does so in a realistic fashion. <br /><br />(One nit-pick of an aside: This otherwise sterling film has a glaringly poor dub of a blonde model that wouldn\'t seem out of place on Mystery Science Theater. How very odd.) <br /><br />But Noir fans -- heck, ANY movie fans -- who haven\'t seen this one are in for a terrific treat.',
        b"Having not seen the previous two in the trilogy of Bourne movies, I was a little reluctant to watch The Bourne Ultimatum.<br /><br />However it was a very thrilling experience and I didn't have the problem of not understanding what was happening due to not seeing the first two films. Each part of the story was easy to understand and I fell in love with The Bourne Ultimatum before it had reached the interval! I don't think I have ever watched such an exquisitely made, and gripping film, especially an action film. Since I usually shy away from action and thriller type movies, this was such great news to me. Ultimatum is one of the most enthralling films, it grabs your attention from the first second till the last minute before the credits roll.<br /><br />Matt Damon was simply fantastic as his role as Jason Bourne. I've heard a lot about his great performances in the Bourne 1+2, and now, this fabulous actor has one more to add to his list. I look forward to seeing more of his movies in the future.<br /><br />The stunts were handled with style - each one was done brilliantly and I was just shocked by the impressiveness of this movie. Well done.",
        b"Calling this a romantic comedy is accurate but nowadays misleading. The genre has sadly deteriorated into cliches, too focused on making the main couple get together and with very little room for ambience and other stories, making it formulaic and overly predictable.<br /><br />The Shop Around the Corner does not suffer from these illnesses: it manages to create a recognisably middle/eastern-European atmosphere and has a strong cast besides the (also strong) nominal leads; I avoid using the words 'supporting cast' as for example Mr. Matuschek (Frank Morgan) has a central role to the film and his story is equally if not more important than the romance.<br /><br />The 1998 film You've Got Mail borrowed the 'anonymous pen-pal' idea from this film and has therefore been billed as a remake. This is not correct and in fact unfair to the new movie - it shares the genre and borrows a plot element, but that is all.",
        b"The book on which this movie is based was excellent; it took a while to come to grips with Houellebecq's unconventional style but once I understood the mood behind the writing I was completely drawn into the author's world of sadness. In fact, no other book has affected me so much. This is not necessarily a good thing - it elucidated my own personal struggle and has made the futility of my own struggle harder to accept. Houellebecq's insights are masterfully captured by Harel and the hero's apathy and indifference to a world which has rejected him is perfectly portrayed. This is a movie which reveals today's society for the lowly male in all its horror. Hopefully, things will change in the future but for the present we have to accept the rat-race as shown in this movie. It's probably best that Harel or Houellebecq do not create a work of genius like this again. One is enough for any man.",
        b"A hilarious and insightful perspective of the dating world is portrayed in this off beat comedy by first time writer/director Peter M. Cohen. The story unfolds as the four male protagonists meet weekly at the local diner to confer about their dating woes. We meet Brad: a good-looking, wall-street playboy with a quick-wit and sharp tongue; Zeek: a cynical, sensitive writer; Jonathan: a sexually perplexed nice guy with an affinity for hand creams and masturbation; and Eric: the married guy, who cherishes his weekly encounters with his single friends in hope for some enlightenment to his boring and banal married existence. The trials and tribulations of the men's single lives in New York are amusingly expressed, mirroring that of Sex in the City and HBO's new comedy The Mind of Married Man, and bring an astute light to scamming. The story takes a twist as the three singletons meet Mia--wittily played by Amanda Peet-and all fall for her. She seduces them each with her uncanny ability to conform to the personalities' they exhibit. When they come to realize they have all met and fallen in love with the same woman, they chose her over their friendship. Whipped is a realistic portrayal of the dating world, one that the critic's failed to recognize. In plain language, they missed the point. The protagonist's here are caricatures of real people. The exaggerations are hysterical, mixing satire and humility, and are not to be taken as seriously as the critic's disparagement suggests. See this movie, you'll laugh from start to finish.",
        b"Judging from this film and THE STRONG MAN, made the same year, I would not place Harry Langdon at the top of the list of great silent screen comedians. There simply is not enough there. Perhaps he was on his way to developing his style but sabotaged himself by taking his first big successes too seriously. In any event, all of his tricks are reminiscent of the greater funny men, but he lacks the acrobatic skills of Keaton and the prodigious ingenuity of Lloyd. He also undermines his own persona by dressing and walking like Chaplin's tramp character. His trademarks are childlike innocence, timidity of approach and a tendency to under-react to calamity by looking perplexed, batting his eyes or touching his pursed lips with the tip of his forefinger. The comedy in Langdon's films results from fate throwing various obstacles in his path which he tries to overcome in wimpy or na\xc3\xafve ways or with a minimum of physicality, such as throwing rocks at an approaching tornado to drive it away, propping up a collapsing building with a two-by-four or dodging boulders by lifting a leg so that they roll under him. In this story, about the son of a shoemaker who joins a cross-country walking race to publicize a rival company's footwear, he manages to win by sheer luck. There is nothing here that hasn't been done far better by the Big Three.",
        b'I didn\'t expect Val Kilmer to make a convincing John Holmes, but I found myself forgetting that it wasn\'t the porn legend himself. In fact, the entire cast turned in amazing performances in this vastly under-rated movie.<br /><br />As some have mentioned earlier, seek out the two-disc set and watch the "Wadd" documentary first; it will give you a lot of background on the story which will be helpful in appreciating the movie. <br /><br />Some people seem unhappy about the LAPD crime scene video being included on the DVD. There are a number of reasons that it might have been included, one of which is that John Holmes\' trial for the murders was the first ever in the United States where such footage was used by the prosecution. If you don\'t want to see it, it\'s easy to avoid; it\'s clearly identified as "LAPD Crime Scene Footage" on the menu!'],
       dtype=object)>}
loaded = tf.saved_model.load("/tmp/saved_model")

run_sig = loaded.signatures["serving_default"]
result = run_sig(sample_batch['text'])['result']
np.mean(tf.argmax(result, axis=-1) == sample_batch['label'])
0.65625

次のステップ

このチュートリアルでは、DTensor を使って MLP センチメント分析モデルの構築とトレーニングを行う方法を説明しました。

MeshLayout はプリミティブではありますが、DTensor は TensorFlow tf.function を、さまざまなトレーニングスキームに適した分散型プログラムに変換することができます。

実際の機械学習アプリケーションでは、評価とクロス検証を適用して、過学習モデルが生成されないようにする必要があります。このチュートリアルで紹介された手法を適用して、評価に並列性を導入することも可能です。

tf.Module を使ってモデルをゼロから構築するには多大な労力が必要であり、レイヤーやヘルパー関数と言った既存のビルディングブロックを再利用することで、モデル開発を大幅に高速化することができます。TensorFlow 2.9 の時点では、tf.keras.layers 以下のすべての Keras レイヤーは、その引数として DTensor レイアウトを受け入れ、DTensor モデルを構築するために使用することができます。また、モデルの実装を変更することなく、DTensor を使って直接 Keras モデルを再利用することも可能です。DTensor Keras の使用に関する詳細は、DTensor と Keras の統合チュートリアルをご覧ください。