Uplifting with Decision Forests

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

Welcome to the Uplifting Tutorial for TensorFlow Decision Forests (TF-DF). In this tutorial, you will learn what uplifting is, why it is so important, and how to do it in TF-DF.

This tutorial assumes you are familiar with the fundaments of TF-DF, in particular the installation procedure. The beginner tutorial is a great place to start learning about TF-DF.

In this colab, you will:

  • Learn what an uplift modeling is.
  • Train a Uplift Random Forest model on the Hillstrom Email Marketing dataset.
  • Evaluate the quality of this model.

Installing TensorFlow Decision Forests

Install TF-DF by running the following cell.

Wurlitzer is needed to display the detailed training logs in Colabs (when using verbose=2 in the model constructor).

pip install tensorflow_decision_forests wurlitzer

Importing libraries

import tensorflow_decision_forests as tfdf

import os
import numpy as np
import pandas as pd
import tensorflow as tf
import math
import matplotlib.pyplot as plt

The hidden code cell limits the output height in colab.

# Check the version of TensorFlow Decision Forests
print("Found TensorFlow Decision Forests v" + tfdf.__version__)
Found TensorFlow Decision Forests v1.9.0

What is uplift modeling?

Uplift modeling is a statistical modeling technique to predict the incremental impact of an action on a subject. The action is often referred to as a treatment that may or may not be applied.

Uplift modeling is often used in targeted marketing campaigns to predict the increase in the likelihood of a person making a purchase (or any other desired action) based on the marketing exposition they receive.

For example, uplift modeling can predict the effect of an email. The effect is defined as the conditional probability \begin{align} \text{effect}(\text{email}) = &\Pr(\text{outcome}=\text{purchase}\ \vert\ \text{treatment}=\text{with email})\ &- \Pr(\text{outcome}=\text{purchase} \ \vert\ \text{treatment}=\text{no email}), \end{align} where \(\Pr(\text{outcome}=\text{purchase}\ \vert\ ...)\) is the probability of purchase depending on the receiving or not an email.

Compare this to a classification model: With a classification model, one can predict the probability of a purchase. However, customers with a high probability are likely to spend money in the store regardless of whether or not they received an email.

Similarly, one can use numerical uplifting to predict the numerical increase in spend when receiving an email. In comparison, a regression model can only increase the expected spend, which is a less useful metric in many cases.

Defining uplift models in TF-DF

TF-DF expects uplifting datasets to be presented in a "flat" format. A dataset of customers might look like this

treatment outcome feature_1 feature_2
0 1 0.1 blue
0 0 0.2 blue
1 1 0.3 blue
1 1 0.4 blue

The treatment is a binary variable indicating whether or not the example has received treatment. In the above example, the treatment indicates if the customer has received an email or not. The outcome (label) indicates the status of the example after receiving the treatment (or not). TF-DF supports categorical outcomes for categorical uplifting and numerical outcomes for numerical uplifting.

Training an uplifting model

In this example, we will use the Hillstrom Email Marketing dataset.

This dataset contains 64,000 customers who last purchased within twelve months. The customers were involved in an e-mail test:

  • 1/3 were randomly chosen to receive an e-mail campaign featuring Mens merchandise.
  • 1/3 were randomly chosen to receive an e-mail campaign featuring Womens merchandise.
  • 1/3 were randomly chosen to not receive an e-mail campaign.

During a period of two weeks following the e-mail campaign, results were tracked. The task is to tell if the Mens or Womens e-mail campaign was successful.

Read more about dataset in its documentation. This tutorial uses the dataset as curated by TensorFlow Datasets.

# Install the TensorFlow Datasets package
pip install tensorflow-datasets -U --quiet
# Load the dataset
import tensorflow_datasets as tfds
raw_train, raw_test = tfds.load('hillstrom', split=['train[:80%]', 'train[20%:]'])

# Display the first 10 examples of the test fold.
pd.DataFrame(list(raw_test.batch(10).take(1))[0])
2024-04-20 11:22:09.063782: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-04-20 11:22:09.069098: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence

Dataset preprocessing

Since TF-DF currently only supports binary treatments, combine the "Men's Email" and the "Women's Email" campaign. This tutorial uses the binary variable conversion as outcome. This means that the problem is a Categorical Uplifting problem. If we were using the numerical variable spend, the problem would be a Numerical Uplifting problem.

def prepare_dataset(example):
  # Use a binary treatment class.
  example['treatment'] = 1 if example['segment'] == b'Mens E-Mail' or example['segment'] == b'Womens E-Mail' else 0
  outcome = example['conversion']
  # Restrict the dataset to the input features.
  input_features = ['channel', 'history', 'mens', 'womens', 'newbie', 'recency', 'zip_code', 'treatment']
  example = {feature: example[feature] for feature in input_features}
  return example, outcome

train_ds = raw_train.map(prepare_dataset).batch(100)
test_ds = raw_test.map(prepare_dataset).batch(100)

Model training

Finally, train and evaluate the model as usual. Note that TF-DF only supports Random Forest models for uplifting.

%set_cell_height 300

# Configure the model and its hyper-parameters.
model = tfdf.keras.RandomForestModel(
    verbose=2,
    task=tfdf.keras.Task.CATEGORICAL_UPLIFT,
    uplift_treatment='treatment'
)

# Train the model.
model.fit(train_ds)
<IPython.core.display.Javascript object>
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmppyeh4gae as temporary training directory
Reading training dataset...
Training tensor examples:
Features: {'channel': <tf.Tensor 'data:0' shape=(None,) dtype=string>, 'history': <tf.Tensor 'data_1:0' shape=(None,) dtype=float32>, 'mens': <tf.Tensor 'data_2:0' shape=(None,) dtype=int64>, 'womens': <tf.Tensor 'data_3:0' shape=(None,) dtype=int64>, 'newbie': <tf.Tensor 'data_4:0' shape=(None,) dtype=int64>, 'recency': <tf.Tensor 'data_5:0' shape=(None,) dtype=int64>, 'zip_code': <tf.Tensor 'data_6:0' shape=(None,) dtype=string>, 'treatment': <tf.Tensor 'data_7:0' shape=(None,) dtype=int32>}
Label: Tensor("data_8:0", shape=(None,), dtype=int64)
Weights: None
Normalized tensor features:
 {'channel': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data:0' shape=(None,) dtype=string>), 'history': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'data_1:0' shape=(None,) dtype=float32>), 'mens': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast:0' shape=(None,) dtype=float32>), 'womens': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_1:0' shape=(None,) dtype=float32>), 'newbie': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_2:0' shape=(None,) dtype=float32>), 'recency': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_3:0' shape=(None,) dtype=float32>), 'zip_code': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_6:0' shape=(None,) dtype=string>)}
Training dataset read in 0:00:04.974222. Found 51200 examples.
Training model...
Standard output detected as not visible to the user e.g. running in a notebook. Creating a training log redirection. If training gets stuck, try calling tfdf.keras.set_training_logs_redirection(False).
[INFO 24-04-20 11:22:14.2334 UTC kernel.cc:771] Start Yggdrasil model training
[INFO 24-04-20 11:22:14.2335 UTC kernel.cc:772] Collect training examples
[INFO 24-04-20 11:22:14.2335 UTC kernel.cc:785] Dataspec guide:
column_guides {
  column_name_pattern: "^__LABEL$"
  type: CATEGORICAL
}
default_column_guide {
  categorial {
    max_vocab_count: 2000
  }
  discretized_numerical {
    maximum_num_bins: 255
  }
}
ignore_columns_without_guides: false
detect_numerical_as_discretized_numerical: false

[INFO 24-04-20 11:22:14.2339 UTC kernel.cc:391] Number of batches: 512
[INFO 24-04-20 11:22:14.2339 UTC kernel.cc:392] Number of examples: 51200
[INFO 24-04-20 11:22:14.2463 UTC kernel.cc:792] Training dataset:
Number of records: 51200
Number of columns: 9

Number of columns by type:
    NUMERICAL: 5 (55.5556%)
    CATEGORICAL: 4 (44.4444%)

Columns:

NUMERICAL: 5 (55.5556%)
    2: "history" NUMERICAL mean:241.833 min:29.99 max:3345.93 sd:255.292
    3: "mens" NUMERICAL mean:0.550391 min:0 max:1 sd:0.497454
    4: "newbie" NUMERICAL mean:0.503086 min:0 max:1 sd:0.49999
    5: "recency" NUMERICAL mean:5.75514 min:1 max:12 sd:3.50281
    7: "womens" NUMERICAL mean:0.549687 min:0 max:1 sd:0.497525

CATEGORICAL: 4 (44.4444%)
    0: "__LABEL" CATEGORICAL integerized vocab-size:3 no-ood-item
    1: "channel" CATEGORICAL has-dict vocab-size:4 zero-ood-items most-frequent:"Web" 22576 (44.0938%)
    6: "treatment" CATEGORICAL integerized vocab-size:3 no-ood-item
    8: "zip_code" CATEGORICAL has-dict vocab-size:4 zero-ood-items most-frequent:"Surburban" 22966 (44.8555%)

Terminology:
    nas: Number of non-available (i.e. missing) values.
    ood: Out of dictionary.
    manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred.
    tokenized: The attribute value is obtained through tokenization.
    has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.
    vocab-size: Number of unique values.

[INFO 24-04-20 11:22:14.2464 UTC kernel.cc:808] Configure learner
[INFO 24-04-20 11:22:14.2466 UTC kernel.cc:822] Training config:
learner: "RANDOM_FOREST"
features: "^channel$"
features: "^history$"
features: "^mens$"
features: "^newbie$"
features: "^recency$"
features: "^womens$"
features: "^zip_code$"
label: "^__LABEL$"
task: CATEGORICAL_UPLIFT
random_seed: 123456
uplift_treatment: "treatment"
metadata {
  framework: "TF Keras"
}
pure_serving_model: false
[yggdrasil_decision_forests.model.random_forest.proto.random_forest_config] {
  num_trees: 300
  decision_tree {
    max_depth: 16
    min_examples: 5
    in_split_min_examples_check: true
    keep_non_leaf_label_distribution: true
    num_candidate_attributes: 0
    missing_value_policy: GLOBAL_IMPUTATION
    allow_na_conditions: false
    categorical_set_greedy_forward {
      sampling: 0.1
      max_num_items: -1
      min_item_frequency: 1
    }
    growing_strategy_local {
    }
    categorical {
      cart {
      }
    }
    axis_aligned_split {
    }
    internal {
      sorting_strategy: PRESORTED
    }
    uplift {
      min_examples_in_treatment: 5
      split_score: KULLBACK_LEIBLER
    }
  }
  winner_take_all_inference: true
  compute_oob_performances: true
  compute_oob_variable_importances: false
  num_oob_variable_importances_permutations: 1
  bootstrap_training_dataset: true
  bootstrap_size_ratio: 1
  adapt_bootstrap_size_ratio_for_maximum_training_duration: false
  sampling_with_replacement: true
}

[INFO 24-04-20 11:22:14.2469 UTC kernel.cc:825] Deployment config:
cache_path: "/tmpfs/tmp/tmppyeh4gae/working_cache"
num_threads: 32
try_resume_training: true

[INFO 24-04-20 11:22:14.2472 UTC kernel.cc:887] Train model
[INFO 24-04-20 11:22:14.2473 UTC random_forest.cc:416] Training random forest on 51200 example(s) and 7 feature(s).
[WARNING 24-04-20 11:22:14.3731 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:14.3741 UTC random_forest.cc:802] Training of tree  1/300 (tree index:2) done qini:0.000172044 auuc:0.0025137
[WARNING 24-04-20 11:22:14.4012 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:14.4027 UTC random_forest.cc:802] Training of tree  15/300 (tree index:31) done qini:1.41341e-05 auuc:0.0023575
[WARNING 24-04-20 11:22:14.5302 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:14.5327 UTC random_forest.cc:802] Training of tree  25/300 (tree index:23) done qini:-2.19346e-05 auuc:0.00235455
[WARNING 24-04-20 11:22:14.6034 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:14.6058 UTC random_forest.cc:802] Training of tree  35/300 (tree index:33) done qini:0.00013211 auuc:0.0025086
[WARNING 24-04-20 11:22:14.6887 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:14.6910 UTC random_forest.cc:802] Training of tree  45/300 (tree index:45) done qini:-2.28572e-05 auuc:0.00235363
[WARNING 24-04-20 11:22:14.7656 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:14.7680 UTC random_forest.cc:802] Training of tree  55/300 (tree index:55) done qini:-8.67727e-05 auuc:0.00228972
[WARNING 24-04-20 11:22:14.8354 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:14.8379 UTC random_forest.cc:802] Training of tree  65/300 (tree index:56) done qini:-0.000112323 auuc:0.00226417
[WARNING 24-04-20 11:22:14.9052 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:14.9077 UTC random_forest.cc:802] Training of tree  75/300 (tree index:74) done qini:-0.000109942 auuc:0.00226655
[WARNING 24-04-20 11:22:14.9680 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:14.9704 UTC random_forest.cc:802] Training of tree  101/300 (tree index:101) done qini:-0.000112409 auuc:0.00226408
[WARNING 24-04-20 11:22:15.1148 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:15.1196 UTC random_forest.cc:802] Training of tree  121/300 (tree index:118) done qini:-0.000299795 auuc:0.00207669
[WARNING 24-04-20 11:22:15.2280 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:15.2305 UTC random_forest.cc:802] Training of tree  131/300 (tree index:138) done qini:-0.000153133 auuc:0.00222336
[WARNING 24-04-20 11:22:15.3108 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:15.3155 UTC random_forest.cc:802] Training of tree  141/300 (tree index:139) done qini:-0.000173194 auuc:0.0022033
[WARNING 24-04-20 11:22:15.3853 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:15.3877 UTC random_forest.cc:802] Training of tree  168/300 (tree index:162) done qini:-0.000130945 auuc:0.00224554
[WARNING 24-04-20 11:22:15.5471 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:15.5519 UTC random_forest.cc:802] Training of tree  178/300 (tree index:178) done qini:-0.000145457 auuc:0.00223103
[WARNING 24-04-20 11:22:15.6367 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:15.6414 UTC random_forest.cc:802] Training of tree  188/300 (tree index:189) done qini:-0.000124566 auuc:0.00225192
[WARNING 24-04-20 11:22:15.6876 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:15.6901 UTC random_forest.cc:802] Training of tree  217/300 (tree index:213) done qini:-0.000161956 auuc:0.00221453
[WARNING 24-04-20 11:22:15.8731 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:15.8795 UTC random_forest.cc:802] Training of tree  227/300 (tree index:229) done qini:-0.000133605 auuc:0.00224288
[WARNING 24-04-20 11:22:15.9403 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:15.9428 UTC random_forest.cc:802] Training of tree  237/300 (tree index:239) done qini:-0.000101549 auuc:0.00227494
[WARNING 24-04-20 11:22:16.0044 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:16.0068 UTC random_forest.cc:802] Training of tree  247/300 (tree index:253) done qini:-0.000141334 auuc:0.00223516
[WARNING 24-04-20 11:22:16.0749 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:16.0773 UTC random_forest.cc:802] Training of tree  257/300 (tree index:257) done qini:-0.000135416 auuc:0.00224107
[WARNING 24-04-20 11:22:16.1446 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:16.1471 UTC random_forest.cc:802] Training of tree  267/300 (tree index:261) done qini:-0.000131112 auuc:0.00224538
[WARNING 24-04-20 11:22:16.2109 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:16.2132 UTC random_forest.cc:802] Training of tree  277/300 (tree index:275) done qini:-0.000149751 auuc:0.00222674
[WARNING 24-04-20 11:22:16.2724 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:16.2746 UTC random_forest.cc:802] Training of tree  287/300 (tree index:283) done qini:-0.000168736 auuc:0.00220775
[WARNING 24-04-20 11:22:16.3282 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:16.3306 UTC random_forest.cc:802] Training of tree  297/300 (tree index:299) done qini:-0.000181665 auuc:0.00219482
[WARNING 24-04-20 11:22:16.3623 UTC random_forest.cc:1105] Internal error: Non empty oob evaluation
[INFO 24-04-20 11:22:16.3646 UTC random_forest.cc:802] Training of tree  300/300 (tree index:298) done qini:-0.000173258 auuc:0.00220323
[INFO 24-04-20 11:22:16.3680 UTC random_forest.cc:882] Final OOB metrics: qini:-0.000173258 auuc:0.00220323
[INFO 24-04-20 11:22:16.3843 UTC kernel.cc:919] Export model in log directory: /tmpfs/tmp/tmppyeh4gae with prefix 568256236db544eb
[INFO 24-04-20 11:22:16.4274 UTC kernel.cc:937] Save model in resources
[INFO 24-04-20 11:22:16.4309 UTC abstract_model.cc:881] Model self evaluation:
Number of predictions (without weights): 51200
Number of predictions (with weights): 51200
Task: CATEGORICAL_UPLIFT
Label: __LABEL

Number of treatments: 2
AUUC: 0.00220323
Qini: -0.000173258

[INFO 24-04-20 11:22:16.4580 UTC kernel.cc:1233] Loading model from path /tmpfs/tmp/tmppyeh4gae/model/ with prefix 568256236db544eb
[INFO 24-04-20 11:22:16.6557 UTC decision_forest.cc:734] Model loaded with 300 root(s), 60190 node(s), and 7 input feature(s).
[INFO 24-04-20 11:22:16.6557 UTC abstract_model.cc:1344] Engine "RandomForestGeneric" built
[INFO 24-04-20 11:22:16.6557 UTC kernel.cc:1061] Use fast generic engine
Model trained in 0:00:02.442514
Compiling model...
Model compiled.
<tf_keras.src.callbacks.History at 0x7f5dd40eb7c0>

Evaluating Uplift models.

Metrics for Uplift models

The two most important metrics for evaluating upift models are the AUUC (Area Under the Uplift Curve) metric and the Qini (Area Under the Qini Curve) metric. This is similar to the use of AUC and accuracy for classification problems. For both metrics, the larger they are, the better.

Both AUUC and Qini are not normalized metrics. This means that the best possible value of the metric can vary from dataset to dataset. This is different from, for example, the AUC matric that always varies between 0 and 1.

A formal definition of AUUC is below. For more information about these metrics, see Guelman and Betlei et al.

Model Self-Evaluation

TF-DF Random Forest models perform self-evaluation on the out-of-bag examples of the training dataset. For uplift models, they expose the AUUC and the Qini metric. You can directly retrieve the two metrics on the training dataset through the inspector

Later, we are going to recompute the AUUC metric "manually" on the test dataset. Note that two metrics are not expected to be exactly equal (out-of-bag on train vs test) since the AUUC is not a normalized metric.

# The self-evaluation is available through the model inspector
insp = model.make_inspector()
insp.evaluation()
Evaluation(num_examples=51200, accuracy=None, loss=None, rmse=None, ndcg=None, aucs=None, auuc=0.0022032308892709586, qini=-0.00017325819500263418)

Manually computing the AUUC

In this section, we manually compute the AUUC and plot the uplift curves.

The next few paragraphs explain the AUUC metric in more detail and may be skipped.

Computing the AUUC

Suppose you have a labeled dataset with \(|T|\) examples with treatment and \(|C|\) examples without treatment, called control examples. For each example, the uplift model \(f\) produces the conditional probability that a treatment on the example will yield a positive outcome.

Suppose a decision-maker needs to decide which clients to send an email using an uplift model \(f\). The model produces a (conditional) probability that the email will result in a conversion. The decision-maker might therefore just pick the number \(k\) of emails to send and send those \(k\) emails to the clients with the highest probability.

Using a labeled test dataset, it is possible to study the impact of \(k\) on the success of the campaign. First, we are interested in the ratio \(\frac{|C \cap T|}{|T|}\) of clients that received an email that converted versus total number of clients that received an email. Here \(C\) is the set of clients that received an email and converted and \(T\) is the total number of clients that received an email. We plot this ratio against \(k\).

Ideally, we like to have this curve increase steeply. This would mean that the model prioritizes sending email to those clients that will generate a conversion when receiving an email.

# Compute all predictions on the test dataset
predictions = model.predict(test_ds).flatten()
# Extract outcomes and treatments
outcomes = np.concatenate([outcome.numpy() for _, outcome in test_ds])
treatment = np.concatenate([example['treatment'].numpy() for example,_ in test_ds])
control = 1 - treatment

num_treatments = np.sum(treatment)
# Clients without treatment are called 'control' group
num_control = np.sum(control)
num_examples = len(predictions)

# Sort labels and treatments according to predictions in descending order
prediction_order = predictions.argsort()[::-1]
outcomes_sorted = outcomes[prediction_order]
treatment_sorted = treatment[prediction_order]
control_sorted = control[prediction_order]
ratio_treatment = np.cumsum(np.multiply(outcomes_sorted, treatment_sorted), axis=0)/num_treatments

fig, ax = plt.subplots()
ax.plot(ratio_treatment, label='Conversion ratio of treatment')
ax.set_xlabel('k')
ax.set_ylabel('Ratio of conversion')
ax.legend()
512/512 [==============================] - 3s 5ms/step
2024-04-20 11:22:25.165808: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-20 11:22:25.975008: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
<matplotlib.legend.Legend at 0x7f5cbc41ac70>

png

Similarly, we can also compute and plot the conversion ratio of those not receiving an email, called the control group. Ideally, this curve is initially flat: This would mean that the model does not prioritize sending emails to clients that will generate a conversion despite not receiving a email

ratio_control = np.cumsum(np.multiply(outcomes_sorted, control_sorted), axis=0)/num_control
ax.plot(ratio_control, label='Conversion ratio of control')
ax.legend()
fig

png

The AUUC metric measures the area between these two curves, normalizing the y-axis between 0 and 1

x = np.linspace(0, 1, num_examples)
plt.plot(x,ratio_treatment, label='Conversion ratio of treatment')
plt.plot(x,ratio_control, label='Conversion ratio of control')
plt.fill_between(x, ratio_treatment, ratio_control, where=(ratio_treatment > ratio_control), color='C0', alpha=0.3)
plt.fill_between(x, ratio_treatment, ratio_control, where=(ratio_treatment < ratio_control), color='C1', alpha=0.3)
plt.xlabel('k')
plt.ylabel('Ratio of conversion')
plt.legend()

# Approximate the integral of the difference between the two curves.
auuc = np.trapz(ratio_treatment-ratio_control, dx=1/num_examples)
print(f'The AUUC on the test dataset is {auuc}')
The AUUC on the test dataset is 0.007513928513572819

png