![]() |
![]() |
![]() |
![]() |
モデルの進行状況は、トレーニング中およびトレーニング後に保存できます。モデルが中断したところから再開できるので、長いトレーニング時間を回避できます。また、保存することによりモデルを共有したり、他の人による作業の再現が可能になります。研究モデルや手法を公開する場合、ほとんどの機械学習の実践者は次を共有します。
- モデルを構築するプログラム
- モデルのトレーニング済みモデルの重みやパラメータ
このデータを共有することで、他の人がモデルがどの様に動作するかを理解したり、新しいデータに試してみたりすることが容易になります。
注意: TensorFlow モデルはコードであり、信頼できないコードに注意する必要があります。詳細については、TensorFlow を安全に使用するをご覧ください。
オプション
使用している API に応じて、さまざまな方法で TensorFlow モデルを保存できます。このガイドでは、高レベル API である tf.keras を使用して、TensorFlow でモデルを構築およびトレーニングします。他のアプローチについては、TensorFlow 保存と復元ガイドまたは Eager で保存するを参照してください。
設定
インストールとインポート
TensorFlow をインストールし、依存関係インポートします。
pip install pyyaml h5py # Required to save models in HDF5 format
import os
import tensorflow as tf
from tensorflow import keras
print(tf.version.VERSION)
2022-08-09 01:30:16.793268: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2022-08-09 01:30:17.530275: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory 2022-08-09 01:30:17.530510: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory 2022-08-09 01:30:17.530523: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly. 2.10.0-rc0
サンプルデータセットの取得
ここでは、重みの保存と読み込みをデモするために、MNIST データセットを使います。デモの実行を速くするため、最初の 1,000 件のサンプルだけを使います。
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
モデルの定義
簡単なシーケンシャルモデルを構築することから始めます。
# Define a simple sequential model
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10)
])
model.compile(optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
return model
# Create a basic model instance
model = create_model()
# Display the model's architecture
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 512) 401920 dropout (Dropout) (None, 512) 0 dense_1 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0 _________________________________________________________________
トレーニング中にチェックポイントを保存する
再トレーニングせずにトレーニング済みモデルを使用したり、トレーニングプロセスを中断したところから再開することもできます。tf.keras.callbacks.ModelCheckpoint
コールバックを使用すると、トレーニング中でもトレーニングの終了時でもモデルを継続的に保存できます。
チェックポイントコールバックの使い方
トレーニング中にのみ重みを保存する tf.keras.callbacks.ModelCheckpoint
コールバックを作成します。
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1)
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=10,
validation_data=(test_images, test_labels),
callbacks=[cp_callback]) # Pass callback to training
# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.
Epoch 1/10 25/32 [======================>.......] - ETA: 0s - loss: 1.2484 - sparse_categorical_accuracy: 0.6525 Epoch 1: saving model to training_1/cp.ckpt 32/32 [==============================] - 1s 15ms/step - loss: 1.1156 - sparse_categorical_accuracy: 0.6900 - val_loss: 0.7113 - val_sparse_categorical_accuracy: 0.7900 Epoch 2/10 26/32 [=======================>......] - ETA: 0s - loss: 0.4196 - sparse_categorical_accuracy: 0.8810 Epoch 2: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 6ms/step - loss: 0.4216 - sparse_categorical_accuracy: 0.8790 - val_loss: 0.5237 - val_sparse_categorical_accuracy: 0.8380 Epoch 3/10 26/32 [=======================>......] - ETA: 0s - loss: 0.2906 - sparse_categorical_accuracy: 0.9219 Epoch 3: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 6ms/step - loss: 0.2897 - sparse_categorical_accuracy: 0.9250 - val_loss: 0.4878 - val_sparse_categorical_accuracy: 0.8490 Epoch 4/10 26/32 [=======================>......] - ETA: 0s - loss: 0.2088 - sparse_categorical_accuracy: 0.9519 Epoch 4: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 6ms/step - loss: 0.2105 - sparse_categorical_accuracy: 0.9510 - val_loss: 0.4425 - val_sparse_categorical_accuracy: 0.8510 Epoch 5/10 26/32 [=======================>......] - ETA: 0s - loss: 0.1390 - sparse_categorical_accuracy: 0.9724 Epoch 5: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.1443 - sparse_categorical_accuracy: 0.9710 - val_loss: 0.4340 - val_sparse_categorical_accuracy: 0.8600 Epoch 6/10 26/32 [=======================>......] - ETA: 0s - loss: 0.1183 - sparse_categorical_accuracy: 0.9760 Epoch 6: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.1138 - sparse_categorical_accuracy: 0.9780 - val_loss: 0.4266 - val_sparse_categorical_accuracy: 0.8640 Epoch 7/10 25/32 [======================>.......] - ETA: 0s - loss: 0.0788 - sparse_categorical_accuracy: 0.9887 Epoch 7: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.0805 - sparse_categorical_accuracy: 0.9860 - val_loss: 0.4252 - val_sparse_categorical_accuracy: 0.8610 Epoch 8/10 26/32 [=======================>......] - ETA: 0s - loss: 0.0595 - sparse_categorical_accuracy: 0.9976 Epoch 8: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.0624 - sparse_categorical_accuracy: 0.9970 - val_loss: 0.4010 - val_sparse_categorical_accuracy: 0.8680 Epoch 9/10 26/32 [=======================>......] - ETA: 0s - loss: 0.0532 - sparse_categorical_accuracy: 0.9952 Epoch 9: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.0520 - sparse_categorical_accuracy: 0.9940 - val_loss: 0.4168 - val_sparse_categorical_accuracy: 0.8620 Epoch 10/10 27/32 [========================>.....] - ETA: 0s - loss: 0.0383 - sparse_categorical_accuracy: 0.9977 Epoch 10: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.0372 - sparse_categorical_accuracy: 0.9980 - val_loss: 0.4228 - val_sparse_categorical_accuracy: 0.8600 <keras.callbacks.History at 0x7fcadcf2f070>
この結果、エポックごとに更新される一連のTensorFlowチェックポイントファイルが作成されます。
os.listdir(checkpoint_dir)
['cp.ckpt.index', 'cp.ckpt.data-00000-of-00001', 'checkpoint']
2 つのモデルが同じアーキテクチャを共有している限り、それらの間で重みを共有できます。したがって、重みのみからモデルを復元する場合は、元のモデルと同じアーキテクチャでモデルを作成してから、その重みを設定します。
次に、トレーニングされていない新しいモデルを再構築し、テストセットで評価します。トレーニングされていないモデルは、偶然誤差(10% 以下の正解率)で実行されます。
# Create a basic model instance
model = create_model()
# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 2.3168 - sparse_categorical_accuracy: 0.1070 - 154ms/epoch - 5ms/step Untrained model, accuracy: 10.70%
次に、チェックポイントから重みをロードし、再び評価します。
# Loads the weights
model.load_weights(checkpoint_path)
# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4228 - sparse_categorical_accuracy: 0.8600 - 70ms/epoch - 2ms/step Restored model, accuracy: 86.00%
チェックポイントコールバックのオプション
このコールバックには、チェックポイントに一意な名前をつけたり、チェックポイントの頻度を調整するためのオプションがあります。
新しいモデルをトレーニングし、5 エポックごとに一意な名前のチェックポイントを保存します。
# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
batch_size = 32
# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
save_freq=5*batch_size)
# Create a new model instance
model = create_model()
# Save the weights using the `checkpoint_path` format
model.save_weights(checkpoint_path.format(epoch=0))
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=50,
batch_size=batch_size,
callbacks=[cp_callback],
validation_data=(test_images, test_labels),
verbose=0)
Epoch 5: saving model to training_2/cp-0005.ckpt Epoch 10: saving model to training_2/cp-0010.ckpt Epoch 15: saving model to training_2/cp-0015.ckpt Epoch 20: saving model to training_2/cp-0020.ckpt Epoch 25: saving model to training_2/cp-0025.ckpt Epoch 30: saving model to training_2/cp-0030.ckpt Epoch 35: saving model to training_2/cp-0035.ckpt Epoch 40: saving model to training_2/cp-0040.ckpt Epoch 45: saving model to training_2/cp-0045.ckpt Epoch 50: saving model to training_2/cp-0050.ckpt <keras.callbacks.History at 0x7fcad00a6940>
次に、できあがったチェックポイントを確認し、最後のものを選択します。
os.listdir(checkpoint_dir)
['cp-0000.ckpt.index', 'cp-0000.ckpt.data-00000-of-00001', 'cp-0050.ckpt.data-00000-of-00001', 'cp-0045.ckpt.data-00000-of-00001', 'cp-0050.ckpt.index', 'cp-0030.ckpt.index', 'cp-0040.ckpt.index', 'cp-0005.ckpt.data-00000-of-00001', 'cp-0045.ckpt.index', 'cp-0040.ckpt.data-00000-of-00001', 'cp-0005.ckpt.index', 'cp-0010.ckpt.index', 'cp-0020.ckpt.index', 'cp-0015.ckpt.data-00000-of-00001', 'checkpoint', 'cp-0035.ckpt.index', 'cp-0035.ckpt.data-00000-of-00001', 'cp-0020.ckpt.data-00000-of-00001', 'cp-0030.ckpt.data-00000-of-00001', 'cp-0025.ckpt.index', 'cp-0025.ckpt.data-00000-of-00001', 'cp-0015.ckpt.index', 'cp-0010.ckpt.data-00000-of-00001']
latest = tf.train.latest_checkpoint(checkpoint_dir)
latest
'training_2/cp-0050.ckpt'
注意: デフォルトの TensorFlow 形式では、最新の 5 つのチェックポイントのみが保存されます。
テストのため、モデルをリセットし最後のチェックポイントを読み込みます。
# Create a new model instance
model = create_model()
# Load the previously saved weights
model.load_weights(latest)
# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4917 - sparse_categorical_accuracy: 0.8810 - 159ms/epoch - 5ms/step Restored model, accuracy: 88.10%
これらのファイルは何?
上記のコードは、バイナリ形式でトレーニングされた重みのみを含む checkpoint 形式のファイルのコレクションに重みを格納します。チェックポイントには、次のものが含まれます。
- 1 つ以上のモデルの重みのシャード。
- どの重みがどのシャードに格納されているかを示すインデックスファイル。
一台のマシンでモデルをトレーニングしている場合は、接尾辞が .data-00000-of-00001
のシャードが 1 つあります。
手動で重みを保存する
Model.save_weights
メソッドを使用して手動で重みを保存します。デフォルトでは、tf.keras
、特に save_weights
は、.ckpt
拡張子を持つ TensorFlow のチェックポイント形式を使用します (HDF5 に .h5
拡張子を付けて保存する方法については、モデルの保存とシリアル化ガイドを参照してください)。
# Save the weights
model.save_weights('./checkpoints/my_checkpoint')
# Create a new model instance
model = create_model()
# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')
# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.iter WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.decay WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.learning_rate 32/32 - 0s - loss: 0.4917 - sparse_categorical_accuracy: 0.8810 - 162ms/epoch - 5ms/step Restored model, accuracy: 88.10%
モデル全体の保存
model.save
を呼ぶことで、モデルのアーキテクチャや重み、トレーニングの設定を単一のファイル/フォルダに保存できます。これにより、オリジナルの Python コード (*) にアクセスせずにモデルを使えるように、モデルをエクスポートできます。オプティマイザの状態も復旧されるため、中断したところからトレーニングを再開できます。
モデル全体を 2 つの異なるファイル形式 (SavedModel
とHDF5
) で保存できます。TensorFlow SavedModel
形式は、TF2.x のデフォルトのファイル形式ですが、モデルは HDF5
形式で保存できます。モデル全体を 2 つのファイル形式で保存する方法の詳細については、以下の説明をご覧ください。
完全に動作するモデルを保存すると TensorFlow.js (Saved Model、HDF5) で読み込んで、ブラウザ上でトレーニングや実行したり、TensorFlow Lite (Saved Model、HDF5) を用いてモバイルデバイス上で実行できるよう変換することもできるので非常に便利です。
カスタムのオブジェクト (クラスを継承したモデルやレイヤー) は保存や読み込みを行うとき、特別な注意を必要とします。以下のカスタムオブジェクトの保存*を参照してください。
SavedModel フォーマットとして
SavedModel 形式は、モデルをシリアル化するもう 1 つの方法です。この形式で保存されたモデルは、tf.keras.models.load_model
を使用して復元でき、TensorFlow Serving と互換性があります。SavedModel をサービングおよび検査する方法についての詳細は、SavedModel ガイドを参照してください。以下のセクションでは、モデルを保存および復元する手順を示します。
# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# Save the entire model as a SavedModel.
!mkdir -p saved_model
model.save('saved_model/my_model')
Epoch 1/5 32/32 [==============================] - 0s 2ms/step - loss: 1.1613 - sparse_categorical_accuracy: 0.6720 Epoch 2/5 32/32 [==============================] - 0s 2ms/step - loss: 0.4246 - sparse_categorical_accuracy: 0.8820 Epoch 3/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2814 - sparse_categorical_accuracy: 0.9310 Epoch 4/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2200 - sparse_categorical_accuracy: 0.9490 Epoch 5/5 32/32 [==============================] - 0s 2ms/step - loss: 0.1569 - sparse_categorical_accuracy: 0.9670 INFO:tensorflow:Assets written to: saved_model/my_model/assets
SavedModel 形式は、protobuf バイナリと TensorFlow チェックポイントを含むディレクトリです。保存されたモデルディレクトリを調べます。
# my_model directory
ls saved_model
# Contains an assets folder, saved_model.pb, and variables folder.
ls saved_model/my_model
my_model assets keras_metadata.pb saved_model.pb variables
保存したモデルから新しい Keras モデルを再度読み込みます。
new_model = tf.keras.models.load_model('saved_model/my_model')
# Check its architecture
new_model.summary()
Model: "sequential_5" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_10 (Dense) (None, 512) 401920 dropout_5 (Dropout) (None, 512) 0 dense_11 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0 _________________________________________________________________
復元されたモデルは、元のモデルと同じ引数でコンパイルされます。読み込まれたモデルで評価と予測を実行してみてください。
# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
print(new_model.predict(test_images).shape)
32/32 - 0s - loss: 0.4259 - sparse_categorical_accuracy: 0.8580 - 165ms/epoch - 5ms/step Restored model, accuracy: 85.80% 32/32 [==============================] - 0s 1ms/step (1000, 10)
HDF5ファイルとして
Keras は HDF5 の標準に従ったベーシックな保存形式も提供します。
# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# Save the entire model to a HDF5 file.
# The '.h5' extension indicates that the model should be saved to HDF5.
model.save('my_model.h5')
Epoch 1/5 32/32 [==============================] - 0s 2ms/step - loss: 1.1513 - sparse_categorical_accuracy: 0.6720 Epoch 2/5 32/32 [==============================] - 0s 2ms/step - loss: 0.4328 - sparse_categorical_accuracy: 0.8780 Epoch 3/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2859 - sparse_categorical_accuracy: 0.9210 Epoch 4/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2024 - sparse_categorical_accuracy: 0.9500 Epoch 5/5 32/32 [==============================] - 0s 2ms/step - loss: 0.1554 - sparse_categorical_accuracy: 0.9690
保存したファイルを使ってモデルを再作成します。
# Recreate the exact same model, including its weights and the optimizer
new_model = tf.keras.models.load_model('my_model.h5')
# Show the model architecture
new_model.summary()
Model: "sequential_6" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_12 (Dense) (None, 512) 401920 dropout_6 (Dropout) (None, 512) 0 dense_13 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0 _________________________________________________________________
正解率を検査します。
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
32/32 - 0s - loss: 0.4225 - sparse_categorical_accuracy: 0.8660 - 164ms/epoch - 5ms/step Restored model, accuracy: 86.60%
Keras は、アーキテクチャを検査することでモデルを保存します。この手法ではすべてが保存されます。
- 重みの値
- モデルのアーキテクチャ
- モデルのトレーニング構成(
.compile()
メソッドに渡すもの) - あれば、オプティマイザとその状態(中断した所からトレーニングを再開するため)
Keras は v1.x
(tf.compat.v1.train
にあります) のオプティマイザを保存できません。これらはチェックポイントと互換性がないためです。v1.x のオプティマイザでは、オプティマイザの状態を読み込ませてモデルを再度コンパイルする必要があります。
カスタムオブジェクトの保存
SavedModel 形式を使用している場合は、このセクションをスキップできます。HDF5 と SavedModel の主な違いは、HDF5 はオブジェクト構成を使用してモデルアーキテクチャを保存するのに対し、SavedModel は実行グラフを保存することです。したがって、SavedModels は、元のコードを必要とせずに、サブクラス化されたモデルやカスタムレイヤーなどのカスタムオブジェクトを保存できます。
カスタムオブジェクトを HDF5 に保存するには、次の手順を実行します。
- オブジェクトで
get_config
メソッドを定義し、オプションでfrom_config
クラスメソッドを定義します。get_config(self)
は、オブジェクトの再作成に必要なパラメータの JSON シリアル化可能なディクショナリを返します。from_config(cls, config){/code0 }は、<code data-md-type="codespan">get_config
から返された構成を使用して新しいオブジェクトを作成します。デフォルトでは、この関数は構成を初期化 kwargs (return cls(**config)
) として使用します。
- モデルを読み込むときに、オブジェクトを
custom_objects
引数に渡します。引数は、文字列クラス名を Python クラスにマッピングするディクショナリである必要があります。(例:tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})
)
カスタムオブジェクトと get_config
の例については、レイヤーとモデルを最初から作成するチュートリアルを参照してください。
# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.