Mô hình tạo sẵn mạng lưới TF

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Tổng quat

Premade Models là cách nhanh chóng và dễ dàng để xây dựng TFL tf.keras.model trường đối với trường hợp sử dụng điển hình. Hướng dẫn này phác thảo các bước cần thiết để xây dựng Mô hình TFL tạo sẵn và huấn luyện / kiểm tra mô hình đó.

Cài đặt

Cài đặt gói TF Lattice:

pip install -q tensorflow-lattice pydot

Nhập các gói bắt buộc:

import tensorflow as tf

import copy
import logging
import numpy as np
import pandas as pd
import sys
import tensorflow_lattice as tfl
logging.disable(sys.maxsize)

Tải xuống tập dữ liệu UCI Statlog (Heart):

csv_file = tf.keras.utils.get_file(
    'heart.csv', 'http://storage.googleapis.com/download.tensorflow.org/data/heart.csv')
df = pd.read_csv(csv_file)
train_size = int(len(df) * 0.8)
train_dataframe = df[:train_size]
test_dataframe = df[train_size:]
df.head()
Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/heart.csv
16384/13273 [=====================================] - 0s 0us/step

Trích xuất và chuyển đổi các tính năng và nhãn thành tensors:

# Features:
# - age
# - sex
# - cp        chest pain type (4 values)
# - trestbps  resting blood pressure
# - chol      serum cholestoral in mg/dl
# - fbs       fasting blood sugar > 120 mg/dl
# - restecg   resting electrocardiographic results (values 0,1,2)
# - thalach   maximum heart rate achieved
# - exang     exercise induced angina
# - oldpeak   ST depression induced by exercise relative to rest
# - slope     the slope of the peak exercise ST segment
# - ca        number of major vessels (0-3) colored by flourosopy
# - thal      3 = normal; 6 = fixed defect; 7 = reversable defect
#
# This ordering of feature names will be the exact same order that we construct
# our model to expect.
feature_names = [
    'age', 'sex', 'cp', 'chol', 'fbs', 'trestbps', 'thalach', 'restecg',
    'exang', 'oldpeak', 'slope', 'ca', 'thal'
]
feature_name_indices = {name: index for index, name in enumerate(feature_names)}
# This is the vocab list and mapping we will use for the 'thal' categorical
# feature.
thal_vocab_list = ['normal', 'fixed', 'reversible']
thal_map = {category: i for i, category in enumerate(thal_vocab_list)}
# Custom function for converting thal categories to buckets
def convert_thal_features(thal_features):
  # Note that two examples in the test set are already converted.
  return np.array([
      thal_map[feature] if feature in thal_vocab_list else feature
      for feature in thal_features
  ])


# Custom function for extracting each feature.
def extract_features(dataframe,
                     label_name='target',
                     feature_names=feature_names):
  features = []
  for feature_name in feature_names:
    if feature_name == 'thal':
      features.append(
          convert_thal_features(dataframe[feature_name].values).astype(float))
    else:
      features.append(dataframe[feature_name].values.astype(float))
  labels = dataframe[label_name].values.astype(float)
  return features, labels
train_xs, train_ys = extract_features(train_dataframe)
test_xs, test_ys = extract_features(test_dataframe)
# Let's define our label minimum and maximum.
min_label, max_label = float(np.min(train_ys)), float(np.max(train_ys))
# Our lattice models may have predictions above 1.0 due to numerical errors.
# We can subtract this small epsilon value from our output_max to make sure we
# do not predict values outside of our label bound.
numerical_error_epsilon = 1e-5

Đặt các giá trị mặc định được sử dụng để đào tạo trong hướng dẫn này:

LEARNING_RATE = 0.01
BATCH_SIZE = 128
NUM_EPOCHS = 500
PREFITTING_NUM_EPOCHS = 10

Cấu hình tính năng

Tính năng hiệu chỉnh và cấu hình cho mỗi tính năng được thiết lập sử dụng tfl.configs.FeatureConfig . Cấu hình tính năng bao gồm các ràng buộc đơn điệu, quy tắc cho mỗi tính năng (xem tfl.configs.RegularizerConfig ), và kích thước lưới cho các mô hình mạng.

Lưu ý rằng chúng ta phải chỉ định đầy đủ cấu hình tính năng cho bất kỳ tính năng nào mà chúng ta muốn mô hình của mình nhận ra. Nếu không, mô hình sẽ không có cách nào để biết rằng một tính năng như vậy tồn tại.

Tính toán lượng tử

Mặc dù các thiết lập mặc định cho pwl_calibration_input_keypoints trong tfl.configs.FeatureConfig là 'quantiles', cho các mô hình premade chúng ta phải tự xác định keypoint đầu vào. Để làm như vậy, trước tiên chúng tôi xác định chức năng trợ giúp của riêng mình cho các lượng tử tính toán.

def compute_quantiles(features,
                      num_keypoints=10,
                      clip_min=None,
                      clip_max=None,
                      missing_value=None):
  # Clip min and max if desired.
  if clip_min is not None:
    features = np.maximum(features, clip_min)
    features = np.append(features, clip_min)
  if clip_max is not None:
    features = np.minimum(features, clip_max)
    features = np.append(features, clip_max)
  # Make features unique.
  unique_features = np.unique(features)
  # Remove missing values if specified.
  if missing_value is not None:
    unique_features = np.delete(unique_features,
                                np.where(unique_features == missing_value))
  # Compute and return quantiles over unique non-missing feature values.
  return np.quantile(
      unique_features,
      np.linspace(0., 1., num=num_keypoints),
      interpolation='nearest').astype(float)

Xác định cấu hình tính năng của chúng tôi

Bây giờ chúng ta có thể tính toán lượng tử của mình, chúng ta xác định cấu hình tính năng cho từng tính năng mà chúng ta muốn mô hình của mình lấy làm đầu vào.

# Feature configs are used to specify how each feature is calibrated and used.
feature_configs = [
    tfl.configs.FeatureConfig(
        name='age',
        lattice_size=3,
        monotonicity='increasing',
        # We must set the keypoints manually.
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['age']],
            num_keypoints=5,
            clip_max=100),
        # Per feature regularization.
        regularizer_configs=[
            tfl.configs.RegularizerConfig(name='calib_wrinkle', l2=0.1),
        ],
    ),
    tfl.configs.FeatureConfig(
        name='sex',
        num_buckets=2,
    ),
    tfl.configs.FeatureConfig(
        name='cp',
        monotonicity='increasing',
        # Keypoints that are uniformly spaced.
        pwl_calibration_num_keypoints=4,
        pwl_calibration_input_keypoints=np.linspace(
            np.min(train_xs[feature_name_indices['cp']]),
            np.max(train_xs[feature_name_indices['cp']]),
            num=4),
    ),
    tfl.configs.FeatureConfig(
        name='chol',
        monotonicity='increasing',
        # Explicit input keypoints initialization.
        pwl_calibration_input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0],
        # Calibration can be forced to span the full output range by clamping.
        pwl_calibration_clamp_min=True,
        pwl_calibration_clamp_max=True,
        # Per feature regularization.
        regularizer_configs=[
            tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-4),
        ],
    ),
    tfl.configs.FeatureConfig(
        name='fbs',
        # Partial monotonicity: output(0) <= output(1)
        monotonicity=[(0, 1)],
        num_buckets=2,
    ),
    tfl.configs.FeatureConfig(
        name='trestbps',
        monotonicity='decreasing',
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['trestbps']], num_keypoints=5),
    ),
    tfl.configs.FeatureConfig(
        name='thalach',
        monotonicity='decreasing',
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['thalach']], num_keypoints=5),
    ),
    tfl.configs.FeatureConfig(
        name='restecg',
        # Partial monotonicity: output(0) <= output(1), output(0) <= output(2)
        monotonicity=[(0, 1), (0, 2)],
        num_buckets=3,
    ),
    tfl.configs.FeatureConfig(
        name='exang',
        # Partial monotonicity: output(0) <= output(1)
        monotonicity=[(0, 1)],
        num_buckets=2,
    ),
    tfl.configs.FeatureConfig(
        name='oldpeak',
        monotonicity='increasing',
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['oldpeak']], num_keypoints=5),
    ),
    tfl.configs.FeatureConfig(
        name='slope',
        # Partial monotonicity: output(0) <= output(1), output(1) <= output(2)
        monotonicity=[(0, 1), (1, 2)],
        num_buckets=3,
    ),
    tfl.configs.FeatureConfig(
        name='ca',
        monotonicity='increasing',
        pwl_calibration_num_keypoints=4,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['ca']], num_keypoints=4),
    ),
    tfl.configs.FeatureConfig(
        name='thal',
        # Partial monotonicity:
        # output(normal) <= output(fixed)
        # output(normal) <= output(reversible)
        monotonicity=[('normal', 'fixed'), ('normal', 'reversible')],
        num_buckets=3,
        # We must specify the vocabulary list in order to later set the
        # monotonicities since we used names and not indices.
        vocabulary_list=thal_vocab_list,
    ),
]

Tiếp theo, chúng ta cần đảm bảo đặt đúng các đơn nguyên cho các tính năng mà chúng ta đã sử dụng từ vựng tùy chỉnh (chẳng hạn như 'thal' ở trên).

tfl.premade_lib.set_categorical_monotonicities(feature_configs)

Mô hình tuyến tính đã hiệu chỉnh

Để xây dựng một mô hình premade TFL, đầu tiên xây dựng một cấu hình mô hình từ tfl.configs . Một mô hình tuyến tính hiệu chuẩn được xây dựng bằng cách sử dụng tfl.configs.CalibratedLinearConfig . Nó áp dụng hiệu chuẩn từng đoạn tuyến tính và phân loại trên các tính năng đầu vào, tiếp theo là kết hợp tuyến tính và hiệu chuẩn tuyến tính từng đoạn tùy chọn đầu ra. Khi sử dụng hiệu chuẩn đầu ra hoặc khi giới hạn đầu ra được chỉ định, lớp tuyến tính sẽ áp dụng giá trị trung bình có trọng số trên các đầu vào đã hiệu chuẩn.

Ví dụ này tạo ra một mô hình tuyến tính đã được hiệu chỉnh trên 5 đối tượng địa lý đầu tiên.

# Model config defines the model structure for the premade model.
linear_model_config = tfl.configs.CalibratedLinearConfig(
    feature_configs=feature_configs[:5],
    use_bias=True,
    # We must set the output min and max to that of the label.
    output_min=min_label,
    output_max=max_label,
    output_calibration=True,
    output_calibration_num_keypoints=10,
    output_initialization=np.linspace(min_label, max_label, num=10),
    regularizer_configs=[
        # Regularizer for the output calibrator.
        tfl.configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4),
    ])
# A CalibratedLinear premade model constructed from the given model config.
linear_model = tfl.premade.CalibratedLinear(linear_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(linear_model, show_layer_names=False, rankdir='LR')

png

Bây giờ, như với bất kỳ khác tf.keras.Model , chúng tôi biên soạn và phù hợp với mô hình dữ liệu của chúng tôi.

linear_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
linear_model.fit(
    train_xs[:5],
    train_ys,
    epochs=NUM_EPOCHS,
    batch_size=BATCH_SIZE,
    verbose=False)
<tensorflow.python.keras.callbacks.History at 0x7ff2bf765860>

Sau khi đào tạo mô hình của chúng tôi, chúng tôi có thể đánh giá nó trên bộ thử nghiệm của chúng tôi.

print('Test Set Evaluation...')
print(linear_model.evaluate(test_xs[:5], test_ys))
Test Set Evaluation...
2/2 [==============================] - 0s 3ms/step - loss: 0.4849 - auc: 0.8214
[0.48487865924835205, 0.8214285969734192]

Mô hình mạng đã hiệu chỉnh

Một mô hình mạng hiệu chuẩn được xây dựng bằng tfl.configs.CalibratedLatticeConfig . Mô hình mạng tinh thể đã hiệu chuẩn áp dụng hiệu chuẩn từng đoạn tuyến tính và phân loại trên các tính năng đầu vào, tiếp theo là mô hình mạng tinh thể và hiệu chuẩn tuyến tính từng đoạn tùy chọn đầu ra.

Ví dụ này tạo ra một mô hình mạng tinh thể đã được hiệu chỉnh trên 5 tính năng đầu tiên.

# This is a calibrated lattice model: inputs are calibrated, then combined
# non-linearly using a lattice layer.
lattice_model_config = tfl.configs.CalibratedLatticeConfig(
    feature_configs=feature_configs[:5],
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label],
    regularizer_configs=[
        # Torsion regularizer applied to the lattice to make it more linear.
        tfl.configs.RegularizerConfig(name='torsion', l2=1e-2),
        # Globally defined calibration regularizer is applied to all features.
        tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-2),
    ])
# A CalibratedLattice premade model constructed from the given model config.
lattice_model = tfl.premade.CalibratedLattice(lattice_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(lattice_model, show_layer_names=False, rankdir='LR')

png

Như trước đây, chúng tôi biên dịch, điều chỉnh và đánh giá mô hình của mình.

lattice_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
lattice_model.fit(
    train_xs[:5],
    train_ys,
    epochs=NUM_EPOCHS,
    batch_size=BATCH_SIZE,
    verbose=False)
print('Test Set Evaluation...')
print(lattice_model.evaluate(test_xs[:5], test_ys))
Test Set Evaluation...
2/2 [==============================] - 1s 3ms/step - loss: 0.4784 - auc_1: 0.8402
[0.47842937707901, 0.8402255773544312]

Mô hình kết hợp mạng lưới đã hiệu chỉnh

Khi số lượng đối tượng lớn, bạn có thể sử dụng mô hình tổng hợp, mô hình này tạo ra nhiều mạng nhỏ hơn cho các tập hợp con của các đối tượng và tính trung bình đầu ra của chúng thay vì chỉ tạo một mạng khổng lồ duy nhất. Mô hình mạng Ensemble được xây dựng sử dụng tfl.configs.CalibratedLatticeEnsembleConfig . Mô hình tổng thể mạng tinh thể đã hiệu chuẩn áp dụng hiệu chuẩn từng đoạn tuyến tính và phân loại trên tính năng đầu vào, tiếp theo là một tập hợp các mô hình mạng tinh thể và hiệu chuẩn tuyến tính từng đoạn tùy chọn đầu ra.

Khởi tạo Ensemble lưới rõ ràng

Nếu bạn đã biết tập hợp con của các tính năng nào bạn muốn đưa vào mạng của mình, thì bạn có thể đặt các mạng một cách rõ ràng bằng cách sử dụng tên các tính năng. Ví dụ này tạo ra một mô hình tổng thể mạng tinh thể đã được hiệu chuẩn với 5 mạng tinh thể và 3 tính năng trên mỗi mạng tinh thể.

# This is a calibrated lattice ensemble model: inputs are calibrated, then
# combined non-linearly and averaged using multiple lattice layers.
explicit_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(
    feature_configs=feature_configs,
    lattices=[['trestbps', 'chol', 'ca'], ['fbs', 'restecg', 'thal'],
              ['fbs', 'cp', 'oldpeak'], ['exang', 'slope', 'thalach'],
              ['restecg', 'age', 'sex']],
    num_lattices=5,
    lattice_rank=3,
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label])
# A CalibratedLatticeEnsemble premade model constructed from the given
# model config.
explicit_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(
    explicit_ensemble_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(
    explicit_ensemble_model, show_layer_names=False, rankdir='LR')

png

Như trước đây, chúng tôi biên dịch, điều chỉnh và đánh giá mô hình của mình.

explicit_ensemble_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
explicit_ensemble_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(explicit_ensemble_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
2/2 [==============================] - 1s 3ms/step - loss: 0.4281 - auc_2: 0.8659
[0.42808252573013306, 0.8659147620201111]

Hệ thống mạng ngẫu nhiên

Nếu bạn không chắc chắn nên cung cấp tập hợp con của các tính năng nào vào mạng của mình, thì một tùy chọn khác là sử dụng tập hợp con ngẫu nhiên của các tính năng cho mỗi mạng. Ví dụ này tạo ra một mô hình tổng thể mạng tinh thể đã được hiệu chuẩn với 5 mạng tinh thể và 3 tính năng trên mỗi mạng tinh thể.

# This is a calibrated lattice ensemble model: inputs are calibrated, then
# combined non-linearly and averaged using multiple lattice layers.
random_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(
    feature_configs=feature_configs,
    lattices='random',
    num_lattices=5,
    lattice_rank=3,
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label],
    random_seed=42)
# Now we must set the random lattice structure and construct the model.
tfl.premade_lib.set_random_lattice_ensemble(random_ensemble_model_config)
# A CalibratedLatticeEnsemble premade model constructed from the given
# model config.
random_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(
    random_ensemble_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(
    random_ensemble_model, show_layer_names=False, rankdir='LR')

png

Như trước đây, chúng tôi biên dịch, điều chỉnh và đánh giá mô hình của mình.

random_ensemble_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
random_ensemble_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(random_ensemble_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
2/2 [==============================] - 1s 3ms/step - loss: 0.3929 - auc_3: 0.9217
[0.3929273188114166, 0.9216791987419128]

Hệ thống mạng ngẫu nhiên lớp RTL

Khi sử dụng một ngẫu nhiên lưới quần, bạn có thể xác định rằng các mô hình sử dụng một đơn tfl.layers.RTL lớp. Chúng tôi lưu ý rằng tfl.layers.RTL chỉ hỗ trợ hạn chế đơn điệu và phải có kích thước lưới tương tự cho tất cả các tính năng và không theo quy tắc mỗi tính năng. Lưu ý rằng việc sử dụng một tfl.layers.RTL lớp cho phép bạn mở rộng để cụm công trình lớn hơn nhiều so với sử dụng riêng biệt tfl.layers.Lattice trường.

Ví dụ này tạo ra một mô hình tổng thể mạng tinh thể đã được hiệu chuẩn với 5 mạng tinh thể và 3 tính năng trên mỗi mạng tinh thể.

# Make sure our feature configs have the same lattice size, no per-feature
# regularization, and only monotonicity constraints.
rtl_layer_feature_configs = copy.deepcopy(feature_configs)
for feature_config in rtl_layer_feature_configs:
  feature_config.lattice_size = 2
  feature_config.unimodality = 'none'
  feature_config.reflects_trust_in = None
  feature_config.dominates = None
  feature_config.regularizer_configs = None
# This is a calibrated lattice ensemble model: inputs are calibrated, then
# combined non-linearly and averaged using multiple lattice layers.
rtl_layer_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(
    feature_configs=rtl_layer_feature_configs,
    lattices='rtl_layer',
    num_lattices=5,
    lattice_rank=3,
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label],
    random_seed=42)
# A CalibratedLatticeEnsemble premade model constructed from the given
# model config. Note that we do not have to specify the lattices by calling
# a helper function (like before with random) because the RTL Layer will take
# care of that for us.
rtl_layer_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(
    rtl_layer_ensemble_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(
    rtl_layer_ensemble_model, show_layer_names=False, rankdir='LR')

png

Như trước đây, chúng tôi biên dịch, điều chỉnh và đánh giá mô hình của mình.

rtl_layer_ensemble_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
rtl_layer_ensemble_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(rtl_layer_ensemble_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
2/2 [==============================] - 0s 3ms/step - loss: 0.4286 - auc_4: 0.8690
[0.42856135964393616, 0.8690476417541504]

Crystals Lattice Ensemble

Premade cũng cung cấp một thuật toán sắp xếp tính năng heuristic, gọi là Crystal . Để sử dụng thuật toán Crystals, trước tiên, chúng tôi đào tạo một mô hình prefitting ước tính các tương tác tính năng theo từng cặp. Sau đó, chúng tôi sắp xếp quần thể cuối cùng sao cho các đặc điểm có nhiều tương tác phi tuyến tính hơn nằm trong cùng một mạng lưới.

Thư viện Premade cung cấp các chức năng trợ giúp để xây dựng cấu hình mô hình chuẩn bị sẵn và trích xuất cấu trúc tinh thể. Lưu ý rằng mô hình prefitting không cần được đào tạo đầy đủ, vì vậy chỉ cần một vài kỷ nguyên là đủ.

Ví dụ này tạo ra một mô hình tổng thể mạng tinh thể đã được hiệu chuẩn với 5 mạng tinh thể và 3 tính năng trên mỗi mạng tinh thể.

# This is a calibrated lattice ensemble model: inputs are calibrated, then
# combines non-linearly and averaged using multiple lattice layers.
crystals_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(
    feature_configs=feature_configs,
    lattices='crystals',
    num_lattices=5,
    lattice_rank=3,
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label],
    random_seed=42)
# Now that we have our model config, we can construct a prefitting model config.
prefitting_model_config = tfl.premade_lib.construct_prefitting_model_config(
    crystals_ensemble_model_config)
# A CalibratedLatticeEnsemble premade model constructed from the given
# prefitting model config.
prefitting_model = tfl.premade.CalibratedLatticeEnsemble(
    prefitting_model_config)
# We can compile and train our prefitting model as we like.
prefitting_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
prefitting_model.fit(
    train_xs,
    train_ys,
    epochs=PREFITTING_NUM_EPOCHS,
    batch_size=BATCH_SIZE,
    verbose=False)
# Now that we have our trained prefitting model, we can extract the crystals.
tfl.premade_lib.set_crystals_lattice_ensemble(crystals_ensemble_model_config,
                                              prefitting_model_config,
                                              prefitting_model)
# A CalibratedLatticeEnsemble premade model constructed from the given
# model config.
crystals_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(
    crystals_ensemble_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(
    crystals_ensemble_model, show_layer_names=False, rankdir='LR')

png

Như trước đây, chúng tôi biên dịch, điều chỉnh và đánh giá mô hình của mình.

crystals_ensemble_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
crystals_ensemble_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(crystals_ensemble_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
2/2 [==============================] - 1s 3ms/step - loss: 0.4671 - auc_5: 0.8283
[0.46707457304000854, 0.8283208608627319]