このページは Cloud Translation API によって翻訳されました。
Switch to English

転移学習と微調整

TensorFlow.orgで見る Google Colabで実行 GitHubでソースを表示する ノートブックをダウンロード

このチュートリアルでは、事前に訓練されたネットワークからの転移学習を使用して、猫と犬の画像を分類する方法を学習します。

事前トレーニング済みモデルは、大規模なデータセット、通常は大規模な画像分類タスクで以前にトレーニングされた保存済みネットワークです。事前学習済みモデルをそのまま使用するか、転移学習を使用して、このモデルを特定のタスクにカスタマイズします。

画像分類のための転移学習の背後にある直観は、モデルが大規模で一般的な十分なデータセットでトレーニングされる場合、このモデルは視覚世界の一般的なモデルとして効果的に機能するということです。その後、大規模なデータセットで大規模なモデルをトレーニングすることにより、ゼロから始めることなく、これらの学習された特徴マップを利用できます。

このノートブックでは、事前トレーニング済みモデルをカスタマイズする2つの方法を試します。

  1. 特徴抽出:以前のネットワークで学習した表現を使用して、新しいサンプルから意味のある特徴を抽出します。最初からトレーニングする新しい分類子を事前トレーニング済みモデルの上に追加するだけで、データセットについて以前に学習した特徴マップを再利用できます。

    モデル全体を(再)トレーニングする必要はありません。基本の畳み込みネットワークには、画像の分類に一般的に役立つ機能がすでに含まれています。ただし、事前トレーニング済みモデルの最後の分類部分は、元の分類タスクに固有であり、その後、モデルがトレーニングされたクラスのセットに固有です。

  2. 微調整:凍結されたモデルベースの上位レイヤーのいくつかをフリーズ解除し、新しく追加された分類子レイヤーとベースモデルの最後のレイヤーの両方を共同でトレーニングします。これにより、ベースモデルの高次の特徴表現を「微調整」して、特定のタスクとの関連性を高めることができます。

一般的な機械学習ワークフローに従います。

  1. データを調べて理解する
  2. この場合はKeras ImageDataGeneratorを使用して、入力パイプラインを構築します。
  3. モデルを作成する
    • 事前トレーニング済みのベースモデル(および事前トレーニング済みの重み)にロードする
    • 上に分類レイヤーを積み重ねます
  4. モデルをトレーニングする
  5. モデルを評価する
pip install -q tf-nightly
 import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf

from tensorflow.keras.preprocessing import image_dataset_from_directory
 

データの前処理

データダウンロード

このチュートリアルでは、猫と犬の数千の画像を含むデータセットを使用します。画像を含むzipファイルをダウンロードして抽出し、 tf.keras.preprocessing.image_dataset_from_directoryユーティリティを使用してトレーニングと検証のためのtf.data.Datasetを作成します。このチュートリアルでは、画像の読み込みについて詳しく学習できます。

 _URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

BATCH_SIZE = 32
IMG_SIZE = (160, 160)

train_dataset = image_dataset_from_directory(train_dir,
                                             shuffle=True,
                                             batch_size=BATCH_SIZE,
                                             image_size=IMG_SIZE)
 
Downloading data from https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
68608000/68606236 [==============================] - 1s 0us/step
Found 2000 files belonging to 2 classes.

 validation_dataset = image_dataset_from_directory(validation_dir,
                                                  shuffle=True,
                                                  batch_size=BATCH_SIZE,
                                                  image_size=IMG_SIZE)
 
Found 1000 files belonging to 2 classes.

トレーニングセットの最初の9つの画像とラベルを表示します。

 class_names = train_dataset.class_names

plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")
 

png

元のデータセットにはテストセットが含まれていないため、テストセットを作成します。そのためには、 tf.data.experimental.cardinality使用して、検証セットで使用可能なデータのバッチ数を決定し、それらの20%をテストセットに移動します。

 val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)
 
 print('Number of validation batches: %d' % tf.data.experimental.cardinality(validation_dataset))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_dataset))
 
Number of validation batches: 26
Number of test batches: 6

パフォーマンスのためにデータセットを構成する

バッファ付きプリフェッチを使用して、I / Oがブロックされることなくディスクからイメージをロードします。この方法の詳細については、 データパフォーマンスガイドを参照してください。

 AUTOTUNE = tf.data.experimental.AUTOTUNE

train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)
 

データ拡張を使用する

大きな画像データセットがない場合は、回転や水平反転などのランダムで現実的な変換をトレーニング画像に適用して、サンプルの多様性を人為的に導入することをお勧めします。これは、トレーニングデータのさまざまな側面にモデルを公開し、 過剰適合を減らすのに役立ちます。このチュートリアルでは、データ拡張について詳しく学ぶことができます。

 data_augmentation = tf.keras.Sequential([
  tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
  tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])
 

これらのレイヤーを同じ画像に繰り返し適用して、結果を見てみましょう。

 for image, _ in train_dataset.take(1):
  plt.figure(figsize=(10, 10))
  first_image = image[0]
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    augmented_image = data_augmentation(tf.expand_dims(first_image, 0))
    plt.imshow(augmented_image[0] / 255)
    plt.axis('off')
 

png

ピクセル値の再スケーリング

tf.keras.applications.MobileNetV2に、ベースモデルとして使用するtf.keras.applications.MobileNetV2をダウンロードします。このモデルは[-1,1]のピクセル値を想定していますが、この時点では、画像のピクセル値は[0-255]です。それらを再スケーリングするには、モデルに含まれている前処理メソッドを使用します。

 preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
 
 rescale = tf.keras.layers.experimental.preprocessing.Rescaling(1./127.5, offset= -1)
 

事前トレーニング済みのconvnetsからベースモデルを作成する

Googleで開発されたMobileNet V2モデルから基本モデルを作成します。これは、1.4Mの画像と1000のクラスで構成される大規模なデータセットであるImageNetデータセットで事前トレーニングされています。 ImageNetのようなカテゴリの多種多様な研究のトレーニングデータセットであるjackfruitsyringe 。この知識の基盤は、特定のデータセットから猫と犬を分類するのに役立ちます。

最初に、特徴抽出に使用するMobileNet V2のレイヤーを選択する必要があります。最後の分類レイヤー(機械学習モデルのほとんどの図は下から上に行くため、「上」にあります)はあまり役に立ちません。代わりに、フラット化操作の前の最後のレイヤーに依存するという一般的な方法に従います。この層を「ボトルネック層」と呼びます。ボトルネックレイヤー機能は、ファイナル/トップレイヤーと比較して、より一般性を保持しています。

最初に、ImageNetでトレーニングされた重みがプリロードされたMobileNet V2モデルをインスタンス化します。 include_top = False引数を指定することにより、上部に分類レイヤーを含まないネットワークをロードします。これは、特徴抽出に最適です。

 # Create the base model from the pre-trained model MobileNet V2
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')
 
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step

この特徴抽出は各変換160x160x3に画像を5x5x1280の機能ブロック。それが画像のバッチの例に対して何をするか見てみましょう:

 image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
print(feature_batch.shape)
 
(32, 5, 5, 1280)

特徴抽出

このステップでは、前のステップで作成したたたみ込みベースをフリーズし、特徴抽出として使用します。さらに、その上に分類子を追加し、最上位の分類子をトレーニングします。

畳み込みベースをフリーズ

モデルをコンパイルしてトレーニングする前に、畳み込みベースをフリーズすることが重要です。 (layer.trainable = Falseを設定して)フリーズすると、特定のレイヤーのウェイトがトレーニング中に更新されなくなります。 MobileNet V2には多くのレイヤーがあるため、モデル全体のtrainableフラグをFalseに設定すると、すべてのレイヤーがフリーズします。

 base_model.trainable = False
 

BatchNormalizationレイヤーに関する重要な注意

多くのモデルにはtf.keras.layers.BatchNormalizationレイヤーが含まれてtf.keras.layers.BatchNormalizationます。このレイヤーは特殊なケースであり、このチュートリアルの後半で示すように、微調整のコンテキストでは注意が必要です。

layer.trainable = Falseを設定すると、 BatchNormalizationレイヤーは推論モードで実行され、平均と分散の統計は更新されません。

微調整を行うためにBatchNormalizationレイヤーを含むモデルをフリーズ解除するときは、ベースモデルを呼び出すときにtraining = Falseを渡して、BatchNormalizationレイヤーを推論モードに維持する必要があります。そうでない場合、トレーニング不可能な重みに適用される更新は、モデルが学習したモデルを破壊します。

詳細については、 転移学習ガイドをご覧ください。

 # Let's take a look at the base model architecture
base_model.summary()
 
Model: "mobilenetv2_1.00_160"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 160, 160, 3) 0                                            
__________________________________________________________________________________________________
Conv1_pad (ZeroPadding2D)       (None, 161, 161, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 80, 80, 32)   864         Conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 80, 80, 32)   128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 80, 80, 32)   0           bn_Conv1[0][0]                   
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 80, 80, 32)   288         Conv1_relu[0][0]                 
__________________________________________________________________________________________________
expanded_conv_depthwise_BN (Bat (None, 80, 80, 32)   128         expanded_conv_depthwise[0][0]    
__________________________________________________________________________________________________
expanded_conv_depthwise_relu (R (None, 80, 80, 32)   0           expanded_conv_depthwise_BN[0][0] 
__________________________________________________________________________________________________
expanded_conv_project (Conv2D)  (None, 80, 80, 16)   512         expanded_conv_depthwise_relu[0][0
__________________________________________________________________________________________________
expanded_conv_project_BN (Batch (None, 80, 80, 16)   64          expanded_conv_project[0][0]      
__________________________________________________________________________________________________
block_1_expand (Conv2D)         (None, 80, 80, 96)   1536        expanded_conv_project_BN[0][0]   
__________________________________________________________________________________________________
block_1_expand_BN (BatchNormali (None, 80, 80, 96)   384         block_1_expand[0][0]             
__________________________________________________________________________________________________
block_1_expand_relu (ReLU)      (None, 80, 80, 96)   0           block_1_expand_BN[0][0]          
__________________________________________________________________________________________________
block_1_pad (ZeroPadding2D)     (None, 81, 81, 96)   0           block_1_expand_relu[0][0]        
__________________________________________________________________________________________________
block_1_depthwise (DepthwiseCon (None, 40, 40, 96)   864         block_1_pad[0][0]                
__________________________________________________________________________________________________
block_1_depthwise_BN (BatchNorm (None, 40, 40, 96)   384         block_1_depthwise[0][0]          
__________________________________________________________________________________________________
block_1_depthwise_relu (ReLU)   (None, 40, 40, 96)   0           block_1_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_1_project (Conv2D)        (None, 40, 40, 24)   2304        block_1_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_1_project_BN (BatchNormal (None, 40, 40, 24)   96          block_1_project[0][0]            
__________________________________________________________________________________________________
block_2_expand (Conv2D)         (None, 40, 40, 144)  3456        block_1_project_BN[0][0]         
__________________________________________________________________________________________________
block_2_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_2_expand[0][0]             
__________________________________________________________________________________________________
block_2_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_2_expand_BN[0][0]          
__________________________________________________________________________________________________
block_2_depthwise (DepthwiseCon (None, 40, 40, 144)  1296        block_2_expand_relu[0][0]        
__________________________________________________________________________________________________
block_2_depthwise_BN (BatchNorm (None, 40, 40, 144)  576         block_2_depthwise[0][0]          
__________________________________________________________________________________________________
block_2_depthwise_relu (ReLU)   (None, 40, 40, 144)  0           block_2_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_2_project (Conv2D)        (None, 40, 40, 24)   3456        block_2_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_2_project_BN (BatchNormal (None, 40, 40, 24)   96          block_2_project[0][0]            
__________________________________________________________________________________________________
block_2_add (Add)               (None, 40, 40, 24)   0           block_1_project_BN[0][0]         
                                                                 block_2_project_BN[0][0]         
__________________________________________________________________________________________________
block_3_expand (Conv2D)         (None, 40, 40, 144)  3456        block_2_add[0][0]                
__________________________________________________________________________________________________
block_3_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_3_expand[0][0]             
__________________________________________________________________________________________________
block_3_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_3_expand_BN[0][0]          
__________________________________________________________________________________________________
block_3_pad (ZeroPadding2D)     (None, 41, 41, 144)  0           block_3_expand_relu[0][0]        
__________________________________________________________________________________________________
block_3_depthwise (DepthwiseCon (None, 20, 20, 144)  1296        block_3_pad[0][0]                
__________________________________________________________________________________________________
block_3_depthwise_BN (BatchNorm (None, 20, 20, 144)  576         block_3_depthwise[0][0]          
__________________________________________________________________________________________________
block_3_depthwise_relu (ReLU)   (None, 20, 20, 144)  0           block_3_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_3_project (Conv2D)        (None, 20, 20, 32)   4608        block_3_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_3_project_BN (BatchNormal (None, 20, 20, 32)   128         block_3_project[0][0]            
__________________________________________________________________________________________________
block_4_expand (Conv2D)         (None, 20, 20, 192)  6144        block_3_project_BN[0][0]         
__________________________________________________________________________________________________
block_4_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_4_expand[0][0]             
__________________________________________________________________________________________________
block_4_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_4_expand_BN[0][0]          
__________________________________________________________________________________________________
block_4_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_4_expand_relu[0][0]        
__________________________________________________________________________________________________
block_4_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_4_depthwise[0][0]          
__________________________________________________________________________________________________
block_4_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_4_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_4_project (Conv2D)        (None, 20, 20, 32)   6144        block_4_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_4_project_BN (BatchNormal (None, 20, 20, 32)   128         block_4_project[0][0]            
__________________________________________________________________________________________________
block_4_add (Add)               (None, 20, 20, 32)   0           block_3_project_BN[0][0]         
                                                                 block_4_project_BN[0][0]         
__________________________________________________________________________________________________
block_5_expand (Conv2D)         (None, 20, 20, 192)  6144        block_4_add[0][0]                
__________________________________________________________________________________________________
block_5_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_5_expand[0][0]             
__________________________________________________________________________________________________
block_5_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_5_expand_BN[0][0]          
__________________________________________________________________________________________________
block_5_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_5_expand_relu[0][0]        
__________________________________________________________________________________________________
block_5_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_5_depthwise[0][0]          
__________________________________________________________________________________________________
block_5_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_5_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_5_project (Conv2D)        (None, 20, 20, 32)   6144        block_5_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_5_project_BN (BatchNormal (None, 20, 20, 32)   128         block_5_project[0][0]            
__________________________________________________________________________________________________
block_5_add (Add)               (None, 20, 20, 32)   0           block_4_add[0][0]                
                                                                 block_5_project_BN[0][0]         
__________________________________________________________________________________________________
block_6_expand (Conv2D)         (None, 20, 20, 192)  6144        block_5_add[0][0]                
__________________________________________________________________________________________________
block_6_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_6_expand[0][0]             
__________________________________________________________________________________________________
block_6_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_6_expand_BN[0][0]          
__________________________________________________________________________________________________
block_6_pad (ZeroPadding2D)     (None, 21, 21, 192)  0           block_6_expand_relu[0][0]        
__________________________________________________________________________________________________
block_6_depthwise (DepthwiseCon (None, 10, 10, 192)  1728        block_6_pad[0][0]                
__________________________________________________________________________________________________
block_6_depthwise_BN (BatchNorm (None, 10, 10, 192)  768         block_6_depthwise[0][0]          
__________________________________________________________________________________________________
block_6_depthwise_relu (ReLU)   (None, 10, 10, 192)  0           block_6_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_6_project (Conv2D)        (None, 10, 10, 64)   12288       block_6_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_6_project_BN (BatchNormal (None, 10, 10, 64)   256         block_6_project[0][0]            
__________________________________________________________________________________________________
block_7_expand (Conv2D)         (None, 10, 10, 384)  24576       block_6_project_BN[0][0]         
__________________________________________________________________________________________________
block_7_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_7_expand[0][0]             
__________________________________________________________________________________________________
block_7_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_7_expand_BN[0][0]          
__________________________________________________________________________________________________
block_7_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_7_expand_relu[0][0]        
__________________________________________________________________________________________________
block_7_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_7_depthwise[0][0]          
__________________________________________________________________________________________________
block_7_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_7_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_7_project (Conv2D)        (None, 10, 10, 64)   24576       block_7_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_7_project_BN (BatchNormal (None, 10, 10, 64)   256         block_7_project[0][0]            
__________________________________________________________________________________________________
block_7_add (Add)               (None, 10, 10, 64)   0           block_6_project_BN[0][0]         
                                                                 block_7_project_BN[0][0]         
__________________________________________________________________________________________________
block_8_expand (Conv2D)         (None, 10, 10, 384)  24576       block_7_add[0][0]                
__________________________________________________________________________________________________
block_8_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_8_expand[0][0]             
__________________________________________________________________________________________________
block_8_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_8_expand_BN[0][0]          
__________________________________________________________________________________________________
block_8_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_8_expand_relu[0][0]        
__________________________________________________________________________________________________
block_8_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_8_depthwise[0][0]          
__________________________________________________________________________________________________
block_8_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_8_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_8_project (Conv2D)        (None, 10, 10, 64)   24576       block_8_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_8_project_BN (BatchNormal (None, 10, 10, 64)   256         block_8_project[0][0]            
__________________________________________________________________________________________________
block_8_add (Add)               (None, 10, 10, 64)   0           block_7_add[0][0]                
                                                                 block_8_project_BN[0][0]         
__________________________________________________________________________________________________
block_9_expand (Conv2D)         (None, 10, 10, 384)  24576       block_8_add[0][0]                
__________________________________________________________________________________________________
block_9_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_9_expand[0][0]             
__________________________________________________________________________________________________
block_9_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_9_expand_BN[0][0]          
__________________________________________________________________________________________________
block_9_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_9_expand_relu[0][0]        
__________________________________________________________________________________________________
block_9_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_9_depthwise[0][0]          
__________________________________________________________________________________________________
block_9_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_9_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_9_project (Conv2D)        (None, 10, 10, 64)   24576       block_9_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_9_project_BN (BatchNormal (None, 10, 10, 64)   256         block_9_project[0][0]            
__________________________________________________________________________________________________
block_9_add (Add)               (None, 10, 10, 64)   0           block_8_add[0][0]                
                                                                 block_9_project_BN[0][0]         
__________________________________________________________________________________________________
block_10_expand (Conv2D)        (None, 10, 10, 384)  24576       block_9_add[0][0]                
__________________________________________________________________________________________________
block_10_expand_BN (BatchNormal (None, 10, 10, 384)  1536        block_10_expand[0][0]            
__________________________________________________________________________________________________
block_10_expand_relu (ReLU)     (None, 10, 10, 384)  0           block_10_expand_BN[0][0]         
__________________________________________________________________________________________________
block_10_depthwise (DepthwiseCo (None, 10, 10, 384)  3456        block_10_expand_relu[0][0]       
__________________________________________________________________________________________________
block_10_depthwise_BN (BatchNor (None, 10, 10, 384)  1536        block_10_depthwise[0][0]         
__________________________________________________________________________________________________
block_10_depthwise_relu (ReLU)  (None, 10, 10, 384)  0           block_10_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_10_project (Conv2D)       (None, 10, 10, 96)   36864       block_10_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_10_project_BN (BatchNorma (None, 10, 10, 96)   384         block_10_project[0][0]           
__________________________________________________________________________________________________
block_11_expand (Conv2D)        (None, 10, 10, 576)  55296       block_10_project_BN[0][0]        
__________________________________________________________________________________________________
block_11_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_11_expand[0][0]            
__________________________________________________________________________________________________
block_11_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_11_expand_BN[0][0]         
__________________________________________________________________________________________________
block_11_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_11_expand_relu[0][0]       
__________________________________________________________________________________________________
block_11_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_11_depthwise[0][0]         
__________________________________________________________________________________________________
block_11_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_11_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_11_project (Conv2D)       (None, 10, 10, 96)   55296       block_11_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_11_project_BN (BatchNorma (None, 10, 10, 96)   384         block_11_project[0][0]           
__________________________________________________________________________________________________
block_11_add (Add)              (None, 10, 10, 96)   0           block_10_project_BN[0][0]        
                                                                 block_11_project_BN[0][0]        
__________________________________________________________________________________________________
block_12_expand (Conv2D)        (None, 10, 10, 576)  55296       block_11_add[0][0]               
__________________________________________________________________________________________________
block_12_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_12_expand[0][0]            
__________________________________________________________________________________________________
block_12_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_12_expand_BN[0][0]         
__________________________________________________________________________________________________
block_12_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_12_expand_relu[0][0]       
__________________________________________________________________________________________________
block_12_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_12_depthwise[0][0]         
__________________________________________________________________________________________________
block_12_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_12_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_12_project (Conv2D)       (None, 10, 10, 96)   55296       block_12_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_12_project_BN (BatchNorma (None, 10, 10, 96)   384         block_12_project[0][0]           
__________________________________________________________________________________________________
block_12_add (Add)              (None, 10, 10, 96)   0           block_11_add[0][0]               
                                                                 block_12_project_BN[0][0]        
__________________________________________________________________________________________________
block_13_expand (Conv2D)        (None, 10, 10, 576)  55296       block_12_add[0][0]               
__________________________________________________________________________________________________
block_13_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_13_expand[0][0]            
__________________________________________________________________________________________________
block_13_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_13_expand_BN[0][0]         
__________________________________________________________________________________________________
block_13_pad (ZeroPadding2D)    (None, 11, 11, 576)  0           block_13_expand_relu[0][0]       
__________________________________________________________________________________________________
block_13_depthwise (DepthwiseCo (None, 5, 5, 576)    5184        block_13_pad[0][0]               
__________________________________________________________________________________________________
block_13_depthwise_BN (BatchNor (None, 5, 5, 576)    2304        block_13_depthwise[0][0]         
__________________________________________________________________________________________________
block_13_depthwise_relu (ReLU)  (None, 5, 5, 576)    0           block_13_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_13_project (Conv2D)       (None, 5, 5, 160)    92160       block_13_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_13_project_BN (BatchNorma (None, 5, 5, 160)    640         block_13_project[0][0]           
__________________________________________________________________________________________________
block_14_expand (Conv2D)        (None, 5, 5, 960)    153600      block_13_project_BN[0][0]        
__________________________________________________________________________________________________
block_14_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_14_expand[0][0]            
__________________________________________________________________________________________________
block_14_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_14_expand_BN[0][0]         
__________________________________________________________________________________________________
block_14_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_14_expand_relu[0][0]       
__________________________________________________________________________________________________
block_14_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_14_depthwise[0][0]         
__________________________________________________________________________________________________
block_14_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_14_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_14_project (Conv2D)       (None, 5, 5, 160)    153600      block_14_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_14_project_BN (BatchNorma (None, 5, 5, 160)    640         block_14_project[0][0]           
__________________________________________________________________________________________________
block_14_add (Add)              (None, 5, 5, 160)    0           block_13_project_BN[0][0]        
                                                                 block_14_project_BN[0][0]        
__________________________________________________________________________________________________
block_15_expand (Conv2D)        (None, 5, 5, 960)    153600      block_14_add[0][0]               
__________________________________________________________________________________________________
block_15_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_15_expand[0][0]            
__________________________________________________________________________________________________
block_15_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_15_expand_BN[0][0]         
__________________________________________________________________________________________________
block_15_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_15_expand_relu[0][0]       
__________________________________________________________________________________________________
block_15_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_15_depthwise[0][0]         
__________________________________________________________________________________________________
block_15_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_15_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_15_project (Conv2D)       (None, 5, 5, 160)    153600      block_15_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_15_project_BN (BatchNorma (None, 5, 5, 160)    640         block_15_project[0][0]           
__________________________________________________________________________________________________
block_15_add (Add)              (None, 5, 5, 160)    0           block_14_add[0][0]               
                                                                 block_15_project_BN[0][0]        
__________________________________________________________________________________________________
block_16_expand (Conv2D)        (None, 5, 5, 960)    153600      block_15_add[0][0]               
__________________________________________________________________________________________________
block_16_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_16_expand[0][0]            
__________________________________________________________________________________________________
block_16_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_16_expand_BN[0][0]         
__________________________________________________________________________________________________
block_16_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_16_expand_relu[0][0]       
__________________________________________________________________________________________________
block_16_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_16_depthwise[0][0]         
__________________________________________________________________________________________________
block_16_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_16_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_16_project (Conv2D)       (None, 5, 5, 320)    307200      block_16_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_16_project_BN (BatchNorma (None, 5, 5, 320)    1280        block_16_project[0][0]           
__________________________________________________________________________________________________
Conv_1 (Conv2D)                 (None, 5, 5, 1280)   409600      block_16_project_BN[0][0]        
__________________________________________________________________________________________________
Conv_1_bn (BatchNormalization)  (None, 5, 5, 1280)   5120        Conv_1[0][0]                     
__________________________________________________________________________________________________
out_relu (ReLU)                 (None, 5, 5, 1280)   0           Conv_1_bn[0][0]                  
==================================================================================================
Total params: 2,257,984
Trainable params: 0
Non-trainable params: 2,257,984
__________________________________________________________________________________________________

分類ヘッドを追加する

機能のブロックから予測を生成するには、 tf.keras.layers.GlobalAveragePooling2Dレイヤーを使用して、画像を1つの1280要素のベクトルに変換することにより、空間的な5x5空間的位置を平均します。

 global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)
 
(32, 1280)

tf.keras.layers.Denseレイヤーを適用して、これらの機能を画像ごとの単一の予測に変換します。この予測はlogitまたは生の予測値として扱われるため、ここではアクティブ化関数は必要ありません。正の数はクラス1を予測し、負の数はクラス0を予測します。

 prediction_layer = tf.keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)
 
(32, 1)

Keras Functional APIを使用して、データ拡張、再スケーリング、base_model、および特徴抽出レイヤーをつなぎ合わせてモデルを構築します 。前述のように、モデルにはBatchNormalizationレイヤーが含まれているため、training = Falseを使用します。

 inputs = tf.keras.Input(shape=(160, 160, 3))
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)
 

モデルをコンパイルする

トレーニングする前にモデルをコンパイルします。 2つのクラスがあるため、モデルは線形出力を提供するため、 from_logits=Truefrom_logits=Trueてバイナリクロスエントロピー損失を使用します。

 base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])
 
 model.summary()
 
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
_________________________________________________________________

MobileNetの2.5Mパラメーターは凍結されていますが、高密度レイヤーには1.2Kのトレーニング可能なパラメーターがあります。これらは、2つのtf.Variableオブジェクトである重みとバイアスに分けられます。

 len(model.trainable_variables)
 
2

モデルをトレーニングする

10エポックのトレーニング後、検証セットの精度が最大94%になるはずです。

 initial_epochs = 10

loss0, accuracy0 = model.evaluate(validation_dataset)
 
26/26 [==============================] - 3s 100ms/step - loss: 0.7056 - accuracy: 0.5669

 print("initial loss: {:.2f}".format(loss0))
print("initial accuracy: {:.2f}".format(accuracy0))
 
initial loss: 0.72
initial accuracy: 0.54

 history = model.fit(train_dataset,
                    epochs=initial_epochs,
                    validation_data=validation_dataset)
 
Epoch 1/10
63/63 [==============================] - 6s 93ms/step - loss: 0.6745 - accuracy: 0.6030 - val_loss: 0.5018 - val_accuracy: 0.7710
Epoch 2/10
63/63 [==============================] - 4s 57ms/step - loss: 0.5142 - accuracy: 0.7145 - val_loss: 0.3587 - val_accuracy: 0.8762
Epoch 3/10
63/63 [==============================] - 3s 54ms/step - loss: 0.3978 - accuracy: 0.8095 - val_loss: 0.2745 - val_accuracy: 0.9183
Epoch 4/10
63/63 [==============================] - 3s 52ms/step - loss: 0.3302 - accuracy: 0.8480 - val_loss: 0.2351 - val_accuracy: 0.9369
Epoch 5/10
63/63 [==============================] - 3s 53ms/step - loss: 0.2973 - accuracy: 0.8730 - val_loss: 0.1903 - val_accuracy: 0.9517
Epoch 6/10
63/63 [==============================] - 3s 53ms/step - loss: 0.2735 - accuracy: 0.8785 - val_loss: 0.1661 - val_accuracy: 0.9579
Epoch 7/10
63/63 [==============================] - 3s 52ms/step - loss: 0.2609 - accuracy: 0.8925 - val_loss: 0.1508 - val_accuracy: 0.9592
Epoch 8/10
63/63 [==============================] - 3s 53ms/step - loss: 0.2340 - accuracy: 0.9090 - val_loss: 0.1412 - val_accuracy: 0.9641
Epoch 9/10
63/63 [==============================] - 3s 52ms/step - loss: 0.2282 - accuracy: 0.8980 - val_loss: 0.1312 - val_accuracy: 0.9653
Epoch 10/10
63/63 [==============================] - 3s 53ms/step - loss: 0.2047 - accuracy: 0.9170 - val_loss: 0.1254 - val_accuracy: 0.9653

学習曲線

MobileNet V2基本モデルを固定特徴抽出器として使用する場合のトレーニングと検証の精度/損失の学習曲線を見てみましょう。

 acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()
 

png

それほどではありませんが、トレーニングメトリックはエポックの平均を報告するのに対し、検証メトリックはエポック後に評価されるため、検証メトリックはわずかに長くトレーニングされたモデルを参照します。

微調整

特徴抽出実験では、MobileNet V2基本モデルの上にあるいくつかのレイヤーのみをトレーニングしました。事前トレーニング済みネットワークの重みは、トレーニング中に更新されませんでした

パフォーマンスをさらに向上させる1つの方法は、追加済みの分類子のトレーニングとともに、事前トレーニング済みモデルの最上位レイヤーの重みをトレーニング(または「微調整」)することです。トレーニングプロセスでは、一般的な機能マップからデータセットに関連付けられている機能に重みを強制的に調整します。

また、MobileNetモデル全体ではなく、少数の最上位層を微調整する必要があります。ほとんどのたたみ込みネットワークでは、上位のレイヤーほど、それはより特殊化されます。最初の数層は、ほとんどすべてのタイプの画像に一般化する非常にシンプルで一般的な機能を学習します。上に行くほど、特徴はモデルがトレーニングされたデータセットに固有のものになります。微調整の目標は、一般的な学習を上書きするのではなく、これらの特殊な機能を新しいデータセットで動作するように適合させることです。

モデルの最上層をフリーズ解除します

必要なのは、 base_model解除し、最下層をトレーニングできないように設定することだけです。次に、モデルを再コンパイルし(これらの変更を有効にするために必要)、トレーニングを再開します。

 base_model.trainable = True
 
 # Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))

# Fine-tune from this layer onwards
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
  layer.trainable =  False
 
Number of layers in the base model:  155

モデルをコンパイルする

はるかに大きなモデルをトレーニングしていて、事前トレーニング済みの重みを再適応させたいので、この段階では低い学習率を使用することが重要です。そうしないと、モデルがすぐにオーバーフィットする可能性があります。

 model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),
              metrics=['accuracy'])
 
 model.summary()
 
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,863,873
Non-trainable params: 395,392
_________________________________________________________________

 len(model.trainable_variables)
 
58

モデルのトレーニングを続ける

以前に収束するようにトレーニングした場合、この手順により、精度が数パーセントポイント向上します。

 fine_tune_epochs = 10
total_epochs =  initial_epochs + fine_tune_epochs

history_fine = model.fit(train_dataset,
                         epochs=total_epochs,
                         initial_epoch=history.epoch[-1],
                         validation_data=validation_dataset)
 
Epoch 10/20
63/63 [==============================] - 8s 122ms/step - loss: 0.1497 - accuracy: 0.9485 - val_loss: 0.0611 - val_accuracy: 0.9752
Epoch 11/20
63/63 [==============================] - 3s 54ms/step - loss: 0.1177 - accuracy: 0.9518 - val_loss: 0.0575 - val_accuracy: 0.9851
Epoch 12/20
63/63 [==============================] - 3s 54ms/step - loss: 0.0856 - accuracy: 0.9712 - val_loss: 0.0459 - val_accuracy: 0.9851
Epoch 13/20
63/63 [==============================] - 3s 54ms/step - loss: 0.1041 - accuracy: 0.9549 - val_loss: 0.0383 - val_accuracy: 0.9851
Epoch 14/20
63/63 [==============================] - 3s 54ms/step - loss: 0.0743 - accuracy: 0.9699 - val_loss: 0.0380 - val_accuracy: 0.9864
Epoch 15/20
63/63 [==============================] - 3s 53ms/step - loss: 0.0785 - accuracy: 0.9708 - val_loss: 0.0429 - val_accuracy: 0.9814
Epoch 16/20
63/63 [==============================] - 3s 54ms/step - loss: 0.0578 - accuracy: 0.9774 - val_loss: 0.0354 - val_accuracy: 0.9839
Epoch 17/20
63/63 [==============================] - 3s 54ms/step - loss: 0.0634 - accuracy: 0.9746 - val_loss: 0.0441 - val_accuracy: 0.9864
Epoch 18/20
63/63 [==============================] - 3s 54ms/step - loss: 0.0561 - accuracy: 0.9830 - val_loss: 0.0343 - val_accuracy: 0.9876
Epoch 19/20
63/63 [==============================] - 3s 54ms/step - loss: 0.0608 - accuracy: 0.9766 - val_loss: 0.0375 - val_accuracy: 0.9864
Epoch 20/20
63/63 [==============================] - 3s 54ms/step - loss: 0.0560 - accuracy: 0.9760 - val_loss: 0.0376 - val_accuracy: 0.9876

MobileNet V2基本モデルの最後の数層を微調整し、その上に分類子をトレーニングするときのトレーニングと検証の精度/損失の学習曲線を見てみましょう。検証の損失はトレーニングの損失よりもはるかに大きいため、過剰適合が発生する可能性があります。

また、新しいトレーニングセットは比較的小さく、元のMobileNet V2データセットに類似しているため、オーバーフィットする可能性があります。

モデルを微調整した後、検証セットの精度はほぼ98%に達します。

 acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']

loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']
 
 plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs-1,initial_epochs-1],
          plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],
         plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()
 

png

評価と予測

最後に、テストセットを使用して、新しいデータに対するモデルのパフォーマンスを確認できます。

 loss, accuracy = model.evaluate(test_dataset)
print('Test accuracy :', accuracy)
 
6/6 [==============================] - 1s 84ms/step - loss: 0.0452 - accuracy: 0.9792
Test accuracy : 0.9791666865348816

これで、このモデルを使用してペットが猫か犬かを予測する準備が整いました。

 #Retrieve a batch of images from the test set
image_batch, label_batch = test_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()

# Apply a sigmoid since our model returns logits
predictions = tf.nn.sigmoid(predictions)
predictions = tf.where(predictions < 0.5, 0, 1)

print('Predictions:\n', predictions.numpy())
print('Labels:\n', label_batch)

plt.figure(figsize=(10, 10))
for i in range(9):
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(image_batch[i].astype("uint8"))
  plt.title(class_names[predictions[i]])
  plt.axis("off")
 
Predictions:
 [1 0 0 0 0 0 1 1 0 1 1 0 1 1 1 1 1 0 0 0 1 1 0 0 1 0 0 1 0 1 1 0]
Labels:
 [1 1 0 0 0 0 1 1 0 1 1 0 1 1 1 1 1 0 0 0 1 1 0 0 1 0 0 1 0 1 1 0]

png

概要

  • 特徴抽出のための事前トレーニング済みモデルの使用 :小さなデータセットを操作する場合、同じドメイン内のより大きなデータセットでトレーニングされたモデルによって学習された特徴を利用することが一般的な方法です。これは、事前にトレーニングされたモデルをインスタンス化し、完全に接続された分類子を上に追加することによって行われます。事前トレーニング済みモデルは「凍結」され、トレーニング中に分類子の重みのみが更新されます。この場合、たたみ込みベースは各画像に関連付けられたすべての特徴を抽出し、抽出された特徴のセットを与えられた画像クラスを決定する分類器をトレーニングしました。

  • 事前トレーニング済みモデルの微調整 :パフォーマンスをさらに向上させるには、事前トレーニング済みモデルの最上位レイヤーを、微調整によって新しいデータセットに再利用することができます。この場合、モデルがデータセットに固有の高レベルの特徴を学習するように重みを調整しました。この手法は通常、トレーニングデータセットが大きく、事前トレーニング済みモデルがトレーニングされた元のデータセットと非常に類似している場合に推奨されます。

詳細については、 転移学習ガイドをご覧ください

 
#
# Copyright (c) 2017 François Chollet                                                                                                                    # IGNORE_COPYRIGHT: cleared by OSS licensing
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.