Tuning a wide and deep model using Google Cloud

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook Kaggle logoRun in Kaggle

In this example we will use CloudTuner and Google Cloud to Tune a Wide and Deep Model based on the tunable model introduced in structured data learning with Wide, Deep, and Cross networks. In this example we will use the data set from CAIIS Dogfood Day

Import required modules

import datetime
import uuid

import numpy as np
import pandas as pd
import tensorflow as tf
import os
import sys
import subprocess

from tensorflow.keras import datasets, layers, models
from sklearn.model_selection import train_test_split

# Install the latest version of tensorflow_cloud and other required packages.
if os.environ.get("TF_KERAS_RUNNING_REMOTELY", True):
    subprocess.run(
        ['python3', '-m', 'pip', 'install', 'tensorflow-cloud', '-q'])
    subprocess.run(
        ['python3', '-m', 'pip', 'install', 'google-cloud-storage', '-q'])
    subprocess.run(
        ['python3', '-m', 'pip', 'install', 'fsspec', '-q'])
    subprocess.run(
        ['python3', '-m', 'pip', 'install', 'gcsfs', '-q'])

import tensorflow_cloud as tfc
print(tfc.__version__)
0.1.15
tf.version.VERSION
'2.6.0'

Project Configurations

Setting project parameters. For more details on Google Cloud Specific parameters please refer to Google Cloud Project Setup Instructions.

# Set Google Cloud Specific parameters

# TODO: Please set GCP_PROJECT_ID to your own Google Cloud project ID.
GCP_PROJECT_ID = 'YOUR_PROJECT_ID' 

# TODO: Change the Service Account Name to your own Service Account
SERVICE_ACCOUNT_NAME = 'YOUR_SERVICE_ACCOUNT_NAME'
SERVICE_ACCOUNT = f'{SERVICE_ACCOUNT_NAME}@{GCP_PROJECT_ID}.iam.gserviceaccount.com'

# TODO: set GCS_BUCKET to your own Google Cloud Storage (GCS) bucket.
GCS_BUCKET = 'YOUR_GCS_BUCKET_NAME'

# DO NOT CHANGE: Currently only the 'us-central1' region is supported.
REGION = 'us-central1'
# Set Tuning Specific parameters

# OPTIONAL: You can change the job name to any string.
JOB_NAME = 'wide_and_deep'

# OPTIONAL: Set Number of concurrent tuning jobs that you would like to run.
NUM_JOBS = 5

# TODO: Set the study ID for this run. Study_ID can be any unique string.
# Reusing the same Study_ID will cause the Tuner to continue tuning the
# Same Study parameters. This can be used to continue on a terminated job,
# or load stats from a previous study.
STUDY_NUMBER = '00001'
STUDY_ID = f'{GCP_PROJECT_ID}_{JOB_NAME}_{STUDY_NUMBER}'

# Setting location were training logs and checkpoints will be stored
GCS_BASE_PATH = f'gs://{GCS_BUCKET}/{JOB_NAME}/{STUDY_ID}'
TENSORBOARD_LOGS_DIR = os.path.join(GCS_BASE_PATH,"logs")

Authenticating the notebook to use your Google Cloud Project

For Kaggle Notebooks click on "Add-ons"->"Google Cloud SDK" before running the cell below.

# Using tfc.remote() to ensure this code only runs in notebook
if not tfc.remote():

    # Authentication for Kaggle Notebooks
    if "kaggle_secrets" in sys.modules:
        from kaggle_secrets import UserSecretsClient
        UserSecretsClient().set_gcloud_credentials(project=GCP_PROJECT_ID)

    # Authentication for Colab Notebooks
    if "google.colab" in sys.modules:
        from google.colab import auth
        auth.authenticate_user()
        os.environ["GOOGLE_CLOUD_PROJECT"] = GCP_PROJECT_ID

Load the data

Read raw data and split to train and test data sets. For this step you will need to copy the dataset to your GCS bucket so it can be accessed during training. For this example we are using the dataset from https://www.kaggle.com/c/caiis-dogfood-day-2020

To do this you can run the following commands to download and copy the dataset to your GCS bucket, or manually download the dataset vi Kaggle UI and upload the train.csv file to your GCS bucket vi GCS UI.

# Download the dataset
kaggle competitions download -c caiis-dogfood-day-2020

# Copy the training file to your bucket
gsutil cp ./caiis-dogfood-day-2020/train.csv $GCS_BASE_PATH/caiis-dogfood-day-2020/train.csv
train_URL = f'{GCS_BASE_PATH}/caiis-dogfood-day-2020/train.csv'
data = pd.read_csv(train_URL)
train, test = train_test_split(data, test_size=0.1)
# A utility method to create a tf.data dataset from a Pandas Dataframe
def df_to_dataset(df, shuffle=True, batch_size=32):
  df = df.copy()
  labels = df.pop('target')
  ds = tf.data.Dataset.from_tensor_slices((dict(df), labels))
  if shuffle:
    ds = ds.shuffle(buffer_size=len(df))
  ds = ds.batch(batch_size)
  return ds
sm_batch_size = 1000  # A small batch size is used for demonstration purposes
train_ds = df_to_dataset(train, batch_size=sm_batch_size)
test_ds = df_to_dataset(test, shuffle=False, batch_size=sm_batch_size)

Preprocess the data

Setting up preprocessing layers for categorical and numerical input data. For more details on preprocessing layers please refer to working with preprocessing layers.

from tensorflow.keras.layers.experimental import preprocessing

def create_model_inputs():
    inputs ={}
    for name, column in data.items():
        if name in ('id','target'):
            continue
        dtype = column.dtype
        if dtype == object:
            dtype = tf.string
        else:
            dtype = tf.float32

        inputs[name] = tf.keras.Input(shape=(1,), name=name, dtype=dtype)

    return inputs
#Preprocessing the numeric inputs, and running them through a normalization layer.
def preprocess_numeric_inputs(inputs):

    numeric_inputs = {name:input for name,input in inputs.items()
                      if input.dtype==tf.float32}

    x = layers.Concatenate()(list(numeric_inputs.values()))
    norm = preprocessing.Normalization()
    norm.adapt(np.array(data[numeric_inputs.keys()]))
    numeric_inputs = norm(x)
    return numeric_inputs
# Preprocessing the categorical inputs.
def preprocess_categorical_inputs(inputs):
    categorical_inputs = []
    for name, input in inputs.items():
        if input.dtype == tf.float32:
            continue

        lookup = preprocessing.StringLookup(vocabulary=np.unique(data[name]))
        one_hot = preprocessing.CategoryEncoding(max_tokens=lookup.vocab_size())

        x = lookup(input)
        x = one_hot(x)
        categorical_inputs.append(x)

    return layers.concatenate(categorical_inputs)

Define the model architecture and hyperparameters

In this section we define our tuning parameters using Keras Tuner Hyper Parameters and a model-building function. The model-building function takes an argument hp from which you can sample hyperparameters, such as hp.Int('units', min_value=32, max_value=512, step=32) (an integer from a certain range).

import kerastuner

# Configure the search space
HPS = kerastuner.engine.hyperparameters.HyperParameters()
HPS.Float('learning_rate', min_value=1e-4, max_value=1e-2, sampling='log')

HPS.Int('num_layers', min_value=2, max_value=5)
for i in range(5):
    HPS.Float('dropout_rate_' + str(i), min_value=0.0, max_value=0.3, step=0.1)
    HPS.Choice('num_units_' + str(i), [32, 64, 128, 256])
from tensorflow.keras import layers
from tensorflow.keras.optimizers import Adam


def create_wide_and_deep_model(hp):

    inputs = create_model_inputs()
    wide = preprocess_categorical_inputs(inputs)
    wide = layers.BatchNormalization()(wide)

    deep = preprocess_numeric_inputs(inputs)
    for i in range(hp.get('num_layers')):
        deep = layers.Dense(hp.get('num_units_' + str(i)))(deep)
        deep = layers.BatchNormalization()(deep)
        deep = layers.ReLU()(deep)
        deep = layers.Dropout(hp.get('dropout_rate_' + str(i)))(deep)

    both = layers.concatenate([wide, deep])
    outputs = layers.Dense(1, activation='sigmoid')(both)
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    metrics = [
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall'),
        'accuracy',
        'mse'
    ]

    model.compile(
        optimizer=Adam(lr=hp.get('learning_rate')),
        loss='binary_crossentropy',
        metrics=metrics)
    return model

Configure a CloudTuner

In this section we configure the cloud tuner for both remote and local execution. The main difference between the two is the distribution strategy.

from tensorflow_cloud import CloudTuner

distribution_strategy = None
if not tfc.remote():
    # Using MirroredStrategy to use a single instance with multiple GPUs
    # during remote execution while using no strategy for local.
    distribution_strategy = tf.distribute.MirroredStrategy()

tuner = CloudTuner(
    create_wide_and_deep_model,
    project_id=GCP_PROJECT_ID,
    project_name=JOB_NAME,
    region=REGION,
    objective='accuracy',
    hyperparameters=HPS,
    max_trials=100,
    directory=GCS_BASE_PATH,
    study_id=STUDY_ID,
    overwrite=True,
    distribution_strategy=distribution_strategy)
# Configure Tensorboard logs
callbacks=[
    tf.keras.callbacks.TensorBoard(log_dir=TENSORBOARD_LOGS_DIR)]

# Setting to run tuning remotely, you can run tuner locally to validate it works first.
if tfc.remote():
    tuner.search(train_ds, epochs=20, validation_data=test_ds, callbacks=callbacks)

# You can uncomment the code below to run the tuner.search() locally to validate
# everything works before submitting the job to Cloud. Stop the job manually
# after one epoch.

# else:
#     tuner.search(train_ds, epochs=1, validation_data=test_ds, callbacks=callbacks)

Start the remote training

This step will prepare your code from this notebook for remote execution and start NUM_JOBS parallel runs remotely to train the model. Once the jobs are submitted you can go to the next step to monitor the jobs progress via Tensorboard.

tfc.run_cloudtuner(
    distribution_strategy='auto',
    docker_config=tfc.DockerConfig(
        image_build_bucket=GCS_BUCKET
        ),
    chief_config=tfc.MachineConfig(
        cpu_cores=16,
        memory=60,
    ),
    job_labels={'job': JOB_NAME},
    service_account=SERVICE_ACCOUNT,
    num_jobs=NUM_JOBS
)

Training Results

Reconnect your Colab instance

Most remote training jobs are long running, if you are using Colab it may time out before the training results are available. In that case rerun the following sections to reconnect and configure your Colab instance to access the training results. Run the following sections in order:

  1. Import required modules
  2. Project Configurations
  3. Authenticating the notebook to use your Google Cloud Project

Load Tensorboard

While the training is in progress you can use Tensorboard to view the results. Note the results will show only after your training has started. This may take a few minutes.

%load_ext tensorboard
%tensorboard --logdir $TENSORBOARD_LOGS_DIR

You can access the training assets as follows. Note the results will show only after your tuning job has completed at least once trial. This may take a few minutes.

if not tfc.remote():
    tuner.results_summary(1)
    best_model = tuner.get_best_models(1)[0]
    best_hyperparameters = tuner.get_best_hyperparameters(1)[0]

    # References to best trial assets
    best_trial_id = tuner.oracle.get_best_trials(1)[0].trial_id
    best_trial_dir = tuner.get_trial_dir(best_trial_id)