![]() |
![]() |
![]() |
![]() |
Overview
In this codelab you'll train a simple image classification model on the CIFAR10 dataset, and then use the "membership inference attack" against this model to assess if the attacker is able to "guess" whether a particular sample was present in the training set. You will use the TF Privacy Report to visualize results from multiple models and model checkpoints.
Setup
import numpy as np
from typing import Tuple
from scipy import special
from sklearn import metrics
import tensorflow as tf
import tensorflow_datasets as tfds
# Set verbosity.
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
from sklearn.exceptions import ConvergenceWarning
import warnings
warnings.simplefilter(action="ignore", category=ConvergenceWarning)
warnings.simplefilter(action="ignore", category=FutureWarning)
2022-08-11 09:15:30.375328: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2022-08-11 09:15:30.997254: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory 2022-08-11 09:15:30.997548: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory 2022-08-11 09:15:30.997562: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Install TensorFlow Privacy.
pip install tensorflow_privacy
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyMetric
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import privacy_report
import tensorflow_privacy
Train two models, with privacy metrics
This section trains a pair of keras.Model
classifiers on the CIFAR-10
dataset. During the training process it collects privacy metrics, that will be used to generate reports in the bext section.
The first step is to define some hyperparameters:
dataset = 'cifar10'
num_classes = 10
activation = 'relu'
num_conv = 3
batch_size=50
epochs_per_report = 2
total_epochs = 50
lr = 0.001
Next, load the dataset. There's nothing privacy-specific in this code.
print('Loading the dataset.')
train_ds = tfds.as_numpy(
tfds.load(dataset, split=tfds.Split.TRAIN, batch_size=-1))
test_ds = tfds.as_numpy(
tfds.load(dataset, split=tfds.Split.TEST, batch_size=-1))
x_train = train_ds['image'].astype('float32') / 255.
y_train_indices = train_ds['label'][:, np.newaxis]
x_test = test_ds['image'].astype('float32') / 255.
y_test_indices = test_ds['label'][:, np.newaxis]
# Convert class vectors to binary class matrices.
y_train = tf.keras.utils.to_categorical(y_train_indices, num_classes)
y_test = tf.keras.utils.to_categorical(y_test_indices, num_classes)
input_shape = x_train.shape[1:]
assert x_train.shape[0] % batch_size == 0, "The tensorflow_privacy optimizer doesn't handle partial batches"
Loading the dataset.
Next define a function to build the models.
def small_cnn(input_shape: Tuple[int],
num_classes: int,
num_conv: int,
activation: str = 'relu') -> tf.keras.models.Sequential:
"""Setup a small CNN for image classification.
Args:
input_shape: Integer tuple for the shape of the images.
num_classes: Number of prediction classes.
num_conv: Number of convolutional layers.
activation: The activation function to use for conv and dense layers.
Returns:
The Keras model.
"""
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Input(shape=input_shape))
# Conv layers
for _ in range(num_conv):
model.add(tf.keras.layers.Conv2D(32, (3, 3), activation=activation))
model.add(tf.keras.layers.MaxPooling2D())
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(64, activation=activation))
model.add(tf.keras.layers.Dense(num_classes))
model.compile(
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
metrics=['accuracy'])
return model
Build two three-layer CNN models using that function.
Configure the first to use a basic SGD optimizer, an the second to use a differentially private optimizer (tf_privacy.DPKerasAdamOptimizer
), so you can compare the results.
model_2layers = small_cnn(
input_shape, num_classes, num_conv=2, activation=activation)
model_3layers = small_cnn(
input_shape, num_classes, num_conv=3, activation=activation)
Define a callback to collect privacy metrics
Next define a keras.callbacks.Callback
to periorically run some privacy attacks against the model, and log the results.
The keras fit
method will call the on_epoch_end
method after each training epoch. The n
argument is the (0-based) epoch number.
You could implement this procedure by writing a loop that repeatedly calls Model.fit(..., epochs=epochs_per_report)
and runs the attack code. The callback is used here just because it gives a clear separation between the training logic, and the privacy evaluation logic.
class PrivacyMetrics(tf.keras.callbacks.Callback):
def __init__(self, epochs_per_report, model_name):
self.epochs_per_report = epochs_per_report
self.model_name = model_name
self.attack_results = []
def on_epoch_end(self, epoch, logs=None):
epoch = epoch+1
if epoch % self.epochs_per_report != 0:
return
print(f'\nRunning privacy report for epoch: {epoch}\n')
logits_train = self.model.predict(x_train, batch_size=batch_size)
logits_test = self.model.predict(x_test, batch_size=batch_size)
prob_train = special.softmax(logits_train, axis=1)
prob_test = special.softmax(logits_test, axis=1)
# Add metadata to generate a privacy report.
privacy_report_metadata = PrivacyReportMetadata(
# Show the validation accuracy on the plot
# It's what you send to train_accuracy that gets plotted.
accuracy_train=logs['val_accuracy'],
accuracy_test=logs['val_accuracy'],
epoch_num=epoch,
model_variant_label=self.model_name)
attack_results = mia.run_attacks(
AttackInputData(
labels_train=y_train_indices[:, 0],
labels_test=y_test_indices[:, 0],
probs_train=prob_train,
probs_test=prob_test),
SlicingSpec(entire_dataset=True, by_class=True),
attack_types=(AttackType.THRESHOLD_ATTACK,
AttackType.LOGISTIC_REGRESSION),
privacy_report_metadata=privacy_report_metadata)
self.attack_results.append(attack_results)
Train the models
The next code block trains the two models. The all_reports
list is used to collect all the results from all the models' training runs. The individual reports are tagged witht the model_name
, so there's no confusion about which model generated which report.
all_reports = []
callback = PrivacyMetrics(epochs_per_report, "2 Layers")
history = model_2layers.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=total_epochs,
validation_data=(x_test, y_test),
callbacks=[callback],
shuffle=True)
all_reports.extend(callback.attack_results)
Epoch 1/50 1000/1000 [==============================] - 5s 3ms/step - loss: 1.5489 - accuracy: 0.4400 - val_loss: 1.3055 - val_accuracy: 0.5434 Epoch 2/50 992/1000 [============================>.] - ETA: 0s - loss: 1.2286 - accuracy: 0.5658 Running privacy report for epoch: 2 1000/1000 [==============================] - 2s 1ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 1.2281 - accuracy: 0.5660 - val_loss: 1.1349 - val_accuracy: 0.6052 Epoch 3/50 1000/1000 [==============================] - 3s 3ms/step - loss: 1.0912 - accuracy: 0.6191 - val_loss: 1.0731 - val_accuracy: 0.6275 Epoch 4/50 993/1000 [============================>.] - ETA: 0s - loss: 1.0026 - accuracy: 0.6516 Running privacy report for epoch: 4 1000/1000 [==============================] - 2s 1ms/step 200/200 [==============================] - 0s 1ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 1.0020 - accuracy: 0.6517 - val_loss: 0.9989 - val_accuracy: 0.6555 Epoch 5/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.9483 - accuracy: 0.6682 - val_loss: 0.9970 - val_accuracy: 0.6518 Epoch 6/50 982/1000 [============================>.] - ETA: 0s - loss: 0.8951 - accuracy: 0.6887 Running privacy report for epoch: 6 1000/1000 [==============================] - 1s 1ms/step 200/200 [==============================] - 0s 1ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 0.8956 - accuracy: 0.6883 - val_loss: 0.9491 - val_accuracy: 0.6712 Epoch 7/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.8582 - accuracy: 0.6999 - val_loss: 0.9541 - val_accuracy: 0.6687 Epoch 8/50 991/1000 [============================>.] - ETA: 0s - loss: 0.8201 - accuracy: 0.7155 Running privacy report for epoch: 8 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 0.8200 - accuracy: 0.7154 - val_loss: 0.9357 - val_accuracy: 0.6827 Epoch 9/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.7896 - accuracy: 0.7256 - val_loss: 0.9699 - val_accuracy: 0.6749 Epoch 10/50 992/1000 [============================>.] - ETA: 0s - loss: 0.7588 - accuracy: 0.7370 Running privacy report for epoch: 10 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 1ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 0.7591 - accuracy: 0.7371 - val_loss: 0.9192 - val_accuracy: 0.6874 Epoch 11/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.7360 - accuracy: 0.7445 - val_loss: 0.9426 - val_accuracy: 0.6868 Epoch 12/50 994/1000 [============================>.] - ETA: 0s - loss: 0.7113 - accuracy: 0.7505 Running privacy report for epoch: 12 1000/1000 [==============================] - 1s 1ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 0.7114 - accuracy: 0.7505 - val_loss: 0.9126 - val_accuracy: 0.6925 Epoch 13/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.6860 - accuracy: 0.7589 - val_loss: 0.9031 - val_accuracy: 0.6994 Epoch 14/50 991/1000 [============================>.] - ETA: 0s - loss: 0.6674 - accuracy: 0.7668 Running privacy report for epoch: 14 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 0.6686 - accuracy: 0.7665 - val_loss: 0.9265 - val_accuracy: 0.6958 Epoch 15/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.6462 - accuracy: 0.7744 - val_loss: 0.9466 - val_accuracy: 0.6876 Epoch 16/50 990/1000 [============================>.] - ETA: 0s - loss: 0.6308 - accuracy: 0.7791 Running privacy report for epoch: 16 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 0.6313 - accuracy: 0.7792 - val_loss: 0.9245 - val_accuracy: 0.6987 Epoch 17/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.6119 - accuracy: 0.7864 - val_loss: 0.9866 - val_accuracy: 0.6871 Epoch 18/50 998/1000 [============================>.] - ETA: 0s - loss: 0.5914 - accuracy: 0.7923 Running privacy report for epoch: 18 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 1ms/step 1000/1000 [==============================] - 13s 14ms/step - loss: 0.5913 - accuracy: 0.7923 - val_loss: 0.9489 - val_accuracy: 0.6913 Epoch 19/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.5717 - accuracy: 0.7994 - val_loss: 0.9975 - val_accuracy: 0.6922 Epoch 20/50 993/1000 [============================>.] - ETA: 0s - loss: 0.5595 - accuracy: 0.8042 Running privacy report for epoch: 20 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 1ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.5594 - accuracy: 0.8042 - val_loss: 0.9910 - val_accuracy: 0.6893 Epoch 21/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.5445 - accuracy: 0.8090 - val_loss: 1.0158 - val_accuracy: 0.6835 Epoch 22/50 988/1000 [============================>.] - ETA: 0s - loss: 0.5278 - accuracy: 0.8133 Running privacy report for epoch: 22 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.5295 - accuracy: 0.8128 - val_loss: 1.0173 - val_accuracy: 0.6890 Epoch 23/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.5188 - accuracy: 0.8177 - val_loss: 1.0150 - val_accuracy: 0.6950 Epoch 24/50 993/1000 [============================>.] - ETA: 0s - loss: 0.5067 - accuracy: 0.8207 Running privacy report for epoch: 24 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.5072 - accuracy: 0.8206 - val_loss: 1.0205 - val_accuracy: 0.6970 Epoch 25/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.4915 - accuracy: 0.8276 - val_loss: 1.0376 - val_accuracy: 0.6966 Epoch 26/50 991/1000 [============================>.] - ETA: 0s - loss: 0.4770 - accuracy: 0.8308 Running privacy report for epoch: 26 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.4775 - accuracy: 0.8306 - val_loss: 1.1048 - val_accuracy: 0.6791 Epoch 27/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.4711 - accuracy: 0.8326 - val_loss: 1.1257 - val_accuracy: 0.6826 Epoch 28/50 992/1000 [============================>.] - ETA: 0s - loss: 0.4558 - accuracy: 0.8376 Running privacy report for epoch: 28 1000/1000 [==============================] - 1s 1ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.4560 - accuracy: 0.8375 - val_loss: 1.0947 - val_accuracy: 0.6957 Epoch 29/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.4422 - accuracy: 0.8426 - val_loss: 1.1163 - val_accuracy: 0.6948 Epoch 30/50 989/1000 [============================>.] - ETA: 0s - loss: 0.4334 - accuracy: 0.8477 Running privacy report for epoch: 30 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.4340 - accuracy: 0.8474 - val_loss: 1.1556 - val_accuracy: 0.6861 Epoch 31/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.4176 - accuracy: 0.8504 - val_loss: 1.1577 - val_accuracy: 0.6895 Epoch 32/50 981/1000 [============================>.] - ETA: 0s - loss: 0.4092 - accuracy: 0.8534 Running privacy report for epoch: 32 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 1ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.4093 - accuracy: 0.8530 - val_loss: 1.1948 - val_accuracy: 0.6876 Epoch 33/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.3998 - accuracy: 0.8560 - val_loss: 1.1922 - val_accuracy: 0.6845 Epoch 34/50 991/1000 [============================>.] - ETA: 0s - loss: 0.3885 - accuracy: 0.8610 Running privacy report for epoch: 34 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.3891 - accuracy: 0.8607 - val_loss: 1.2492 - val_accuracy: 0.6855 Epoch 35/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.3857 - accuracy: 0.8619 - val_loss: 1.2948 - val_accuracy: 0.6796 Epoch 36/50 985/1000 [============================>.] - ETA: 0s - loss: 0.3703 - accuracy: 0.8668 Running privacy report for epoch: 36 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 1ms/step 1000/1000 [==============================] - 13s 14ms/step - loss: 0.3699 - accuracy: 0.8670 - val_loss: 1.2682 - val_accuracy: 0.6859 Epoch 37/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.3611 - accuracy: 0.8714 - val_loss: 1.3642 - val_accuracy: 0.6740 Epoch 38/50 991/1000 [============================>.] - ETA: 0s - loss: 0.3462 - accuracy: 0.8765 Running privacy report for epoch: 38 1000/1000 [==============================] - 1s 1ms/step 200/200 [==============================] - 0s 1ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.3464 - accuracy: 0.8763 - val_loss: 1.3078 - val_accuracy: 0.6831 Epoch 39/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.3435 - accuracy: 0.8761 - val_loss: 1.3791 - val_accuracy: 0.6833 Epoch 40/50 980/1000 [============================>.] - ETA: 0s - loss: 0.3320 - accuracy: 0.8805 Running privacy report for epoch: 40 1000/1000 [==============================] - 1s 1ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 0.3331 - accuracy: 0.8799 - val_loss: 1.3722 - val_accuracy: 0.6813 Epoch 41/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.3299 - accuracy: 0.8815 - val_loss: 1.4196 - val_accuracy: 0.6821 Epoch 42/50 994/1000 [============================>.] - ETA: 0s - loss: 0.3148 - accuracy: 0.8855 Running privacy report for epoch: 42 1000/1000 [==============================] - 2s 1ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 0.3149 - accuracy: 0.8855 - val_loss: 1.4826 - val_accuracy: 0.6766 Epoch 43/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.3131 - accuracy: 0.8856 - val_loss: 1.4955 - val_accuracy: 0.6730 Epoch 44/50 983/1000 [============================>.] - ETA: 0s - loss: 0.3008 - accuracy: 0.8922 Running privacy report for epoch: 44 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.3016 - accuracy: 0.8920 - val_loss: 1.5527 - val_accuracy: 0.6743 Epoch 45/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.2967 - accuracy: 0.8937 - val_loss: 1.5531 - val_accuracy: 0.6690 Epoch 46/50 990/1000 [============================>.] - ETA: 0s - loss: 0.2902 - accuracy: 0.8953 Running privacy report for epoch: 46 1000/1000 [==============================] - 2s 1ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.2901 - accuracy: 0.8953 - val_loss: 1.6468 - val_accuracy: 0.6681 Epoch 47/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.2818 - accuracy: 0.8991 - val_loss: 1.6265 - val_accuracy: 0.6739 Epoch 48/50 981/1000 [============================>.] - ETA: 0s - loss: 0.2718 - accuracy: 0.9016 Running privacy report for epoch: 48 1000/1000 [==============================] - 1s 1ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.2723 - accuracy: 0.9015 - val_loss: 1.6512 - val_accuracy: 0.6729 Epoch 49/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.2724 - accuracy: 0.9004 - val_loss: 1.7084 - val_accuracy: 0.6721 Epoch 50/50 995/1000 [============================>.] - ETA: 0s - loss: 0.2593 - accuracy: 0.9047 Running privacy report for epoch: 50 1000/1000 [==============================] - 1s 1ms/step 200/200 [==============================] - 0s 1ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 0.2594 - accuracy: 0.9047 - val_loss: 1.7358 - val_accuracy: 0.6724
callback = PrivacyMetrics(epochs_per_report, "3 Layers")
history = model_3layers.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=total_epochs,
validation_data=(x_test, y_test),
callbacks=[callback],
shuffle=True)
all_reports.extend(callback.attack_results)
Epoch 1/50 1000/1000 [==============================] - 4s 3ms/step - loss: 1.6857 - accuracy: 0.3792 - val_loss: 1.4281 - val_accuracy: 0.4705 Epoch 2/50 984/1000 [============================>.] - ETA: 0s - loss: 1.3752 - accuracy: 0.5020 Running privacy report for epoch: 2 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 1.3741 - accuracy: 0.5024 - val_loss: 1.3383 - val_accuracy: 0.5206 Epoch 3/50 1000/1000 [==============================] - 3s 3ms/step - loss: 1.2677 - accuracy: 0.5451 - val_loss: 1.2325 - val_accuracy: 0.5586 Epoch 4/50 994/1000 [============================>.] - ETA: 0s - loss: 1.1839 - accuracy: 0.5782 Running privacy report for epoch: 4 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 1.1836 - accuracy: 0.5781 - val_loss: 1.1777 - val_accuracy: 0.5782 Epoch 5/50 1000/1000 [==============================] - 3s 3ms/step - loss: 1.1109 - accuracy: 0.6036 - val_loss: 1.0978 - val_accuracy: 0.6111 Epoch 6/50 992/1000 [============================>.] - ETA: 0s - loss: 1.0550 - accuracy: 0.6265 Running privacy report for epoch: 6 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 1.0553 - accuracy: 0.6263 - val_loss: 1.0666 - val_accuracy: 0.6237 Epoch 7/50 1000/1000 [==============================] - 3s 3ms/step - loss: 1.0103 - accuracy: 0.6410 - val_loss: 1.0321 - val_accuracy: 0.6362 Epoch 8/50 993/1000 [============================>.] - ETA: 0s - loss: 0.9677 - accuracy: 0.6571 Running privacy report for epoch: 8 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 0.9686 - accuracy: 0.6568 - val_loss: 1.0183 - val_accuracy: 0.6400 Epoch 9/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.9335 - accuracy: 0.6708 - val_loss: 0.9733 - val_accuracy: 0.6560 Epoch 10/50 988/1000 [============================>.] - ETA: 0s - loss: 0.9057 - accuracy: 0.6799 Running privacy report for epoch: 10 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 0.9055 - accuracy: 0.6799 - val_loss: 0.9580 - val_accuracy: 0.6668 Epoch 11/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.8814 - accuracy: 0.6891 - val_loss: 0.9386 - val_accuracy: 0.6728 Epoch 12/50 986/1000 [============================>.] - ETA: 0s - loss: 0.8547 - accuracy: 0.6957 Running privacy report for epoch: 12 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 0.8548 - accuracy: 0.6958 - val_loss: 0.9416 - val_accuracy: 0.6680 Epoch 13/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.8310 - accuracy: 0.7065 - val_loss: 0.9164 - val_accuracy: 0.6794 Epoch 14/50 994/1000 [============================>.] - ETA: 0s - loss: 0.8151 - accuracy: 0.7108 Running privacy report for epoch: 14 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 0.8156 - accuracy: 0.7107 - val_loss: 0.9930 - val_accuracy: 0.6525 Epoch 15/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.7929 - accuracy: 0.7185 - val_loss: 0.9061 - val_accuracy: 0.6841 Epoch 16/50 1000/1000 [==============================] - ETA: 0s - loss: 0.7757 - accuracy: 0.7266 Running privacy report for epoch: 16 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 0.7757 - accuracy: 0.7266 - val_loss: 0.9097 - val_accuracy: 0.6889 Epoch 17/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.7617 - accuracy: 0.7302 - val_loss: 0.9195 - val_accuracy: 0.6835 Epoch 18/50 993/1000 [============================>.] - ETA: 0s - loss: 0.7412 - accuracy: 0.7380 Running privacy report for epoch: 18 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.7415 - accuracy: 0.7378 - val_loss: 0.9179 - val_accuracy: 0.6825 Epoch 19/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.7346 - accuracy: 0.7388 - val_loss: 0.9075 - val_accuracy: 0.6901 Epoch 20/50 994/1000 [============================>.] - ETA: 0s - loss: 0.7229 - accuracy: 0.7433 Running privacy report for epoch: 20 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 0.7222 - accuracy: 0.7436 - val_loss: 0.9032 - val_accuracy: 0.6905 Epoch 21/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.7079 - accuracy: 0.7510 - val_loss: 0.9201 - val_accuracy: 0.6862 Epoch 22/50 1000/1000 [==============================] - ETA: 0s - loss: 0.7002 - accuracy: 0.7505 Running privacy report for epoch: 22 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 0.7002 - accuracy: 0.7505 - val_loss: 0.9163 - val_accuracy: 0.6899 Epoch 23/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.6854 - accuracy: 0.7582 - val_loss: 0.9310 - val_accuracy: 0.6873 Epoch 24/50 990/1000 [============================>.] - ETA: 0s - loss: 0.6733 - accuracy: 0.7621 Running privacy report for epoch: 24 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 13s 13ms/step - loss: 0.6739 - accuracy: 0.7619 - val_loss: 0.8897 - val_accuracy: 0.7008 Epoch 25/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.6706 - accuracy: 0.7612 - val_loss: 0.8804 - val_accuracy: 0.7021 Epoch 26/50 995/1000 [============================>.] - ETA: 0s - loss: 0.6563 - accuracy: 0.7676 Running privacy report for epoch: 26 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.6565 - accuracy: 0.7675 - val_loss: 0.9391 - val_accuracy: 0.6837 Epoch 27/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.6507 - accuracy: 0.7696 - val_loss: 0.9059 - val_accuracy: 0.6998 Epoch 28/50 985/1000 [============================>.] - ETA: 0s - loss: 0.6394 - accuracy: 0.7724 Running privacy report for epoch: 28 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.6392 - accuracy: 0.7724 - val_loss: 0.9589 - val_accuracy: 0.6819 Epoch 29/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.6320 - accuracy: 0.7764 - val_loss: 0.9041 - val_accuracy: 0.7046 Epoch 30/50 985/1000 [============================>.] - ETA: 0s - loss: 0.6216 - accuracy: 0.7806 Running privacy report for epoch: 30 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.6218 - accuracy: 0.7806 - val_loss: 0.9187 - val_accuracy: 0.6995 Epoch 31/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.6187 - accuracy: 0.7803 - val_loss: 0.9133 - val_accuracy: 0.6969 Epoch 32/50 990/1000 [============================>.] - ETA: 0s - loss: 0.6079 - accuracy: 0.7847 Running privacy report for epoch: 32 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.6086 - accuracy: 0.7845 - val_loss: 0.9354 - val_accuracy: 0.6970 Epoch 33/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.6051 - accuracy: 0.7853 - val_loss: 0.9219 - val_accuracy: 0.6944 Epoch 34/50 989/1000 [============================>.] - ETA: 0s - loss: 0.5973 - accuracy: 0.7878 Running privacy report for epoch: 34 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.5975 - accuracy: 0.7879 - val_loss: 0.9085 - val_accuracy: 0.7054 Epoch 35/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.5912 - accuracy: 0.7901 - val_loss: 0.9090 - val_accuracy: 0.7068 Epoch 36/50 1000/1000 [==============================] - ETA: 0s - loss: 0.5849 - accuracy: 0.7901 Running privacy report for epoch: 36 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.5849 - accuracy: 0.7901 - val_loss: 0.9203 - val_accuracy: 0.7017 Epoch 37/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.5792 - accuracy: 0.7947 - val_loss: 0.9327 - val_accuracy: 0.6967 Epoch 38/50 985/1000 [============================>.] - ETA: 0s - loss: 0.5710 - accuracy: 0.7953 Running privacy report for epoch: 38 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.5710 - accuracy: 0.7952 - val_loss: 0.9674 - val_accuracy: 0.6986 Epoch 39/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.5666 - accuracy: 0.7987 - val_loss: 0.9880 - val_accuracy: 0.6884 Epoch 40/50 986/1000 [============================>.] - ETA: 0s - loss: 0.5582 - accuracy: 0.8008 Running privacy report for epoch: 40 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.5592 - accuracy: 0.8005 - val_loss: 0.9379 - val_accuracy: 0.6990 Epoch 41/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.5574 - accuracy: 0.8019 - val_loss: 0.9314 - val_accuracy: 0.7021 Epoch 42/50 1000/1000 [==============================] - ETA: 0s - loss: 0.5496 - accuracy: 0.8029 Running privacy report for epoch: 42 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.5496 - accuracy: 0.8029 - val_loss: 0.9624 - val_accuracy: 0.6985 Epoch 43/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.5489 - accuracy: 0.8053 - val_loss: 0.9367 - val_accuracy: 0.7037 Epoch 44/50 985/1000 [============================>.] - ETA: 0s - loss: 0.5399 - accuracy: 0.8077 Running privacy report for epoch: 44 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.5398 - accuracy: 0.8079 - val_loss: 0.9536 - val_accuracy: 0.7002 Epoch 45/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.5390 - accuracy: 0.8079 - val_loss: 0.9832 - val_accuracy: 0.6908 Epoch 46/50 983/1000 [============================>.] - ETA: 0s - loss: 0.5340 - accuracy: 0.8099 Running privacy report for epoch: 46 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.5338 - accuracy: 0.8098 - val_loss: 0.9792 - val_accuracy: 0.6969 Epoch 47/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.5275 - accuracy: 0.8129 - val_loss: 0.9721 - val_accuracy: 0.7000 Epoch 48/50 992/1000 [============================>.] - ETA: 0s - loss: 0.5223 - accuracy: 0.8124 Running privacy report for epoch: 48 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.5224 - accuracy: 0.8124 - val_loss: 0.9630 - val_accuracy: 0.7058 Epoch 49/50 1000/1000 [==============================] - 3s 3ms/step - loss: 0.5221 - accuracy: 0.8142 - val_loss: 0.9902 - val_accuracy: 0.6972 Epoch 50/50 995/1000 [============================>.] - ETA: 0s - loss: 0.5155 - accuracy: 0.8155 Running privacy report for epoch: 50 1000/1000 [==============================] - 2s 2ms/step 200/200 [==============================] - 0s 2ms/step 1000/1000 [==============================] - 14s 14ms/step - loss: 0.5158 - accuracy: 0.8154 - val_loss: 0.9978 - val_accuracy: 0.6944
Epoch Plots
You can visualize how privacy risks happen as you train models by probing the model periodically (e.g. every 5 epochs), you can pick the point in time with the best performance / privacy trade-off.
Use the TF Privacy Membership Inference Attack module to generate AttackResults
. These AttackResults
get combined into an AttackResultsCollection
. The TF Privacy Report is designed to analyze the provided AttackResultsCollection
.
results = AttackResultsCollection(all_reports)
privacy_metrics = (PrivacyMetric.AUC, PrivacyMetric.ATTACKER_ADVANTAGE)
epoch_plot = privacy_report.plot_by_epochs(
results, privacy_metrics=privacy_metrics)
See that as a rule, privacy vulnerability tends to increase as the number of epochs goes up. This is true across model variants as well as different attacker types.
Two layer models (with fewer convolutional layers) are generally more vulnerable than their three layer model counterparts.
Now let's see how model performance changes with respect to privacy risk.
Privacy vs Utility
privacy_metrics = (PrivacyMetric.AUC, PrivacyMetric.ATTACKER_ADVANTAGE)
utility_privacy_plot = privacy_report.plot_privacy_vs_accuracy(
results, privacy_metrics=privacy_metrics)
for axis in utility_privacy_plot.axes:
axis.set_xlabel('Validation accuracy')
Three layer models (perhaps due to too many parameters) only achieve a train accuracy of 0.85. The two layer models achieve roughly equal performance for that level of privacy risk but they continue to get better accuracy.
You can also see how the line for two layer models gets steeper. This means that additional marginal gains in train accuracy come at an expense of vast privacy vulnerabilities.
This is the end of the tutorial. Feel free to analyze your own results.