Zapisz datę! Google I / O powraca w dniach 18-20 maja Zarejestruj się teraz
Ta strona została przetłumaczona przez Cloud Translation API.
Switch to English

Gotowe modele kratowe TF

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło w serwisie GitHub Pobierz notatnik

Przegląd

Gotowe modele to szybkie i łatwe sposoby tworzenia instancji TFL tf.keras.model dla typowych przypadków użycia. Ten przewodnik przedstawia kroki potrzebne do skonstruowania gotowego modelu TFL i wytrenowania / przetestowania go.

Ustawiać

Instalowanie pakietu TF Lattice:

pip install -q tensorflow-lattice pydot

Importowanie wymaganych pakietów:

import tensorflow as tf

import copy
import logging
import numpy as np
import pandas as pd
import sys
import tensorflow_lattice as tfl
logging.disable(sys.maxsize)

Pobieranie zestawu danych UCI Statlog (Heart):

csv_file = tf.keras.utils.get_file(
    'heart.csv', 'http://storage.googleapis.com/download.tensorflow.org/data/heart.csv')
df = pd.read_csv(csv_file)
train_size = int(len(df) * 0.8)
train_dataframe = df[:train_size]
test_dataframe = df[train_size:]
df.head()
Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/heart.csv
16384/13273 [=====================================] - 0s 0us/step

Wyodrębnij i przekonwertuj elementy i etykiety na tensory:

# Features:
# - age
# - sex
# - cp        chest pain type (4 values)
# - trestbps  resting blood pressure
# - chol      serum cholestoral in mg/dl
# - fbs       fasting blood sugar > 120 mg/dl
# - restecg   resting electrocardiographic results (values 0,1,2)
# - thalach   maximum heart rate achieved
# - exang     exercise induced angina
# - oldpeak   ST depression induced by exercise relative to rest
# - slope     the slope of the peak exercise ST segment
# - ca        number of major vessels (0-3) colored by flourosopy
# - thal      3 = normal; 6 = fixed defect; 7 = reversable defect
#
# This ordering of feature names will be the exact same order that we construct
# our model to expect.
feature_names = [
    'age', 'sex', 'cp', 'chol', 'fbs', 'trestbps', 'thalach', 'restecg',
    'exang', 'oldpeak', 'slope', 'ca', 'thal'
]
feature_name_indices = {name: index for index, name in enumerate(feature_names)}
# This is the vocab list and mapping we will use for the 'thal' categorical
# feature.
thal_vocab_list = ['normal', 'fixed', 'reversible']
thal_map = {category: i for i, category in enumerate(thal_vocab_list)}
# Custom function for converting thal categories to buckets
def convert_thal_features(thal_features):
  # Note that two examples in the test set are already converted.
  return np.array([
      thal_map[feature] if feature in thal_vocab_list else feature
      for feature in thal_features
  ])


# Custom function for extracting each feature.
def extract_features(dataframe,
                     label_name='target',
                     feature_names=feature_names):
  features = []
  for feature_name in feature_names:
    if feature_name == 'thal':
      features.append(
          convert_thal_features(dataframe[feature_name].values).astype(float))
    else:
      features.append(dataframe[feature_name].values.astype(float))
  labels = dataframe[label_name].values.astype(float)
  return features, labels
train_xs, train_ys = extract_features(train_dataframe)
test_xs, test_ys = extract_features(test_dataframe)
# Let's define our label minimum and maximum.
min_label, max_label = float(np.min(train_ys)), float(np.max(train_ys))
# Our lattice models may have predictions above 1.0 due to numerical errors.
# We can subtract this small epsilon value from our output_max to make sure we
# do not predict values outside of our label bound.
numerical_error_epsilon = 1e-5

Ustawianie wartości domyślnych używanych do treningu w tym przewodniku:

LEARNING_RATE = 0.01
BATCH_SIZE = 128
NUM_EPOCHS = 500
PREFITTING_NUM_EPOCHS = 10

Konfiguracje funkcji

Kalibracja funkcji i konfiguracje dla poszczególnych funkcji są ustawiane za pomocą tfl.configs.FeatureConfig . Konfiguracje funkcji obejmują ograniczenia monotoniczności, regularyzację poszczególnych funkcji (patrz tfl.configs.RegularizerConfig ) i rozmiary krat dla modeli kratowych.

Zauważ, że musimy w pełni określić konfigurację funkcji dla każdej funkcji, którą chcemy, aby nasz model rozpoznawał. W przeciwnym razie model nie będzie miał możliwości dowiedzenia się, że taka funkcja istnieje.

Oblicz kwantyle

Chociaż domyślnym ustawieniem dla pwl_calibration_input_keypoints w tfl.configs.FeatureConfig jest „kwantyle”, dla gotowych modeli musimy ręcznie zdefiniować wejściowe punkty kluczowe. Aby to zrobić, najpierw definiujemy naszą własną funkcję pomocniczą do obliczania kwantyli.

def compute_quantiles(features,
                      num_keypoints=10,
                      clip_min=None,
                      clip_max=None,
                      missing_value=None):
  # Clip min and max if desired.
  if clip_min is not None:
    features = np.maximum(features, clip_min)
    features = np.append(features, clip_min)
  if clip_max is not None:
    features = np.minimum(features, clip_max)
    features = np.append(features, clip_max)
  # Make features unique.
  unique_features = np.unique(features)
  # Remove missing values if specified.
  if missing_value is not None:
    unique_features = np.delete(unique_features,
                                np.where(unique_features == missing_value))
  # Compute and return quantiles over unique non-missing feature values.
  return np.quantile(
      unique_features,
      np.linspace(0., 1., num=num_keypoints),
      interpolation='nearest').astype(float)

Definiowanie naszych konfiguracji funkcji

Teraz, gdy możemy obliczyć nasze kwantyle, definiujemy konfigurację funkcji dla każdej funkcji, którą nasz model ma przyjąć jako dane wejściowe.

# Feature configs are used to specify how each feature is calibrated and used.
feature_configs = [
    tfl.configs.FeatureConfig(
        name='age',
        lattice_size=3,
        monotonicity='increasing',
        # We must set the keypoints manually.
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['age']],
            num_keypoints=5,
            clip_max=100),
        # Per feature regularization.
        regularizer_configs=[
            tfl.configs.RegularizerConfig(name='calib_wrinkle', l2=0.1),
        ],
    ),
    tfl.configs.FeatureConfig(
        name='sex',
        num_buckets=2,
    ),
    tfl.configs.FeatureConfig(
        name='cp',
        monotonicity='increasing',
        # Keypoints that are uniformly spaced.
        pwl_calibration_num_keypoints=4,
        pwl_calibration_input_keypoints=np.linspace(
            np.min(train_xs[feature_name_indices['cp']]),
            np.max(train_xs[feature_name_indices['cp']]),
            num=4),
    ),
    tfl.configs.FeatureConfig(
        name='chol',
        monotonicity='increasing',
        # Explicit input keypoints initialization.
        pwl_calibration_input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0],
        # Calibration can be forced to span the full output range by clamping.
        pwl_calibration_clamp_min=True,
        pwl_calibration_clamp_max=True,
        # Per feature regularization.
        regularizer_configs=[
            tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-4),
        ],
    ),
    tfl.configs.FeatureConfig(
        name='fbs',
        # Partial monotonicity: output(0) <= output(1)
        monotonicity=[(0, 1)],
        num_buckets=2,
    ),
    tfl.configs.FeatureConfig(
        name='trestbps',
        monotonicity='decreasing',
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['trestbps']], num_keypoints=5),
    ),
    tfl.configs.FeatureConfig(
        name='thalach',
        monotonicity='decreasing',
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['thalach']], num_keypoints=5),
    ),
    tfl.configs.FeatureConfig(
        name='restecg',
        # Partial monotonicity: output(0) <= output(1), output(0) <= output(2)
        monotonicity=[(0, 1), (0, 2)],
        num_buckets=3,
    ),
    tfl.configs.FeatureConfig(
        name='exang',
        # Partial monotonicity: output(0) <= output(1)
        monotonicity=[(0, 1)],
        num_buckets=2,
    ),
    tfl.configs.FeatureConfig(
        name='oldpeak',
        monotonicity='increasing',
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['oldpeak']], num_keypoints=5),
    ),
    tfl.configs.FeatureConfig(
        name='slope',
        # Partial monotonicity: output(0) <= output(1), output(1) <= output(2)
        monotonicity=[(0, 1), (1, 2)],
        num_buckets=3,
    ),
    tfl.configs.FeatureConfig(
        name='ca',
        monotonicity='increasing',
        pwl_calibration_num_keypoints=4,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['ca']], num_keypoints=4),
    ),
    tfl.configs.FeatureConfig(
        name='thal',
        # Partial monotonicity:
        # output(normal) <= output(fixed)
        # output(normal) <= output(reversible)
        monotonicity=[('normal', 'fixed'), ('normal', 'reversible')],
        num_buckets=3,
        # We must specify the vocabulary list in order to later set the
        # monotonicities since we used names and not indices.
        vocabulary_list=thal_vocab_list,
    ),
]

Następnie musimy upewnić się, że poprawnie ustawiliśmy monotonię dla funkcji, w których używaliśmy niestandardowego słownictwa (takiego jak „thal” powyżej).

tfl.premade_lib.set_categorical_monotonicities(feature_configs)

Kalibrowany model liniowy

Aby zbudować gotowy model TFL, najpierw utwórz konfigurację modelu z tfl.configs . Skalibrowany model liniowy jest konstruowany przy użyciu tfl.configs.CalibratedLinearConfig . Stosuje odcinkowo-liniową i kategorialną kalibrację cech wejściowych, po której następuje kombinacja liniowa i opcjonalna wyjściowa kalibracja odcinkowo-liniowa. Podczas korzystania z kalibracji wyjściowej lub gdy określone są granice wyjściowe, warstwa liniowa zastosuje uśrednienie ważone na skalibrowanych wejściach.

Ten przykład tworzy skalibrowany model liniowy dla pierwszych 5 elementów.

# Model config defines the model structure for the premade model.
linear_model_config = tfl.configs.CalibratedLinearConfig(
    feature_configs=feature_configs[:5],
    use_bias=True,
    # We must set the output min and max to that of the label.
    output_min=min_label,
    output_max=max_label,
    output_calibration=True,
    output_calibration_num_keypoints=10,
    output_initialization=np.linspace(min_label, max_label, num=10),
    regularizer_configs=[
        # Regularizer for the output calibrator.
        tfl.configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4),
    ])
# A CalibratedLinear premade model constructed from the given model config.
linear_model = tfl.premade.CalibratedLinear(linear_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(linear_model, show_layer_names=False, rankdir='LR')

png

Teraz, podobnie jak w przypadku każdego innego tf.keras.Model , kompilujemy i dopasowujemy model do naszych danych.

linear_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
linear_model.fit(
    train_xs[:5],
    train_ys,
    epochs=NUM_EPOCHS,
    batch_size=BATCH_SIZE,
    verbose=False)
<tensorflow.python.keras.callbacks.History at 0x7ff2bf765860>

Po wytrenowaniu naszego modelu możemy go ocenić na naszym zestawie testowym.

print('Test Set Evaluation...')
print(linear_model.evaluate(test_xs[:5], test_ys))
Test Set Evaluation...
2/2 [==============================] - 0s 3ms/step - loss: 0.4849 - auc: 0.8214
[0.48487865924835205, 0.8214285969734192]

Kalibrowany model kratowy

Skalibrowany model sieci jest konstruowany przy użyciu tfl.configs.CalibratedLatticeConfig . Skalibrowany model sieciowy stosuje odcinkowo-liniową i kategorialną kalibrację elementów wejściowych, a następnie model sieciowy i opcjonalną wyjściową kalibrację odcinkowo-liniową.

Ten przykład tworzy skalibrowany model sieciowy dla pierwszych 5 elementów.

# This is a calibrated lattice model: inputs are calibrated, then combined
# non-linearly using a lattice layer.
lattice_model_config = tfl.configs.CalibratedLatticeConfig(
    feature_configs=feature_configs[:5],
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label],
    regularizer_configs=[
        # Torsion regularizer applied to the lattice to make it more linear.
        tfl.configs.RegularizerConfig(name='torsion', l2=1e-2),
        # Globally defined calibration regularizer is applied to all features.
        tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-2),
    ])
# A CalibratedLattice premade model constructed from the given model config.
lattice_model = tfl.premade.CalibratedLattice(lattice_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(lattice_model, show_layer_names=False, rankdir='LR')

png

Tak jak poprzednio, kompilujemy, dopasowujemy i oceniamy nasz model.

lattice_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
lattice_model.fit(
    train_xs[:5],
    train_ys,
    epochs=NUM_EPOCHS,
    batch_size=BATCH_SIZE,
    verbose=False)
print('Test Set Evaluation...')
print(lattice_model.evaluate(test_xs[:5], test_ys))
Test Set Evaluation...
2/2 [==============================] - 1s 3ms/step - loss: 0.4784 - auc_1: 0.8402
[0.47842937707901, 0.8402255773544312]

Skalibrowany model zespołu kratownicy

Gdy liczba elementów jest duża, można użyć modelu zespolonego, który tworzy wiele mniejszych kratek dla podzbiorów elementów i uśrednia ich wydajność zamiast tworzyć tylko jedną ogromną siatkę. Modele kratownic zespolonych są konstruowane przy użyciu tfl.configs.CalibratedLatticeEnsembleConfig . Skalibrowany model zespołu sieci kratowej stosuje odcinkowo-liniową i kategorialną kalibrację elementu wejściowego, po której następuje zestaw modeli kratowych i opcjonalna wyjściowa kalibracja odcinkowo-liniowa.

Jawna inicjalizacja zespołu kratownicy

Jeśli już wiesz, które podzbiory obiektów chcesz wprowadzić do swoich kratownic, możesz jawnie ustawić kraty za pomocą nazw elementów. Ten przykład tworzy skalibrowany model zespołu kratownicy z 5 kratami i 3 elementami na kratownicę.

# This is a calibrated lattice ensemble model: inputs are calibrated, then
# combined non-linearly and averaged using multiple lattice layers.
explicit_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(
    feature_configs=feature_configs,
    lattices=[['trestbps', 'chol', 'ca'], ['fbs', 'restecg', 'thal'],
              ['fbs', 'cp', 'oldpeak'], ['exang', 'slope', 'thalach'],
              ['restecg', 'age', 'sex']],
    num_lattices=5,
    lattice_rank=3,
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label])
# A CalibratedLatticeEnsemble premade model constructed from the given
# model config.
explicit_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(
    explicit_ensemble_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(
    explicit_ensemble_model, show_layer_names=False, rankdir='LR')

png

Tak jak poprzednio, kompilujemy, dopasowujemy i oceniamy nasz model.

explicit_ensemble_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
explicit_ensemble_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(explicit_ensemble_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
2/2 [==============================] - 1s 3ms/step - loss: 0.4281 - auc_2: 0.8659
[0.42808252573013306, 0.8659147620201111]

Random Lattice Ensemble

Jeśli nie masz pewności, które podzbiory obiektów należy wprowadzić do sieci, inną opcją jest użycie losowych podzbiorów cech dla każdej kraty. Ten przykład tworzy skalibrowany model zespołu kratownicy z 5 kratami i 3 elementami na kratownicę.

# This is a calibrated lattice ensemble model: inputs are calibrated, then
# combined non-linearly and averaged using multiple lattice layers.
random_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(
    feature_configs=feature_configs,
    lattices='random',
    num_lattices=5,
    lattice_rank=3,
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label],
    random_seed=42)
# Now we must set the random lattice structure and construct the model.
tfl.premade_lib.set_random_lattice_ensemble(random_ensemble_model_config)
# A CalibratedLatticeEnsemble premade model constructed from the given
# model config.
random_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(
    random_ensemble_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(
    random_ensemble_model, show_layer_names=False, rankdir='LR')

png

Tak jak poprzednio, kompilujemy, dopasowujemy i oceniamy nasz model.

random_ensemble_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
random_ensemble_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(random_ensemble_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
2/2 [==============================] - 1s 3ms/step - loss: 0.3929 - auc_3: 0.9217
[0.3929273188114166, 0.9216791987419128]

Zespół krat losowych warstwy RTL

Używając losowego zespołu kratownicy, można określić, że model używa pojedynczej warstwy tfl.layers.RTL . Zwracamy uwagę, że tfl.layers.RTL obsługuje tylko ograniczenia monotoniczności i musi mieć ten sam rozmiar kraty dla wszystkich funkcji i bez regularyzacji dla poszczególnych cech. Zauważ, że użycie warstwy tfl.layers.RTL umożliwia skalowanie do znacznie większych tfl.layers.Lattice niż używanie oddzielnych instancji tfl.layers.Lattice .

Ten przykład tworzy skalibrowany model zespołu kratownicy z 5 kratami i 3 elementami na kratownicę.

# Make sure our feature configs have the same lattice size, no per-feature
# regularization, and only monotonicity constraints.
rtl_layer_feature_configs = copy.deepcopy(feature_configs)
for feature_config in rtl_layer_feature_configs:
  feature_config.lattice_size = 2
  feature_config.unimodality = 'none'
  feature_config.reflects_trust_in = None
  feature_config.dominates = None
  feature_config.regularizer_configs = None
# This is a calibrated lattice ensemble model: inputs are calibrated, then
# combined non-linearly and averaged using multiple lattice layers.
rtl_layer_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(
    feature_configs=rtl_layer_feature_configs,
    lattices='rtl_layer',
    num_lattices=5,
    lattice_rank=3,
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label],
    random_seed=42)
# A CalibratedLatticeEnsemble premade model constructed from the given
# model config. Note that we do not have to specify the lattices by calling
# a helper function (like before with random) because the RTL Layer will take
# care of that for us.
rtl_layer_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(
    rtl_layer_ensemble_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(
    rtl_layer_ensemble_model, show_layer_names=False, rankdir='LR')

png

Tak jak poprzednio, kompilujemy, dopasowujemy i oceniamy nasz model.

rtl_layer_ensemble_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
rtl_layer_ensemble_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(rtl_layer_ensemble_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
2/2 [==============================] - 0s 3ms/step - loss: 0.4286 - auc_4: 0.8690
[0.42856135964393616, 0.8690476417541504]

Crystals Lattice Ensemble

Premade zapewnia również heurystyczny algorytm porządkowania cech, zwany kryształami . Aby użyć algorytmu Kryształy, najpierw trenujemy model prefitting, który szacuje parami interakcje cech. Następnie układamy ostateczny zbiór w taki sposób, aby elementy z bardziej nieliniowymi interakcjami znajdowały się w tych samych kratach.

Biblioteka Premade oferuje funkcje pomocnicze do tworzenia konfiguracji modelu prefitting i wyodrębniania struktury kryształów. Zauważ, że model prefittingu nie musi być w pełni wytrenowany, więc kilka epok powinno wystarczyć.

W tym przykładzie tworzony jest skalibrowany model zespołu kratownicy z 5 kratownicami i 3 elementami na kratownicę.

# This is a calibrated lattice ensemble model: inputs are calibrated, then
# combines non-linearly and averaged using multiple lattice layers.
crystals_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(
    feature_configs=feature_configs,
    lattices='crystals',
    num_lattices=5,
    lattice_rank=3,
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label],
    random_seed=42)
# Now that we have our model config, we can construct a prefitting model config.
prefitting_model_config = tfl.premade_lib.construct_prefitting_model_config(
    crystals_ensemble_model_config)
# A CalibratedLatticeEnsemble premade model constructed from the given
# prefitting model config.
prefitting_model = tfl.premade.CalibratedLatticeEnsemble(
    prefitting_model_config)
# We can compile and train our prefitting model as we like.
prefitting_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
prefitting_model.fit(
    train_xs,
    train_ys,
    epochs=PREFITTING_NUM_EPOCHS,
    batch_size=BATCH_SIZE,
    verbose=False)
# Now that we have our trained prefitting model, we can extract the crystals.
tfl.premade_lib.set_crystals_lattice_ensemble(crystals_ensemble_model_config,
                                              prefitting_model_config,
                                              prefitting_model)
# A CalibratedLatticeEnsemble premade model constructed from the given
# model config.
crystals_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(
    crystals_ensemble_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(
    crystals_ensemble_model, show_layer_names=False, rankdir='LR')

png

Tak jak poprzednio, kompilujemy, dopasowujemy i oceniamy nasz model.

crystals_ensemble_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
crystals_ensemble_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(crystals_ensemble_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
2/2 [==============================] - 1s 3ms/step - loss: 0.4671 - auc_5: 0.8283
[0.46707457304000854, 0.8283208608627319]