Diese Seite wurde von der Cloud Translation API übersetzt.
Switch to English

Vorgefertigte TF-Gittermodelle

Ansicht auf TensorFlow.org In Google Colab ausführen Quelle auf GitHub anzeigen Notizbuch herunterladen

Überblick

Vorgefertigte Modelle sind schnelle und einfache Methoden zum Erstellen von TFL tf.keras.model Instanzen für typische Anwendungsfälle. In diesem Handbuch werden die Schritte beschrieben, die zum Erstellen eines vorgefertigten TFL-Modells und zum Trainieren / Testen erforderlich sind.

Konfiguration

Installieren des TF-Gitterpakets:


pip install -q tensorflow-lattice pydot

Erforderliche Pakete importieren:

import tensorflow as tf

import copy
import logging
import numpy as np
import pandas as pd
import sys
import tensorflow_lattice as tfl
logging.disable(sys.maxsize)

Herunterladen des UCI Statlog (Heart) -Datensatzes:

csv_file = tf.keras.utils.get_file(
    'heart.csv', 'http://storage.googleapis.com/applied-dl/heart.csv')
df = pd.read_csv(csv_file)
train_size = int(len(df) * 0.8)
train_dataframe = df[:train_size]
test_dataframe = df[train_size:]
df.head()

Extrahieren und konvertieren Sie Features und Beschriftungen in Tensoren:

# Features:
# - age
# - sex
# - cp        chest pain type (4 values)
# - trestbps  resting blood pressure
# - chol      serum cholestoral in mg/dl
# - fbs       fasting blood sugar > 120 mg/dl
# - restecg   resting electrocardiographic results (values 0,1,2)
# - thalach   maximum heart rate achieved
# - exang     exercise induced angina
# - oldpeak   ST depression induced by exercise relative to rest
# - slope     the slope of the peak exercise ST segment
# - ca        number of major vessels (0-3) colored by flourosopy
# - thal      3 = normal; 6 = fixed defect; 7 = reversable defect
#
# This ordering of feature names will be the exact same order that we construct
# our model to expect.
feature_names = [
    'age', 'sex', 'cp', 'chol', 'fbs', 'trestbps', 'thalach', 'restecg',
    'exang', 'oldpeak', 'slope', 'ca', 'thal'
]
feature_name_indices = {name: index for index, name in enumerate(feature_names)}
# This is the vocab list and mapping we will use for the 'thal' categorical
# feature.
thal_vocab_list = ['normal', 'fixed', 'reversible']
thal_map = {category: i for i, category in enumerate(thal_vocab_list)}
# Custom function for converting thal categories to buckets
def convert_thal_features(thal_features):
  # Note that two examples in the test set are already converted.
  return np.array([
      thal_map[feature] if feature in thal_vocab_list else feature
      for feature in thal_features
  ])


# Custom function for extracting each feature.
def extract_features(dataframe,
                     label_name='target',
                     feature_names=feature_names):
  features = []
  for feature_name in feature_names:
    if feature_name == 'thal':
      features.append(
          convert_thal_features(dataframe[feature_name].values).astype(float))
    else:
      features.append(dataframe[feature_name].values.astype(float))
  labels = dataframe[label_name].values.astype(float)
  return features, labels
train_xs, train_ys = extract_features(train_dataframe)
test_xs, test_ys = extract_features(test_dataframe)
# Let's define our label minimum and maximum.
min_label, max_label = float(np.min(train_ys)), float(np.max(train_ys))
# Our lattice models may have predictions above 1.0 due to numerical errors.
# We can subtract this small epsilon value from our output_max to make sure we
# do not predict values outside of our label bound.
numerical_error_epsilon = 1e-5

Festlegen der Standardwerte für das Training in diesem Handbuch:

LEARNING_RATE = 0.01
BATCH_SIZE = 128
NUM_EPOCHS = 500
PREFITTING_NUM_EPOCHS = 10

Funktionskonfigurationen

Die Feature-Kalibrierung und die Konfiguration pro Feature werden mit tfl.configs.FeatureConfig festgelegt . Zu den Feature-Konfigurationen gehören Monotonieeinschränkungen, Regularisierung pro Feature (siehe tfl.configs.RegularizerConfig ) und Gittergrößen für Gittermodelle.

Beachten Sie, dass wir die Feature-Konfiguration für jedes Feature, das unser Modell erkennen soll, vollständig angeben müssen. Andernfalls kann das Modell nicht erkennen, dass eine solche Funktion vorhanden ist.

Quantile berechnen

Obwohl die Standardeinstellung für pwl_calibration_input_keypoints in tfl.configs.FeatureConfig "Quantile" ist, müssen wir für vorgefertigte Modelle die Eingabeschlüsselpunkte manuell definieren. Dazu definieren wir zunächst unsere eigene Hilfsfunktion zur Berechnung von Quantilen.

def compute_quantiles(features,
                      num_keypoints=10,
                      clip_min=None,
                      clip_max=None,
                      missing_value=None):
  # Clip min and max if desired.
  if clip_min is not None:
    features = np.maximum(features, clip_min)
    features = np.append(features, clip_min)
  if clip_max is not None:
    features = np.minimum(features, clip_max)
    features = np.append(features, clip_max)
  # Make features unique.
  unique_features = np.unique(features)
  # Remove missing values if specified.
  if missing_value is not None:
    unique_features = np.delete(unique_features,
                                np.where(unique_features == missing_value))
  # Compute and return quantiles over unique non-missing feature values.
  return np.quantile(
      unique_features,
      np.linspace(0., 1., num=num_keypoints),
      interpolation='nearest').astype(float)

Definieren unserer Funktionskonfigurationen

Nachdem wir unsere Quantile berechnen können, definieren wir eine Feature-Konfiguration für jedes Feature, das unser Modell als Eingabe verwenden soll.

# Feature configs are used to specify how each feature is calibrated and used.
feature_configs = [
    tfl.configs.FeatureConfig(
        name='age',
        lattice_size=3,
        monotonicity='increasing',
        # We must set the keypoints manually.
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['age']],
            num_keypoints=5,
            clip_max=100),
        # Per feature regularization.
        regularizer_configs=[
            tfl.configs.RegularizerConfig(name='calib_wrinkle', l2=0.1),
        ],
    ),
    tfl.configs.FeatureConfig(
        name='sex',
        num_buckets=2,
    ),
    tfl.configs.FeatureConfig(
        name='cp',
        monotonicity='increasing',
        # Keypoints that are uniformly spaced.
        pwl_calibration_num_keypoints=4,
        pwl_calibration_input_keypoints=np.linspace(
            np.min(train_xs[feature_name_indices['cp']]),
            np.max(train_xs[feature_name_indices['cp']]),
            num=4),
    ),
    tfl.configs.FeatureConfig(
        name='chol',
        monotonicity='increasing',
        # Explicit input keypoints initialization.
        pwl_calibration_input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0],
        # Calibration can be forced to span the full output range by clamping.
        pwl_calibration_clamp_min=True,
        pwl_calibration_clamp_max=True,
        # Per feature regularization.
        regularizer_configs=[
            tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-4),
        ],
    ),
    tfl.configs.FeatureConfig(
        name='fbs',
        # Partial monotonicity: output(0) <= output(1)
        monotonicity=[(0, 1)],
        num_buckets=2,
    ),
    tfl.configs.FeatureConfig(
        name='trestbps',
        monotonicity='decreasing',
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['trestbps']], num_keypoints=5),
    ),
    tfl.configs.FeatureConfig(
        name='thalach',
        monotonicity='decreasing',
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['thalach']], num_keypoints=5),
    ),
    tfl.configs.FeatureConfig(
        name='restecg',
        # Partial monotonicity: output(0) <= output(1), output(0) <= output(2)
        monotonicity=[(0, 1), (0, 2)],
        num_buckets=3,
    ),
    tfl.configs.FeatureConfig(
        name='exang',
        # Partial monotonicity: output(0) <= output(1)
        monotonicity=[(0, 1)],
        num_buckets=2,
    ),
    tfl.configs.FeatureConfig(
        name='oldpeak',
        monotonicity='increasing',
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['oldpeak']], num_keypoints=5),
    ),
    tfl.configs.FeatureConfig(
        name='slope',
        # Partial monotonicity: output(0) <= output(1), output(1) <= output(2)
        monotonicity=[(0, 1), (1, 2)],
        num_buckets=3,
    ),
    tfl.configs.FeatureConfig(
        name='ca',
        monotonicity='increasing',
        pwl_calibration_num_keypoints=4,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['ca']], num_keypoints=4),
    ),
    tfl.configs.FeatureConfig(
        name='thal',
        # Partial monotonicity:
        # output(normal) <= output(fixed)
        # output(normal) <= output(reversible)
        monotonicity=[('normal', 'fixed'), ('normal', 'reversible')],
        num_buckets=3,
        # We must specify the vocabulary list in order to later set the
        # monotonicities since we used names and not indices.
        vocabulary_list=thal_vocab_list,
    ),
]

Als nächstes müssen wir sicherstellen, dass die Monotonien für Funktionen, für die wir ein benutzerdefiniertes Vokabular verwendet haben (wie z. B. 'thal' oben), richtig eingestellt sind.

tfl.premade_lib.set_categorical_monotonicities(feature_configs)

Kalibriertes lineares Modell

Um ein vorgefertigtes TFL-Modell zu erstellen, erstellen Sie zunächst eine Modellkonfiguration aus tfl.configs . Ein kalibriertes lineares Modell wird mit tfl.configs.CalibratedLinearConfig erstellt . Es wendet eine stückweise lineare und kategoriale Kalibrierung auf die Eingabemerkmale an, gefolgt von einer linearen Kombination und einer optionalen stückweise linearen Ausgabekalibrierung. Bei Verwendung der Ausgabekalibrierung oder bei Angabe von Ausgabegrenzen wendet die lineare Ebene eine gewichtete Mittelung auf kalibrierte Eingaben an.

In diesem Beispiel wird ein kalibriertes lineares Modell für die ersten 5 Features erstellt.

# Model config defines the model structure for the premade model.
linear_model_config = tfl.configs.CalibratedLinearConfig(
    feature_configs=feature_configs[:5],
    use_bias=True,
    # We must set the output min and max to that of the label.
    output_min=min_label,
    output_max=max_label,
    output_calibration=True,
    output_calibration_num_keypoints=10,
    output_initialization=np.linspace(min_label, max_label, num=10),
    regularizer_configs=[
        # Regularizer for the output calibrator.
        tfl.configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4),
    ])
# A CalibratedLinear premade model constructed from the given model config.
linear_model = tfl.premade.CalibratedLinear(linear_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(linear_model, show_layer_names=False, rankdir='LR')

png

Wie bei jedem anderen tf.keras.Model kompilieren wir das Modell und passen es an unsere Daten an.

linear_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
linear_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
<tensorflow.python.keras.callbacks.History at 0x7fd5c079a1d0>

Nachdem wir unser Modell trainiert haben, können wir es auf unserem Testset bewerten.

print('Test Set Evaluation...')
print(linear_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
2/2 [==============================] - 0s 1ms/step - loss: 0.4644 - auc: 0.8459
[0.46442732214927673, 0.8458647131919861]

Kalibriertes Gittermodell

Ein kalibriertes Gittermodell wird mit tfl.configs.CalibratedLatticeConfig erstellt . Ein kalibriertes Gittermodell wendet eine stückweise lineare und kategoriale Kalibrierung auf die Eingabemerkmale an, gefolgt von einem Gittermodell und einer optionalen stückweise linearen Ausgabekalibrierung.

In diesem Beispiel wird ein kalibriertes Gittermodell für die ersten 5 Features erstellt.

# This is a calibrated lattice model: inputs are calibrated, then combined
# non-linearly using a lattice layer.
lattice_model_config = tfl.configs.CalibratedLatticeConfig(
    feature_configs=feature_configs[:5],
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label],
    regularizer_configs=[
        # Torsion regularizer applied to the lattice to make it more linear.
        tfl.configs.RegularizerConfig(name='torsion', l2=1e-2),
        # Globally defined calibration regularizer is applied to all features.
        tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-2),
    ])
# A CalibratedLattice premade model constructed from the given model config.
lattice_model = tfl.premade.CalibratedLattice(lattice_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(lattice_model, show_layer_names=False, rankdir='LR')

png

Nach wie vor kompilieren, passen und bewerten wir unser Modell.

lattice_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
lattice_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(lattice_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
2/2 [==============================] - 0s 2ms/step - loss: 0.4789 - auc_1: 0.8409
[0.4789487421512604, 0.8408521413803101]

Kalibriertes Gitterensemble-Modell

Wenn die Anzahl der Features groß ist, können Sie ein Ensemble-Modell verwenden, das mehrere kleinere Gitter für Teilmengen der Features erstellt und deren Ausgabe mittelt, anstatt nur ein einziges großes Gitter zu erstellen. Ensemble-Gittermodelle werden mit tfl.configs.CalibratedLatticeEnsembleConfig erstellt . Ein kalibriertes Gitterensemble-Modell wendet eine stückweise lineare und kategoriale Kalibrierung auf die Eingabefunktion an, gefolgt von einem Ensemble von Gittermodellen und einer optionalen stückweise linearen Ausgabekalibrierung.

Explizite Initialisierung des Gitterensembles

Wenn Sie bereits wissen, welche Teilmengen von Features Sie in Ihre Gitter einspeisen möchten, können Sie die Gitter explizit mithilfe von Feature-Namen festlegen. In diesem Beispiel wird ein kalibriertes Gitterensemble-Modell mit 5 Gittern und 3 Merkmalen pro Gitter erstellt.

# This is a calibrated lattice ensemble model: inputs are calibrated, then
# combined non-linearly and averaged using multiple lattice layers.
explicit_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(
    feature_configs=feature_configs,
    lattices=[['trestbps', 'chol', 'ca'], ['fbs', 'restecg', 'thal'],
              ['fbs', 'cp', 'oldpeak'], ['exang', 'slope', 'thalach'],
              ['restecg', 'age', 'sex']],
    num_lattices=5,
    lattice_rank=3,
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label])
# A CalibratedLatticeEnsemble premade model constructed from the given
# model config.
explicit_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(
    explicit_ensemble_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(
    explicit_ensemble_model, show_layer_names=False, rankdir='LR')

png

Nach wie vor kompilieren, passen und bewerten wir unser Modell.

explicit_ensemble_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
explicit_ensemble_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(explicit_ensemble_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
2/2 [==============================] - 0s 2ms/step - loss: 0.4373 - auc_2: 0.8615
[0.437343567609787, 0.8615288734436035]

Zufälliges Gitterensemble

Wenn Sie nicht sicher sind, welche Teilmengen von Merkmalen in Ihre Gitter eingespeist werden sollen, können Sie auch zufällige Teilmengen von Merkmalen für jedes Gitter verwenden. In diesem Beispiel wird ein kalibriertes Gitterensemble-Modell mit 5 Gittern und 3 Merkmalen pro Gitter erstellt.

# This is a calibrated lattice ensemble model: inputs are calibrated, then
# combined non-linearly and averaged using multiple lattice layers.
random_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(
    feature_configs=feature_configs,
    lattices='random',
    num_lattices=5,
    lattice_rank=3,
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label],
    random_seed=42)
# Now we must set the random lattice structure and construct the model.
tfl.premade_lib.set_random_lattice_ensemble(random_ensemble_model_config)
# A CalibratedLatticeEnsemble premade model constructed from the given
# model config.
random_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(
    random_ensemble_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(
    random_ensemble_model, show_layer_names=False, rankdir='LR')

png

Nach wie vor kompilieren, passen und bewerten wir unser Modell.

random_ensemble_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
random_ensemble_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(random_ensemble_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
2/2 [==============================] - 0s 2ms/step - loss: 0.4034 - auc_3: 0.9223
[0.40344616770744324, 0.9223057627677917]

RTL Layer Random Lattice Ensemble

Wenn Sie ein zufälliges Gitterensemble verwenden, können Sie festlegen, dass das Modell eine einzelne tfl.layers.RTL verwendet. Wir stellen fest, dass tfl.layers.RTL nur Monotonieeinschränkungen unterstützt und für alle Features dieselbe tfl.layers.RTL und keine Regularisierung pro Feature haben muss. Beachten Sie, dass Sie mit einer tfl.layers.RTL Ebene auf viel größere Ensembles tfl.layers.Lattice als mit separaten tfl.layers.Lattice Instanzen.

In diesem Beispiel wird ein kalibriertes Gitterensemble-Modell mit 5 Gittern und 3 Merkmalen pro Gitter erstellt.

# Make sure our feature configs have the same lattice size, no per-feature
# regularization, and only monotonicity constraints.
rtl_layer_feature_configs = copy.deepcopy(feature_configs)
for feature_config in rtl_layer_feature_configs:
  feature_config.lattice_size = 2
  feature_config.unimodality = 'none'
  feature_config.reflects_trust_in = None
  feature_config.dominates = None
  feature_config.regularizer_configs = None
# This is a calibrated lattice ensemble model: inputs are calibrated, then
# combined non-linearly and averaged using multiple lattice layers.
rtl_layer_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(
    feature_configs=rtl_layer_feature_configs,
    lattices='rtl_layer',
    num_lattices=5,
    lattice_rank=3,
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label],
    random_seed=42)
# A CalibratedLatticeEnsemble premade model constructed from the given
# model config. Note that we do not have to specify the lattices by calling
# a helper function (like before with random) because the RTL Layer will take
# care of that for us.
rtl_layer_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(
    rtl_layer_ensemble_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(
    rtl_layer_ensemble_model, show_layer_names=False, rankdir='LR')

png

Nach wie vor kompilieren, passen und bewerten wir unser Modell.

rtl_layer_ensemble_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
rtl_layer_ensemble_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(rtl_layer_ensemble_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
2/2 [==============================] - 0s 2ms/step - loss: 0.4287 - auc_4: 0.8684
[0.42873889207839966, 0.8684210777282715]

Crystals Lattice Ensemble

Premade bietet auch einen heuristischen Algorithmus für die Anordnung von Merkmalen namens Crystals . Um den Crystals-Algorithmus zu verwenden, trainieren wir zunächst ein Voranpassungsmodell, das paarweise Merkmalsinteraktionen schätzt. Wir ordnen dann das endgültige Ensemble so an, dass Merkmale mit mehr nichtlinearen Wechselwirkungen in denselben Gittern liegen.

Die vorgefertigte Bibliothek bietet Hilfsfunktionen zum Erstellen der vormontierten Modellkonfiguration und zum Extrahieren der Kristallstruktur. Beachten Sie, dass das Voranpassungsmodell nicht vollständig trainiert werden muss, daher sollten einige Epochen ausreichen.

In diesem Beispiel wird ein kalibriertes Gitterensemble-Modell mit 5 Gittern und 3 Merkmalen pro Gitter erstellt.

# This is a calibrated lattice ensemble model: inputs are calibrated, then
# combines non-linearly and averaged using multiple lattice layers.
crystals_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(
    feature_configs=feature_configs,
    lattices='crystals',
    num_lattices=5,
    lattice_rank=3,
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label],
    random_seed=42)
# Now that we have our model config, we can construct a prefitting model config.
prefitting_model_config = tfl.premade_lib.construct_prefitting_model_config(
    crystals_ensemble_model_config)
# A CalibratedLatticeEnsemble premade model constructed from the given
# prefitting model config.
prefitting_model = tfl.premade.CalibratedLatticeEnsemble(
    prefitting_model_config)
# We can compile and train our prefitting model as we like.
prefitting_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
prefitting_model.fit(
    train_xs,
    train_ys,
    epochs=PREFITTING_NUM_EPOCHS,
    batch_size=BATCH_SIZE,
    verbose=False)
# Now that we have our trained prefitting model, we can extract the crystals.
tfl.premade_lib.set_crystals_lattice_ensemble(crystals_ensemble_model_config,
                                              prefitting_model_config,
                                              prefitting_model)
# A CalibratedLatticeEnsemble premade model constructed from the given
# model config.
crystals_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(
    crystals_ensemble_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(
    crystals_ensemble_model, show_layer_names=False, rankdir='LR')

png

Nach wie vor kompilieren, passen und bewerten wir unser Modell.

crystals_ensemble_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
crystals_ensemble_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(crystals_ensemble_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
2/2 [==============================] - 0s 2ms/step - loss: 0.4039 - auc_5: 0.8853
[0.40386414527893066, 0.885338306427002]