تاریخ را ذخیره کنید! Google I / O 18-20 مه بازمی گردد اکنون ثبت نام کنید
این صفحه به‌وسیله ‏Cloud Translation API‏ ترجمه شده است.
Switch to English

مدلهای پیش ساخته TF Lattice

مشاهده در TensorFlow.org در Google Colab اجرا کنید مشاهده منبع در GitHub دانلود دفترچه یادداشت

بررسی اجمالی

مدل های Premade روش های سریع و ساده ای برای ساخت نمونه های TFL tf.keras.model برای موارد استفاده معمولی هستند. این راهنما مراحل لازم برای ساخت یک مدل TFL Premade و آموزش / آزمایش آن را مشخص می کند.

برپایی

نصب بسته TF Lattice:

pip install -q tensorflow-lattice pydot

وارد کردن بسته های مورد نیاز:

import tensorflow as tf

import copy
import logging
import numpy as np
import pandas as pd
import sys
import tensorflow_lattice as tfl
logging.disable(sys.maxsize)

بارگیری مجموعه داده های UCI Statlog (قلب):

csv_file = tf.keras.utils.get_file(
    'heart.csv', 'http://storage.googleapis.com/download.tensorflow.org/data/heart.csv')
df = pd.read_csv(csv_file)
train_size = int(len(df) * 0.8)
train_dataframe = df[:train_size]
test_dataframe = df[train_size:]
df.head()
Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/heart.csv
16384/13273 [=====================================] - 0s 0us/step

استخراج و تبدیل ویژگی ها و برچسب ها به تنسور:

# Features:
# - age
# - sex
# - cp        chest pain type (4 values)
# - trestbps  resting blood pressure
# - chol      serum cholestoral in mg/dl
# - fbs       fasting blood sugar > 120 mg/dl
# - restecg   resting electrocardiographic results (values 0,1,2)
# - thalach   maximum heart rate achieved
# - exang     exercise induced angina
# - oldpeak   ST depression induced by exercise relative to rest
# - slope     the slope of the peak exercise ST segment
# - ca        number of major vessels (0-3) colored by flourosopy
# - thal      3 = normal; 6 = fixed defect; 7 = reversable defect
#
# This ordering of feature names will be the exact same order that we construct
# our model to expect.
feature_names = [
    'age', 'sex', 'cp', 'chol', 'fbs', 'trestbps', 'thalach', 'restecg',
    'exang', 'oldpeak', 'slope', 'ca', 'thal'
]
feature_name_indices = {name: index for index, name in enumerate(feature_names)}
# This is the vocab list and mapping we will use for the 'thal' categorical
# feature.
thal_vocab_list = ['normal', 'fixed', 'reversible']
thal_map = {category: i for i, category in enumerate(thal_vocab_list)}
# Custom function for converting thal categories to buckets
def convert_thal_features(thal_features):
  # Note that two examples in the test set are already converted.
  return np.array([
      thal_map[feature] if feature in thal_vocab_list else feature
      for feature in thal_features
  ])


# Custom function for extracting each feature.
def extract_features(dataframe,
                     label_name='target',
                     feature_names=feature_names):
  features = []
  for feature_name in feature_names:
    if feature_name == 'thal':
      features.append(
          convert_thal_features(dataframe[feature_name].values).astype(float))
    else:
      features.append(dataframe[feature_name].values.astype(float))
  labels = dataframe[label_name].values.astype(float)
  return features, labels
train_xs, train_ys = extract_features(train_dataframe)
test_xs, test_ys = extract_features(test_dataframe)
# Let's define our label minimum and maximum.
min_label, max_label = float(np.min(train_ys)), float(np.max(train_ys))
# Our lattice models may have predictions above 1.0 due to numerical errors.
# We can subtract this small epsilon value from our output_max to make sure we
# do not predict values outside of our label bound.
numerical_error_epsilon = 1e-5

تنظیم مقادیر پیش فرض مورد استفاده برای آموزش در این راهنما:

LEARNING_RATE = 0.01
BATCH_SIZE = 128
NUM_EPOCHS = 500
PREFITTING_NUM_EPOCHS = 10

پیکربندی های ویژگی

تنظیمات ویژگی و تنظیمات هر ویژگی با استفاده از tfl.configs.FeatureConfig تنظیم می شوند . پیکربندی های ویژگی شامل محدودیت های یکنواختی ، تنظیم هر ویژگی (برای دیدن مدل های شبکه به اندازه tfl.configs.RegularizerConfig مراجعه کنید) و اندازه ها.

توجه داشته باشید که باید پیکربندی ویژگی را برای هر ویژگی که می خواهیم مدل ما آن را تشخیص دهد ، کاملاً مشخص کنیم. در غیر این صورت مدل هیچ راهی برای دانستن وجود چنین ویژگی نخواهد داشت.

محاسبه Quantiles

اگرچه تنظیمات پیش فرض pwl_calibration_input_keypoints در tfl.configs.FeatureConfig "quantiles" است ، اما برای مدل های پیش ساخته ما باید دستی کلیدهای ورودی را تعریف کنیم. برای انجام این کار ، ما ابتدا عملکرد کمکی خود را برای محاسبه مقادیر کمی تعریف می کنیم.

def compute_quantiles(features,
                      num_keypoints=10,
                      clip_min=None,
                      clip_max=None,
                      missing_value=None):
  # Clip min and max if desired.
  if clip_min is not None:
    features = np.maximum(features, clip_min)
    features = np.append(features, clip_min)
  if clip_max is not None:
    features = np.minimum(features, clip_max)
    features = np.append(features, clip_max)
  # Make features unique.
  unique_features = np.unique(features)
  # Remove missing values if specified.
  if missing_value is not None:
    unique_features = np.delete(unique_features,
                                np.where(unique_features == missing_value))
  # Compute and return quantiles over unique non-missing feature values.
  return np.quantile(
      unique_features,
      np.linspace(0., 1., num=num_keypoints),
      interpolation='nearest').astype(float)

تعریف تنظیمات ویژگی ما

اکنون که می توانیم مقدارهای خود را محاسبه کنیم ، برای هر ویژگی یک پیکربندی ویژگی را تعریف می کنیم که می خواهیم مدل ما به عنوان ورودی آن باشد.

# Feature configs are used to specify how each feature is calibrated and used.
feature_configs = [
    tfl.configs.FeatureConfig(
        name='age',
        lattice_size=3,
        monotonicity='increasing',
        # We must set the keypoints manually.
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['age']],
            num_keypoints=5,
            clip_max=100),
        # Per feature regularization.
        regularizer_configs=[
            tfl.configs.RegularizerConfig(name='calib_wrinkle', l2=0.1),
        ],
    ),
    tfl.configs.FeatureConfig(
        name='sex',
        num_buckets=2,
    ),
    tfl.configs.FeatureConfig(
        name='cp',
        monotonicity='increasing',
        # Keypoints that are uniformly spaced.
        pwl_calibration_num_keypoints=4,
        pwl_calibration_input_keypoints=np.linspace(
            np.min(train_xs[feature_name_indices['cp']]),
            np.max(train_xs[feature_name_indices['cp']]),
            num=4),
    ),
    tfl.configs.FeatureConfig(
        name='chol',
        monotonicity='increasing',
        # Explicit input keypoints initialization.
        pwl_calibration_input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0],
        # Calibration can be forced to span the full output range by clamping.
        pwl_calibration_clamp_min=True,
        pwl_calibration_clamp_max=True,
        # Per feature regularization.
        regularizer_configs=[
            tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-4),
        ],
    ),
    tfl.configs.FeatureConfig(
        name='fbs',
        # Partial monotonicity: output(0) <= output(1)
        monotonicity=[(0, 1)],
        num_buckets=2,
    ),
    tfl.configs.FeatureConfig(
        name='trestbps',
        monotonicity='decreasing',
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['trestbps']], num_keypoints=5),
    ),
    tfl.configs.FeatureConfig(
        name='thalach',
        monotonicity='decreasing',
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['thalach']], num_keypoints=5),
    ),
    tfl.configs.FeatureConfig(
        name='restecg',
        # Partial monotonicity: output(0) <= output(1), output(0) <= output(2)
        monotonicity=[(0, 1), (0, 2)],
        num_buckets=3,
    ),
    tfl.configs.FeatureConfig(
        name='exang',
        # Partial monotonicity: output(0) <= output(1)
        monotonicity=[(0, 1)],
        num_buckets=2,
    ),
    tfl.configs.FeatureConfig(
        name='oldpeak',
        monotonicity='increasing',
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['oldpeak']], num_keypoints=5),
    ),
    tfl.configs.FeatureConfig(
        name='slope',
        # Partial monotonicity: output(0) <= output(1), output(1) <= output(2)
        monotonicity=[(0, 1), (1, 2)],
        num_buckets=3,
    ),
    tfl.configs.FeatureConfig(
        name='ca',
        monotonicity='increasing',
        pwl_calibration_num_keypoints=4,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['ca']], num_keypoints=4),
    ),
    tfl.configs.FeatureConfig(
        name='thal',
        # Partial monotonicity:
        # output(normal) <= output(fixed)
        # output(normal) <= output(reversible)
        monotonicity=[('normal', 'fixed'), ('normal', 'reversible')],
        num_buckets=3,
        # We must specify the vocabulary list in order to later set the
        # monotonicities since we used names and not indices.
        vocabulary_list=thal_vocab_list,
    ),
]

در مرحله بعدی باید اطمینان حاصل کنیم که یکنواختی ها را به درستی برای ویژگی هایی که از واژگان سفارشی استفاده می کنیم (مانند "thal" در بالا) تنظیم کنیم.

tfl.premade_lib.set_categorical_monotonicities(feature_configs)

مدل خطی کالیبره شده

برای ساخت یک مدل پیش ساخته TFL ، ابتدا از tfl.configs یک پیکربندی مدل بسازید . یک مدل خطی کالیبره شده با استفاده از tfl.configs.CalibratedLinearConfig ساخته شده است . این کالیبراسیون را به صورت قطعه ای و خطی بر روی ویژگی های ورودی اعمال می کند و به دنبال آن یک ترکیب خطی و یک کالیبراسیون قطعه ای و خطی خروجی اختیاری اعمال می شود. هنگام استفاده از کالیبراسیون خروجی یا تعیین مرزهای خروجی ، لایه خطی میانگین ورودی را روی ورودی های کالیبره شده اعمال می کند.

این مثال یک مدل خطی کالیبره شده را روی 5 ویژگی اول ایجاد می کند.

# Model config defines the model structure for the premade model.
linear_model_config = tfl.configs.CalibratedLinearConfig(
    feature_configs=feature_configs[:5],
    use_bias=True,
    # We must set the output min and max to that of the label.
    output_min=min_label,
    output_max=max_label,
    output_calibration=True,
    output_calibration_num_keypoints=10,
    output_initialization=np.linspace(min_label, max_label, num=10),
    regularizer_configs=[
        # Regularizer for the output calibrator.
        tfl.configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4),
    ])
# A CalibratedLinear premade model constructed from the given model config.
linear_model = tfl.premade.CalibratedLinear(linear_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(linear_model, show_layer_names=False, rankdir='LR')

png

اکنون ، مانند هر مدل tf.keras دیگر ، ما مدل را با داده های خود کامپایل و متناسب می کنیم.

linear_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
linear_model.fit(
    train_xs[:5],
    train_ys,
    epochs=NUM_EPOCHS,
    batch_size=BATCH_SIZE,
    verbose=False)
<tensorflow.python.keras.callbacks.History at 0x7ff2bf765860>

پس از آموزش مدل خود ، می توانیم آن را در مجموعه آزمایشات خود ارزیابی کنیم.

print('Test Set Evaluation...')
print(linear_model.evaluate(test_xs[:5], test_ys))
Test Set Evaluation...
2/2 [==============================] - 0s 3ms/step - loss: 0.4849 - auc: 0.8214
[0.48487865924835205, 0.8214285969734192]

مدل شبکه کالیبره شده

یک مدل شبکه کالیبره شده با استفاده از tfl.configs.CalibratedLatticeConfig ساخته شده است . یک مدل شبکه کالیبره شده کالیبراسیون قطعه ای و خطی را بر روی ویژگی های ورودی اعمال می کند ، به دنبال آن یک مدل شبکه ای و یک کالیبراسیون قطعه ای خطی خروجی اختیاری ارائه می شود.

این مثال یک مدل شبکه کالیبره شده را روی 5 ویژگی اول ایجاد می کند.

# This is a calibrated lattice model: inputs are calibrated, then combined
# non-linearly using a lattice layer.
lattice_model_config = tfl.configs.CalibratedLatticeConfig(
    feature_configs=feature_configs[:5],
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label],
    regularizer_configs=[
        # Torsion regularizer applied to the lattice to make it more linear.
        tfl.configs.RegularizerConfig(name='torsion', l2=1e-2),
        # Globally defined calibration regularizer is applied to all features.
        tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-2),
    ])
# A CalibratedLattice premade model constructed from the given model config.
lattice_model = tfl.premade.CalibratedLattice(lattice_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(lattice_model, show_layer_names=False, rankdir='LR')

png

مانند گذشته ، ما مدل خود را تدوین ، متناسب و ارزیابی می کنیم.

lattice_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
lattice_model.fit(
    train_xs[:5],
    train_ys,
    epochs=NUM_EPOCHS,
    batch_size=BATCH_SIZE,
    verbose=False)
print('Test Set Evaluation...')
print(lattice_model.evaluate(test_xs[:5], test_ys))
Test Set Evaluation...
2/2 [==============================] - 1s 3ms/step - loss: 0.4784 - auc_1: 0.8402
[0.47842937707901, 0.8402255773544312]

مدل گروه مشبک کالیبره شده

هنگامی که تعداد ویژگی ها زیاد است ، می توانید از یک مدل گروهی استفاده کنید ، که به جای ایجاد فقط یک شبکه بزرگ عظیم ، چندین شبکه کوچکتر برای زیر مجموعه ویژگی ها ایجاد می کند و خروجی آنها را متوسط ​​می کند. مدل های شبکه شبکه با استفاده از tfl.configs.CalibratedLatticeEnsembleConfig ساخته می شوند . یک مدل گروه شبکه کالیبره شده کالیبراسیون به صورت جزئی و خطی را روی ویژگی ورودی اعمال می کند و به دنبال آن مجموعه ای از مدل های شبکه و یک کالیبراسیون قطعه ای و خطی خروجی اختیاری ارائه می شود.

آغاز صریح گروه مشبک

اگر قبلاً می دانید که کدام زیر مجموعه ویژگی ها را می خواهید در شبکه های خود تغذیه کنید ، می توانید به وضوح شبکه ها را با استفاده از نام ویژگی ها تنظیم کنید. این مثال یک مدل گروه شبکه کالیبره شده با 5 شبکه و 3 ویژگی در هر شبکه ایجاد می کند.

# This is a calibrated lattice ensemble model: inputs are calibrated, then
# combined non-linearly and averaged using multiple lattice layers.
explicit_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(
    feature_configs=feature_configs,
    lattices=[['trestbps', 'chol', 'ca'], ['fbs', 'restecg', 'thal'],
              ['fbs', 'cp', 'oldpeak'], ['exang', 'slope', 'thalach'],
              ['restecg', 'age', 'sex']],
    num_lattices=5,
    lattice_rank=3,
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label])
# A CalibratedLatticeEnsemble premade model constructed from the given
# model config.
explicit_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(
    explicit_ensemble_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(
    explicit_ensemble_model, show_layer_names=False, rankdir='LR')

png

مانند گذشته ، ما مدل خود را تدوین ، متناسب و ارزیابی می کنیم.

explicit_ensemble_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
explicit_ensemble_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(explicit_ensemble_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
2/2 [==============================] - 1s 3ms/step - loss: 0.4281 - auc_2: 0.8659
[0.42808252573013306, 0.8659147620201111]

گروه مشبک تصادفی

اگر مطمئن نیستید که کدام زیر مجموعه از ویژگی ها را در شبکه های شما تغذیه می کنید ، گزینه دیگر استفاده از زیر مجموعه های تصادفی ویژگی برای هر شبکه است. این مثال یک مدل گروه شبکه کالیبره شده با 5 شبکه و 3 ویژگی در هر شبکه ایجاد می کند.

# This is a calibrated lattice ensemble model: inputs are calibrated, then
# combined non-linearly and averaged using multiple lattice layers.
random_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(
    feature_configs=feature_configs,
    lattices='random',
    num_lattices=5,
    lattice_rank=3,
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label],
    random_seed=42)
# Now we must set the random lattice structure and construct the model.
tfl.premade_lib.set_random_lattice_ensemble(random_ensemble_model_config)
# A CalibratedLatticeEnsemble premade model constructed from the given
# model config.
random_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(
    random_ensemble_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(
    random_ensemble_model, show_layer_names=False, rankdir='LR')

png

مانند گذشته ، ما مدل خود را تدوین ، متناسب و ارزیابی می کنیم.

random_ensemble_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
random_ensemble_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(random_ensemble_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
2/2 [==============================] - 1s 3ms/step - loss: 0.3929 - auc_3: 0.9217
[0.3929273188114166, 0.9216791987419128]

گروه مشبک تصادفی RTL Layer

هنگام استفاده از یک مجموعه مشبک تصادفی ، می توانید تعیین کنید که مدل از یک لایه tfl.layers.RTL استفاده کند. ما توجه داریم که tfl.layers.RTL فقط از محدودیت های یکنواختی پشتیبانی می کند و باید اندازه شبکه را برای همه ویژگی ها یکسان داشته باشد و هیچ نظم در هر ویژگی نداشته باشد. توجه داشته باشید که با استفاده از یک لایه tfl.layers.RTL شما امکان می دهد تا نسبت به استفاده از نمونه های جداگانه tfl.layers.Lattice گروه های بزرگتر مقیاس tfl.layers.Lattice .

این مثال یک مدل گروه شبکه کالیبره شده با 5 شبکه و 3 ویژگی در هر شبکه ایجاد می کند.

# Make sure our feature configs have the same lattice size, no per-feature
# regularization, and only monotonicity constraints.
rtl_layer_feature_configs = copy.deepcopy(feature_configs)
for feature_config in rtl_layer_feature_configs:
  feature_config.lattice_size = 2
  feature_config.unimodality = 'none'
  feature_config.reflects_trust_in = None
  feature_config.dominates = None
  feature_config.regularizer_configs = None
# This is a calibrated lattice ensemble model: inputs are calibrated, then
# combined non-linearly and averaged using multiple lattice layers.
rtl_layer_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(
    feature_configs=rtl_layer_feature_configs,
    lattices='rtl_layer',
    num_lattices=5,
    lattice_rank=3,
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label],
    random_seed=42)
# A CalibratedLatticeEnsemble premade model constructed from the given
# model config. Note that we do not have to specify the lattices by calling
# a helper function (like before with random) because the RTL Layer will take
# care of that for us.
rtl_layer_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(
    rtl_layer_ensemble_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(
    rtl_layer_ensemble_model, show_layer_names=False, rankdir='LR')

png

مانند گذشته ، ما مدل خود را تدوین ، متناسب و ارزیابی می کنیم.

rtl_layer_ensemble_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
rtl_layer_ensemble_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(rtl_layer_ensemble_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
2/2 [==============================] - 0s 3ms/step - loss: 0.4286 - auc_4: 0.8690
[0.42856135964393616, 0.8690476417541504]

گروه مشبک بلورها

Premade همچنین یک الگوریتم تنظیم ویژگی اکتشافی ، به نام Crystals ارائه می دهد . برای استفاده از الگوریتم Crystals ، ابتدا یک مدل ترجیحی را آموزش می دهیم که برهم کنش ویژگی های دوتایی را تخمین می زند. سپس گروه نهایی را به گونه ای ترتیب می دهیم که ویژگی هایی با فعل و انفعالات غیر خطی بیشتر در همان شبکه ها باشند.

کتابخانه Premade توابع کمکی را برای ساخت پیکربندی مدل قبل و استخراج ساختار بلورها ارائه می دهد. توجه داشته باشید که مدل دلفریب نیازی به آموزش کامل ندارد ، بنابراین چند دوره باید کافی باشد.

این مثال یک مدل گروه شبکه کالیبره شده با 5 شبکه و 3 ویژگی در هر شبکه ایجاد می کند.

# This is a calibrated lattice ensemble model: inputs are calibrated, then
# combines non-linearly and averaged using multiple lattice layers.
crystals_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(
    feature_configs=feature_configs,
    lattices='crystals',
    num_lattices=5,
    lattice_rank=3,
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label],
    random_seed=42)
# Now that we have our model config, we can construct a prefitting model config.
prefitting_model_config = tfl.premade_lib.construct_prefitting_model_config(
    crystals_ensemble_model_config)
# A CalibratedLatticeEnsemble premade model constructed from the given
# prefitting model config.
prefitting_model = tfl.premade.CalibratedLatticeEnsemble(
    prefitting_model_config)
# We can compile and train our prefitting model as we like.
prefitting_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
prefitting_model.fit(
    train_xs,
    train_ys,
    epochs=PREFITTING_NUM_EPOCHS,
    batch_size=BATCH_SIZE,
    verbose=False)
# Now that we have our trained prefitting model, we can extract the crystals.
tfl.premade_lib.set_crystals_lattice_ensemble(crystals_ensemble_model_config,
                                              prefitting_model_config,
                                              prefitting_model)
# A CalibratedLatticeEnsemble premade model constructed from the given
# model config.
crystals_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(
    crystals_ensemble_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(
    crystals_ensemble_model, show_layer_names=False, rankdir='LR')

png

مانند گذشته ، ما مدل خود را تدوین ، متناسب و ارزیابی می کنیم.

crystals_ensemble_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
crystals_ensemble_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(crystals_ensemble_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
2/2 [==============================] - 1s 3ms/step - loss: 0.4671 - auc_5: 0.8283
[0.46707457304000854, 0.8283208608627319]