전이 학습 및 미세 조정

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 소스 보기 노트북 다운로드

설정

import numpy as np
import tensorflow as tf
from tensorflow import keras
2022-12-14 22:32:32.798503: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:32:32.798593: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:32:32.798602: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

시작하기

전이 학습은 한 가지 문제에 대해 학습한 기능을 가져와서 비슷한 새로운 문제에 활용하는 것으로 구성됩니다. 예를 들어, 너구리를 식별하는 방법을 배운 모델의 기능은 너구리를 식별하는 모델을 시작하는 데 유용할 수 있습니다.

전이 학습은 일반적으로 전체 모델을 처음부터 훈련하기에는 데이터세트에 데이터가 너무 적은 작업에 대해 수행됩니다.

딥 러닝의 맥락에서 전이 학습의 가장 일반적인 구현은 다음 워크플로와 같습니다.

  1. 이전에 훈련된 모델에서 레이어를 가져옵니다.
  2. 추후 훈련 라운드 중에 포함된 정보가 손상되지 않도록 동결합니다.
  3. 고정된 레이어 위에 훈련할 수 있는 새 레이어를 추가합니다. 해당 레이어는 기존 기능을 새로운 데이터세트에 대한 예측으로 전환하는 방법을 배웁니다.
  4. 데이터세트에서 새로운 레이어를 훈련합니다.

마지막으로 선택적인 단계는 미세 조정입니다. 이 단계는 위에서 얻은 전체 모델(또는 모델의 일부)을 동결 해제하고 학습률이 매우 낮은 새로운 데이터에 대해 재훈련하는 과정으로 구성됩니다. 이는 사전 훈련된 특성을 새로운 데이터에 점진적으로 적용함으로써 의미 있는 개선을 달성할 수 있습니다.

먼저, Keras의 trainable API에 대해 자세히 살펴보겠습니다. 이 API는 대부분의 전이 학습 및 미세 조정 워크플로의 기초가 됩니다.

그런 다음 ImageNet 데이터세트에서 사전 훈련된 모델을 사용하여 Kaggle "cats vs dogs" 분류 데이터세트에서 재훈련함으로써 일반적인 워크플로를 시연합니다.

이것은 Deep Learning with Python과 2016 블로그 게시물 "building powerful image classification models using very little data"로부터 조정되었습니다.

레이어 동결: trainable 속성의 이해

레이어 및 모델에는 세 가지 가중치 속성이 있습니다.

  • weights는 레이어의 모든 가중치 변수 목록입니다.
  • trainable_weights는 훈련 중 손실을 최소화하기 위해 업데이트(그래디언트 디센트를 통해)되어야 하는 목록입니다.
  • non_trainable_weights는 훈련되지 않은 가중치 변수의 목록입니다. 일반적으로 순방향 전달 중에 모델에 의해 업데이트됩니다.

예제: Dense 레이어에는 2개의 훈련 가능한 가중치가 있습니다(커널 및 바이어스)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

일반적으로 모든 가중치는 훈련이 가능합니다. 훈련할 수 없는 가중치가 있는 유일한 내장 레이어는 BatchNormalization 레이어입니다. 훈련할 수 없는 가중치를 사용하여 훈련 중 입력의 평균 및 분산을 추적합니다. 훈련할 수 없는 가중치를 사용자 정의 레이어에서 사용하는 방법을 배우려면 새 레이어를 처음부터 작성하는 방법을 참조하세요.

예제: BatchNormalization 레이어에는 2개의 훈련 가능한 가중치와 2개의 훈련할 수 없는 가중치가 있습니다

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

레이어 및 모델에는 boolean 속성 trainable도 있습니다. 값은 변경될 수 있습니다. layer.trainableFalse로 설정하면 모든 레이어의 가중치가 훈련 가능에서 훈련 불가능으로 이동합니다. 이를 레이어 "동결"이라고 합니다. 동결 레이어의 상태는 fit()을 사용하거나 trainable_weights에 의존하는 사용자 정의 루프를 사용해 훈련하여 그래디언트 업데이트를 적용할 때도 훈련하는 동안 업데이트되지 않습니다.

예제: trainableFalse로 설정

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

훈련 가능한 가중치가 훈련할 수 없게 되면 훈련 중에 그 값이 더는 업데이트되지 않습니다.

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 1s 1s/step - loss: 0.1928

layer.trainable 속성을 레이어가 추론 모드 또는 훈련 모드에서 순방향 전달을 실행해야 하는지를 제어하는 layer.__call__()의 인수 training과 혼동하지 마세요. 자세한 내용은 Keras FAQ를 참조하세요.

trainable 속성의 재귀 설정

모델 또는 하위 레이어가 있는 레이어에서 trainable = False를 설정하면 모든 하위 레이어도 훈련할 수 없게 됩니다.

예제:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

일반적인 전이 학습 워크플로

이를 통해 Keras에서 일반적인 전이 학습 워크플로를 구현할 수 있습니다.

  1. 기본 모델을 인스턴스화하고 사전 훈련된 가중치를 여기에 로드합니다.
  2. trainable = False를 설정하여 기본 모델의 모든 레이어를 동결합니다.
  3. 기본 모델에서 하나 이상의 레이어 출력 위에 새 모델을 만듭니다.
  4. 새 데이터세트에서 새 모델을 훈련합니다.

보다 가벼운 대안 워크플로는 다음과 같습니다.

  1. 기본 모델을 인스턴스화하고 사전 훈련된 가중치를 여기에 로드합니다.
  2. 이를 통해 새로운 데이터세트를 실행하고 기본 모델의 하나의(또는 여러) 레이어의 출력을 기록합니다. 이를 특성 추출이라 합니다.
  3. 이 출력을 더 작은 새 모델의 입력 데이터로 사용합니다.

이 두 번째 워크플로의 주요 장점은 훈련 epoch마다 한 번이 아니라 한 번의 데이터로 기본 모델을 실행한다는 것입니다. 따라서 훨씬 빠르고 저렴합니다.

그러나 두 번째 워크플로의 문제점은 훈련 중에 새 모델의 입력 데이터를 동적으로 수정할 수 없다는 것입니다. 예를 들어 데이터 증강을 수행할 때 필요합니다. 전이 학습은 일반적으로 새 데이터세트에 데이터가 너무 작아서 전체 규모의 모델을 처음부터 훈련할 수 없는 작업에 사용되며, 이러한 시나리오에서는 데이터 증강이 매우 중요합니다. 따라서 다음 내용에서는 첫 번째 워크플로에 중점을 둘 것입니다.

Keras의 첫 번째 워크플로는 다음과 같습니다.

먼저, 사전 훈련된 가중치를 사용하여 기본 모델을 인스턴스화합니다.

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

그런 다음 기본 모델을 동결합니다.

base_model.trainable = False

맨 위에 새 모델을 만듭니다.

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

새 데이터로 모델을 훈련합니다.

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

미세 조정

모델이 새로운 데이터에 수렴하면 기본 모델의 일부 또는 전부를 동결 해제하고 학습률이 매우 낮은 전체 모델을 전체적으로 재훈련할 수 있습니다.

이 단계는 선택적으로 마지막 단계이며 점진적으로 개선할 수 있습니다. 또한 잠재적으로 빠른 과대적합을 초래할 수 있습니다. 명심하세요.

동결된 레이어가 있는 모델이 수렴하도록 훈련된 에만 이 단계를 수행하는 것이 중요합니다. 무작위로 초기화된 훈련 가능한 레이어를 사전 훈련된 특성을 보유하는 훈련 가능한 레이어와 혼합하는 경우, 무작위로 초기화된 레이어는 훈련 중에 매우 큰 그래디언트 업데이트를 유발하여 사전 훈련된 특성을 파괴합니다.

또한 이 단계에서는 일반적으로 매우 작은 데이터 집합에서 첫 번째 훈련보다 훨씬 더 큰 모델을 훈련하기 때문에, 매우 낮은 학습률을 사용하는 것이 중요합니다. 결과적으로 큰 가중치 업데이트를 적용하면 과도하게 빠른 과대적합의 위험이 있습니다. 여기에서는 사전 훈련된 가중치만 점진적인 방식으로 다시 적용하려고 합니다.

다음은 전체 기본 모델의 미세 조정을 구현하는 방법입니다.

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

compile()trainable에 대한 중요 사항

모델에서 compile()을 호출하는 것은 해당 모델의 동작을 "동결"하기 위한 것입니다. 이는 compile이 다시 호출될 때까지 모델이 컴파일될 때 trainable 속성값이 해당 모델의 수명 동안 유지되어야 함을 의미합니다. 따라서 trainable 값을 변경하면 모델에서 compile()을 다시 호출하여 변경 사항을 적용합니다.

BatchNormalization 레이어에 대한 중요 사항

많은 이미지 모델에는 BatchNormalization 레이어가 포함되어 있습니다. 그 레이어는 상상할 수 있는 모든 수에서 특별한 경우입니다. 다음은 명심해야 할 몇 가지 사항입니다.

  • BatchNormalization에는 훈련 중에 업데이트되는 훈련 불가능한 2개의 가중치가 포함되어 있습니다. 입력의 평균과 분산을 추적하는 변수입니다.
  • bn_layer.trainable = False를 설정하면 BatchNormalization 레이어가 추론 모드에서 실행되며 평균 및 분산 통계가 업데이트되지 않습니다. 가중치 훈련 및 추론/훈련 모드가 직교 개념이므로 일반적으로 다른 레이어의 경우에는 해당하지 않습니다. 그러나 BatchNormalization 레이어의 경우 두 가지가 묶여 있습니다.
  • 미세 조정을 위해 BatchNormalization 레이어를 포함하는 모델을 동결 해제하면 기본 모델을 호출할 때 training=False를 전달하여 BatchNormalization 레이어를 추론 모드로 유지해야 합니다. 그렇지 않으면 훈련 불가능한 가중치에 적용된 업데이트로 인해 모델이 학습한 내용이 갑작스럽게 파괴됩니다.

이 가이드 끝의 엔드 투 엔드 예제에서 이 패턴이 적용되는 것을 볼 수 있습니다.

사용자 정의 훈련 루프를 사용한 전이 학습 및 미세 조정

fit() 대신 자체 저수준 훈련 루프를 사용하는 경우 워크플로는 본질적으로 동일하게 유지됩니다. 그래디언트 업데이트를 적용할 때 목록 model.trainable_weights만 고려해야 합니다.

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

미세 조정의 경우도 마찬가지입니다.

엔드 투 엔드 예제: 고양이 vs 개 데이터세트에서 이미지 분류 모델 미세 조정

이러한 개념을 강화하기 위해 구체적인 엔드 투 엔드 전이 학습 및 미세 조정 예제를 안내합니다. ImageNet에서 사전 훈련된 Xception 모델을 로드하고 Kaggle "cats vs. dogs" 분류 데이터세트에서 사용합니다.

데이터 얻기

먼저 TFDS를 사용해 고양이 vs 개 데이터세트를 가져옵니다. 자체 데이터세트가 있다면 유틸리티 tf.keras.preprocessing.image_dataset_from_directory를 사용하여 클래스별 폴더에 보관된 디스크의 이미지 세트에서 유사한 레이블이 지정된 데이터세트 객체를 생성할 수 있습니다.

전이 학습은 매우 작은 데이터로 작업할 때 가장 유용합니다. 데이터세트를 작게 유지하기 위해 원래 훈련 데이터(25,000개 이미지)의 40%를 훈련에, 10%를 유효성 검사에, 10%를 테스트에 사용합니다.

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

이것은 훈련 데이터세트에서 처음 9개의 이미지입니다. 보시다시피 이미지는 모두 크기가 다릅니다.

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

레이블 1이 "개"이고 레이블 0이 "고양이"임을 알 수 있습니다.

데이터 표준화하기

원시 이미지는 다양한 크기를 가지고 있습니다. 또한 각 픽셀은 0에서 255 사이의 3개의 정숫값(RGB 레벨값)으로 구성됩니다. 이는 신경망에 공급하기 적합하지 않습니다. 2가지를 수행해야 합니다.

  • 고정된 이미지 크기로 표준화합니다. 150x150을 선택합니다.
  • 정상 픽셀 값은 -1 과 1 사이입니다. 모델 자체의 일부로 Normalization 레이어를 사용합니다.

일반적으로 이미 사전 처리된 데이터를 사용하는 모델과 달리 원시 데이터를 입력으로 사용하는 모델을 개발하는 것이 좋습니다. 모델에 사전 처리 된 데이터가 필요한 경우 모델을 내보내 다른 위치(웹 브라우저, 모바일 앱)에서 사용할 때마다 동일한 사전 처리 파이프 라인을 다시 구현해야하기 때문입니다. 이것은 매우 까다로워집니다. 따라서 모델에 도달하기 전에 가능한 최소한의 전처리를 수행해야합니다.

여기에서는 데이터 파이프라인에서 이미지 크기 조정을 수행하고(심층 신경망은 인접한 데이터 배치만 처리할 수 있기 때문에) 입력값 스케일링을 모델의 일부로 생성합니다.

이미지 크기를 150x150으로 조정해 보겠습니다.

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089

또한 데이터를 일괄 처리하고 캐싱 및 프리페치를 사용하여 로딩 속도를 최적화합니다.

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

무작위 데이터 증강 사용하기

큰 이미지 데이터세트가 없는 경우 임의의 수평 뒤집기 또는 작은 임의의 회전과 같이 훈련 이미지에 무작위이지만 사실적인 변형을 적용하여 샘플 다양성을 인위적으로 도입하는 것이 좋습니다. 이것은 과대적합을 늦추면서 모델을 훈련 데이터의 다른 측면에 노출하는 데 도움이 됩니다.

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [layers.RandomFlip("horizontal"), layers.RandomRotation(0.1),]
)

다양한 무작위 변형 후 첫 번째 배치의 첫 번째 이미지가 어떻게 보이는지 시각화해 보겠습니다.

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
2022-12-14 22:32:42.196850: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:5 out of the last 5 calls to <function pfor.<locals>.f at 0x7ff800249af0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:5 out of the last 5 calls to <function pfor.<locals>.f at 0x7ff800249af0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:6 out of the last 6 calls to <function pfor.<locals>.f at 0x7ff8001af820> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:6 out of the last 6 calls to <function pfor.<locals>.f at 0x7ff8001af820> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.

png

모델 빌드하기

이제 앞에서 설명한 청사진을 따르는 모델을 만들어 보겠습니다.

참고 사항:

  • 입력 값(처음에는 [0, 255] 범위)을 [-1, 1] 범위로 조정하기 위해 Rescaling 레이어를 추가합니다.
  • 정규화를 위해 분류 레이어 앞에 Dropout 레이어를 추가합니다.
  • 기본 모델을 호출할 때 training=False를 전달하여 추론 모드에서 실행되므로 미세 조정을 위해 기본 모델을 동결 해제한 후에도 batchnorm 통계가 업데이트되지 않습니다.
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83683744/83683744 [==============================] - 2s 0us/step
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_5 (InputLayer)        [(None, 150, 150, 3)]     0         
                                                                 
 sequential_3 (Sequential)   (None, 150, 150, 3)       0         
                                                                 
 rescaling (Rescaling)       (None, 150, 150, 3)       0         
                                                                 
 xception (Functional)       (None, 5, 5, 2048)        20861480  
                                                                 
 global_average_pooling2d (G  (None, 2048)             0         
 lobalAveragePooling2D)                                          
                                                                 
 dropout (Dropout)           (None, 2048)              0         
                                                                 
 dense_7 (Dense)             (None, 1)                 2049      
                                                                 
=================================================================
Total params: 20,863,529
Trainable params: 2,049
Non-trainable params: 20,861,480
_________________________________________________________________

최상위 레이어 훈련하기

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
144/291 [=============>................] - ETA: 11s - loss: 0.2055 - binary_accuracy: 0.9043
Corrupt JPEG data: 65 extraneous bytes before marker 0xd9
260/291 [=========================>....] - ETA: 2s - loss: 0.1746 - binary_accuracy: 0.9213
Corrupt JPEG data: 239 extraneous bytes before marker 0xd9
274/291 [===========================>..] - ETA: 1s - loss: 0.1708 - binary_accuracy: 0.9230
Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9
277/291 [===========================>..] - ETA: 1s - loss: 0.1697 - binary_accuracy: 0.9235
Corrupt JPEG data: 228 extraneous bytes before marker 0xd9
291/291 [==============================] - ETA: 0s - loss: 0.1684 - binary_accuracy: 0.9242
Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9
291/291 [==============================] - 32s 92ms/step - loss: 0.1684 - binary_accuracy: 0.9242 - val_loss: 0.0826 - val_binary_accuracy: 0.9686
Epoch 2/20
291/291 [==============================] - 25s 85ms/step - loss: 0.1167 - binary_accuracy: 0.9501 - val_loss: 0.0765 - val_binary_accuracy: 0.9703
Epoch 3/20
291/291 [==============================] - 24s 84ms/step - loss: 0.1160 - binary_accuracy: 0.9520 - val_loss: 0.0805 - val_binary_accuracy: 0.9708
Epoch 4/20
291/291 [==============================] - 24s 84ms/step - loss: 0.1054 - binary_accuracy: 0.9539 - val_loss: 0.0725 - val_binary_accuracy: 0.9716
Epoch 5/20
291/291 [==============================] - 24s 84ms/step - loss: 0.1011 - binary_accuracy: 0.9567 - val_loss: 0.0761 - val_binary_accuracy: 0.9699
Epoch 6/20
291/291 [==============================] - 24s 84ms/step - loss: 0.0958 - binary_accuracy: 0.9603 - val_loss: 0.0727 - val_binary_accuracy: 0.9699
Epoch 7/20
291/291 [==============================] - 24s 84ms/step - loss: 0.0964 - binary_accuracy: 0.9603 - val_loss: 0.0752 - val_binary_accuracy: 0.9712
Epoch 8/20
291/291 [==============================] - 24s 84ms/step - loss: 0.1000 - binary_accuracy: 0.9612 - val_loss: 0.0729 - val_binary_accuracy: 0.9721
Epoch 9/20
291/291 [==============================] - 24s 84ms/step - loss: 0.0978 - binary_accuracy: 0.9577 - val_loss: 0.0710 - val_binary_accuracy: 0.9716
Epoch 10/20
291/291 [==============================] - 24s 83ms/step - loss: 0.0918 - binary_accuracy: 0.9615 - val_loss: 0.0716 - val_binary_accuracy: 0.9703
Epoch 11/20
291/291 [==============================] - 24s 84ms/step - loss: 0.0933 - binary_accuracy: 0.9614 - val_loss: 0.0729 - val_binary_accuracy: 0.9712
Epoch 12/20
291/291 [==============================] - 24s 84ms/step - loss: 0.0954 - binary_accuracy: 0.9594 - val_loss: 0.0757 - val_binary_accuracy: 0.9703
Epoch 13/20
291/291 [==============================] - 24s 83ms/step - loss: 0.0888 - binary_accuracy: 0.9624 - val_loss: 0.0752 - val_binary_accuracy: 0.9708
Epoch 14/20
291/291 [==============================] - 24s 84ms/step - loss: 0.0921 - binary_accuracy: 0.9630 - val_loss: 0.0742 - val_binary_accuracy: 0.9695
Epoch 15/20
291/291 [==============================] - 24s 84ms/step - loss: 0.0902 - binary_accuracy: 0.9615 - val_loss: 0.0750 - val_binary_accuracy: 0.9690
Epoch 16/20
291/291 [==============================] - 24s 83ms/step - loss: 0.0930 - binary_accuracy: 0.9616 - val_loss: 0.0723 - val_binary_accuracy: 0.9712
Epoch 17/20
291/291 [==============================] - 24s 84ms/step - loss: 0.0868 - binary_accuracy: 0.9648 - val_loss: 0.0782 - val_binary_accuracy: 0.9682
Epoch 18/20
291/291 [==============================] - 24s 83ms/step - loss: 0.0866 - binary_accuracy: 0.9644 - val_loss: 0.0747 - val_binary_accuracy: 0.9712
Epoch 19/20
291/291 [==============================] - 24s 83ms/step - loss: 0.0881 - binary_accuracy: 0.9636 - val_loss: 0.0834 - val_binary_accuracy: 0.9699
Epoch 20/20
291/291 [==============================] - 25s 84ms/step - loss: 0.0904 - binary_accuracy: 0.9648 - val_loss: 0.0770 - val_binary_accuracy: 0.9712
<keras.callbacks.History at 0x7ff9301266a0>

전체 모델의 미세 조정 수행하기

마지막으로 기본 모델을 동결 해제하고 낮은 학습률로 전체 모델을 전체적으로 훈련합니다.

중요한 것은 기본 모델이 훈련 가능하지만 모델 빌드를 호출할 때 training=False를 전달했으므로 여전히 추론 모드로 실행되고 있다는 것입니다. 이는 내부의 배치 정규화 레이어가 배치 통계를 업데이트하지 않음을 의미합니다. 만약 레이어가 배치 통계를 업데이트한다면 지금까지 모델이 학습한 표현에 혼란을 줄 수 있습니다.

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_5 (InputLayer)        [(None, 150, 150, 3)]     0         
                                                                 
 sequential_3 (Sequential)   (None, 150, 150, 3)       0         
                                                                 
 rescaling (Rescaling)       (None, 150, 150, 3)       0         
                                                                 
 xception (Functional)       (None, 5, 5, 2048)        20861480  
                                                                 
 global_average_pooling2d (G  (None, 2048)             0         
 lobalAveragePooling2D)                                          
                                                                 
 dropout (Dropout)           (None, 2048)              0         
                                                                 
 dense_7 (Dense)             (None, 1)                 2049      
                                                                 
=================================================================
Total params: 20,863,529
Trainable params: 20,809,001
Non-trainable params: 54,528
_________________________________________________________________
Epoch 1/10
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
291/291 [==============================] - 79s 193ms/step - loss: 0.0790 - binary_accuracy: 0.9686 - val_loss: 0.0589 - val_binary_accuracy: 0.9768
Epoch 2/10
291/291 [==============================] - 55s 188ms/step - loss: 0.0538 - binary_accuracy: 0.9786 - val_loss: 0.0495 - val_binary_accuracy: 0.9798
Epoch 3/10
291/291 [==============================] - 55s 188ms/step - loss: 0.0433 - binary_accuracy: 0.9838 - val_loss: 0.0497 - val_binary_accuracy: 0.9807
Epoch 4/10
291/291 [==============================] - 55s 188ms/step - loss: 0.0379 - binary_accuracy: 0.9861 - val_loss: 0.0591 - val_binary_accuracy: 0.9772
Epoch 5/10
291/291 [==============================] - 55s 189ms/step - loss: 0.0268 - binary_accuracy: 0.9902 - val_loss: 0.0540 - val_binary_accuracy: 0.9794
Epoch 6/10
291/291 [==============================] - 55s 188ms/step - loss: 0.0180 - binary_accuracy: 0.9938 - val_loss: 0.0592 - val_binary_accuracy: 0.9802
Epoch 7/10
291/291 [==============================] - 55s 189ms/step - loss: 0.0156 - binary_accuracy: 0.9947 - val_loss: 0.0518 - val_binary_accuracy: 0.9837
Epoch 8/10
291/291 [==============================] - 55s 188ms/step - loss: 0.0151 - binary_accuracy: 0.9939 - val_loss: 0.0467 - val_binary_accuracy: 0.9832
Epoch 9/10
291/291 [==============================] - 55s 189ms/step - loss: 0.0109 - binary_accuracy: 0.9962 - val_loss: 0.0511 - val_binary_accuracy: 0.9824
Epoch 10/10
291/291 [==============================] - 55s 188ms/step - loss: 0.0103 - binary_accuracy: 0.9969 - val_loss: 0.0497 - val_binary_accuracy: 0.9819
<keras.callbacks.History at 0x7ff7e4297af0>

10 epoch 후에, 미세 조정이 크게 개선됩니다.