Composing Decision Forest and Neural Network models

Stay organized with collections Save and categorize content based on your preferences.

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook Keras Functional API

Introduction

Welcome to the model composition tutorial for TensorFlow Decision Forests (TF-DF). This notebook shows you how to compose multiple decision forest and neural network models together using a common preprocessing layer and the Keras functional API.

You might want to compose models together to improve predictive performance (ensembling), to get the best of different modeling technologies (heterogeneous model ensembling), to train different part of the model on different datasets (e.g. pre-training), or to create a stacked model (e.g. a model operates on the predictions of another model).

This tutorial covers an advanced use case of model composition using the Functional API. You can find examples for simpler scenarios of model composition in the "feature preprocessing" section of this tutorial and in the "using a pretrained text embedding" section of this tutorial.

Here is the structure of the model you'll build:

<span class="ansired">---------------------------------------------------------------------------</span>

<span class="ansired">FileNotFoundError</span>                         Traceback (most recent call last)

File <span class="ansigreen">/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/graphviz/backend/execute.py:79</span>, in <span class="ansicyan">run_check</span><span class="ansiblue">(cmd, input_lines, encoding, quiet, **kwargs)</span>
<span class="ansigreen">     78</span>         kwargs[&apos;</span></span>stdout</span></span>&apos;</span>] </span>=</span> kwargs[</span>&apos;</span></span>stderr</span></span>&apos;</span>] </span>=</span> subprocess</span>.</span>PIPE
</span><span class="ansigreen">---&gt; 79</span>     proc =</span> </span>_run_input_lines</span></span>(</span></span>cmd</span></span>,</span></span> </span></span>input_lines</span></span>,</span></span> </span></span>kwargs</span></span>=</span></span>kwargs</span></span>)</span>
</span><span class="ansigreen">     80</span> <span class="ansibold">else</span>:

File <span class="ansigreen">/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/graphviz/backend/execute.py:99</span>, in <span class="ansicyan">_run_input_lines</span><span class="ansiblue">(cmd, input_lines, kwargs)</span>
<span class="ansigreen">     98</span> <span class="ansibold">def</span> </span>_run_input_lines</span>(cmd, input_lines, </span>*</span>, kwargs):
</span><span class="ansigreen">---&gt; 99</span>     popen =</span> </span>subprocess</span></span>.</span></span>Popen</span></span>(</span></span>cmd</span></span>,</span></span> </span></span>stdin</span></span>=</span></span>subprocess</span></span>.</span></span>PIPE</span></span>,</span></span> </span></span>*</span></span>*</span></span>kwargs</span></span>)</span>
</span><span class="ansigreen">    101</span>     stdin_write =</span> popen</span>.</span>stdin</span>.</span>write

File <span class="ansigreen">/usr/lib/python3.9/subprocess.py:951</span>, in <span class="ansicyan">Popen.__init__</span><span class="ansiblue">(self, args, bufsize, executable, stdin, stdout, stderr, preexec_fn, close_fds, shell, cwd, env, universal_newlines, startupinfo, creationflags, restore_signals, start_new_session, pass_fds, user, group, extra_groups, encoding, errors, text, umask)</span>
<span class="ansigreen">    948</span>             self</span></span>.</span>stderr </span>=</span> io</span>.</span>TextIOWrapper(</span>self</span></span>.</span>stderr,
</span><span class="ansigreen">    949</span>                     encoding=</span>encoding, errors</span>=</span>errors)
</span><span class="ansigreen">--&gt; 951</span>     self</span></span>.</span></span>_execute_child</span></span>(</span></span>args</span></span>,</span></span> </span></span>executable</span></span>,</span></span> </span></span>preexec_fn</span></span>,</span></span> </span></span>close_fds</span></span>,</span>
</span><span class="ansigreen">    952</span>                         </span></span>pass_fds</span></span>,</span></span> </span></span>cwd</span></span>,</span></span> </span></span>env</span></span>,</span>
</span><span class="ansigreen">    953</span>                         </span></span>startupinfo</span></span>,</span></span> </span></span>creationflags</span></span>,</span></span> </span></span>shell</span></span>,</span>
</span><span class="ansigreen">    954</span>                         </span></span>p2cread</span></span>,</span></span> </span></span>p2cwrite</span></span>,</span>
</span><span class="ansigreen">    955</span>                         </span></span>c2pread</span></span>,</span></span> </span></span>c2pwrite</span></span>,</span>
</span><span class="ansigreen">    956</span>                         </span></span>errread</span></span>,</span></span> </span></span>errwrite</span></span>,</span>
</span><span class="ansigreen">    957</span>                         </span></span>restore_signals</span></span>,</span>
</span><span class="ansigreen">    958</span>                         </span></span>gid</span></span>,</span></span> </span></span>gids</span></span>,</span></span> </span></span>uid</span></span>,</span></span> </span></span>umask</span></span>,</span>
</span><span class="ansigreen">    959</span>                         </span></span>start_new_session</span></span>)</span>
</span><span class="ansigreen">    960</span> <span class="ansibold">except</span>:
</span><span class="ansigreen">    961</span>     # Cleanup if the child failed starting.</span>

File <span class="ansigreen">/usr/lib/python3.9/subprocess.py:1821</span>, in <span class="ansicyan">Popen._execute_child</span><span class="ansiblue">(self, args, executable, preexec_fn, close_fds, pass_fds, cwd, env, startupinfo, creationflags, shell, p2cread, p2cwrite, c2pread, c2pwrite, errread, errwrite, restore_signals, gid, gids, uid, umask, start_new_session)</span>
<span class="ansigreen">   1820</span>         err_msg =</span> os</span>.</span>strerror(errno_num)
</span><span class="ansigreen">-&gt; 1821</span>     <span class="ansibold">raise</span> child_exception_type(errno_num, err_msg, err_filename)
</span><span class="ansigreen">   1822</span> <span class="ansibold">raise</span> child_exception_type(err_msg)

<span class="ansired">FileNotFoundError</span>: [Errno 2] No such file or directory: PosixPath(&apos;dot&apos;)


The above exception was the direct cause of the following exception:


<span class="ansired">ExecutableNotFound</span>                        Traceback (most recent call last)

File <span class="ansigreen">/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/formatters.py:973</span>, in <span class="ansicyan">MimeBundleFormatter.__call__</span><span class="ansiblue">(self, obj, include, exclude)</span>
<span class="ansigreen">    970</span>     method =</span> get_real_method(obj, </span>self</span></span>.</span>print_method)
</span><span class="ansigreen">    972</span>     <span class="ansibold">if</span> method </span><span class="ansibold">is</span> </span><span class="ansibold">not</span> </span><span class="ansibold">None</span>:
</span><span class="ansigreen">--&gt; 973</span>         <span class="ansibold">return</span> </span>method</span></span>(</span></span>include</span></span>=</span></span>include</span></span>,</span></span> </span></span>exclude</span></span>=</span></span>exclude</span></span>)</span>
</span><span class="ansigreen">    974</span>     <span class="ansibold">return</span> </span><span class="ansibold">None</span>
</span><span class="ansigreen">    975</span> <span class="ansibold">else</span>:

File <span class="ansigreen">/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/graphviz/jupyter_integration.py:98</span>, in <span class="ansicyan">JupyterIntegration._repr_mimebundle_</span><span class="ansiblue">(self, include, exclude, **_)</span>
<span class="ansigreen">     96</span> include =</span> </span>set</span>(include) </span><span class="ansibold">if</span> include </span><span class="ansibold">is</span> </span><span class="ansibold">not</span> </span><span class="ansibold">None</span> </span><span class="ansibold">else</span> {</span>self</span></span>.</span>_jupyter_mimetype}
</span><span class="ansigreen">     97</span> include -</span></span>=</span> </span>set</span>(exclude </span><span class="ansibold">or</span> [])
</span><span class="ansigreen">---&gt; 98</span> <span class="ansibold">return</span> {mimetype: </span>getattr</span>(</span>self</span>, method_name)()
</span><span class="ansigreen">     99</span>         <span class="ansibold">for</span> mimetype, method_name </span><span class="ansibold">in</span> MIME_TYPES</span>.</span>items()
</span><span class="ansigreen">    100</span>         <span class="ansibold">if</span> mimetype </span><span class="ansibold">in</span> include}

File <span class="ansigreen">/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/graphviz/jupyter_integration.py:98</span>, in <span class="ansicyan">&lt;dictcomp&gt;</span><span class="ansiblue">(.0)</span>
<span class="ansigreen">     96</span> include =</span> </span>set</span>(include) </span><span class="ansibold">if</span> include </span><span class="ansibold">is</span> </span><span class="ansibold">not</span> </span><span class="ansibold">None</span> </span><span class="ansibold">else</span> {</span>self</span></span>.</span>_jupyter_mimetype}
</span><span class="ansigreen">     97</span> include -</span></span>=</span> </span>set</span>(exclude </span><span class="ansibold">or</span> [])
</span><span class="ansigreen">---&gt; 98</span> <span class="ansibold">return</span> {mimetype: </span>getattr</span></span>(</span></span>self</span></span>,</span></span> </span></span>method_name</span></span>)</span></span>(</span></span>)</span>
</span><span class="ansigreen">     99</span>         <span class="ansibold">for</span> mimetype, method_name </span><span class="ansibold">in</span> MIME_TYPES</span>.</span>items()
</span><span class="ansigreen">    100</span>         <span class="ansibold">if</span> mimetype </span><span class="ansibold">in</span> include}

File <span class="ansigreen">/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/graphviz/jupyter_integration.py:112</span>, in <span class="ansicyan">JupyterIntegration._repr_image_svg_xml</span><span class="ansiblue">(self)</span>
<span class="ansigreen">    110</span> <span class="ansibold">def</span> </span>_repr_image_svg_xml</span>(</span>self</span>) </span>-</span></span>&gt;</span> </span>str</span>:
</span><span class="ansigreen">    111</span>     &quot;&quot;&quot;Return the rendered graph as SVG string.&quot;&quot;&quot;</span>
</span><span class="ansigreen">--&gt; 112</span>     <span class="ansibold">return</span> </span>self</span></span>.</span></span>pipe</span></span>(</span></span>format</span></span>=</span></span>&apos;</span></span>svg</span></span>&apos;</span></span>,</span></span> </span></span>encoding</span></span>=</span></span>SVG_ENCODING</span></span>)</span>

File <span class="ansigreen">/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/graphviz/piping.py:104</span>, in <span class="ansicyan">Pipe.pipe</span><span class="ansiblue">(self, format, renderer, formatter, neato_no_op, quiet, engine, encoding)</span>
<span class="ansigreen">     55</span> <span class="ansibold">def</span> </span>pipe</span>(</span>self</span>,
</span><span class="ansigreen">     56</span>          format</span>: typing</span>.</span>Optional[</span>str</span>] </span>=</span> </span><span class="ansibold">None</span>,
</span><span class="ansigreen">     57</span>          renderer: typing.</span>Optional[</span>str</span>] </span>=</span> </span><span class="ansibold">None</span>,
</span><span class="ansigreen">   (...)</span>
<span class="ansigreen">     61</span>          engine: typing.</span>Optional[</span>str</span>] </span>=</span> </span><span class="ansibold">None</span>,
</span><span class="ansigreen">     62</span>          encoding: typing.</span>Optional[</span>str</span>] </span>=</span> </span><span class="ansibold">None</span>) </span>-</span></span>&gt;</span> typing</span>.</span>Union[</span>bytes</span>, </span>str</span>]:
</span><span class="ansigreen">     63</span>     &quot;&quot;&quot;Return the source piped through the Graphviz layout command.</span>
</span><span class="ansigreen">     64</span> 
<span class="ansigreen">     65</span>     Args:</span>
</span><span class="ansigreen">   (...)</span>
<span class="ansigreen">    102</span>         &apos;&lt;?xml version=&apos;</span>
</span><span class="ansigreen">    103</span>     &quot;&quot;&quot;</span>
</span><span class="ansigreen">--&gt; 104</span>     <span class="ansibold">return</span> </span>self</span></span>.</span></span>_pipe_legacy</span></span>(</span></span>format</span></span>,</span>
</span><span class="ansigreen">    105</span>                              </span></span>renderer</span></span>=</span></span>renderer</span></span>,</span>
</span><span class="ansigreen">    106</span>                              </span></span>formatter</span></span>=</span></span>formatter</span></span>,</span>
</span><span class="ansigreen">    107</span>                              </span></span>neato_no_op</span></span>=</span></span>neato_no_op</span></span>,</span>
</span><span class="ansigreen">    108</span>                              </span></span>quiet</span></span>=</span></span>quiet</span></span>,</span>
</span><span class="ansigreen">    109</span>                              </span></span>engine</span></span>=</span></span>engine</span></span>,</span>
</span><span class="ansigreen">    110</span>                              </span></span>encoding</span></span>=</span></span>encoding</span></span>)</span>

File <span class="ansigreen">/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/graphviz/_tools.py:171</span>, in <span class="ansicyan">deprecate_positional_args.&lt;locals&gt;.decorator.&lt;locals&gt;.wrapper</span><span class="ansiblue">(*args, **kwargs)</span>
<span class="ansigreen">    162</span>     wanted =</span> </span>&apos;</span></span>, </span></span>&apos;</span></span>.</span>join(</span>f</span></span>&apos;</span></span><span class="ansibold">{</span>name</span><span class="ansibold">}</span></span>=</span></span><span class="ansibold">{</span>value</span><span class="ansibold">!r}</span></span>&apos;</span>
</span><span class="ansigreen">    163</span>                        <span class="ansibold">for</span> name, value </span><span class="ansibold">in</span> deprecated</span>.</span>items())
</span><span class="ansigreen">    164</span>     warnings.</span>warn(</span>f</span></span>&apos;</span></span>The signature of </span></span><span class="ansibold">{</span>func</span>.</span></span>__name__</span></span><span class="ansibold">}</span></span> will be reduced</span></span>&apos;</span>
</span><span class="ansigreen">    165</span>                   f</span></span>&apos;</span></span> to </span></span><span class="ansibold">{</span>supported_number</span><span class="ansibold">}</span></span> positional args</span></span>&apos;</span>
</span><span class="ansigreen">    166</span>                   f</span></span>&apos;</span></span> </span></span><span class="ansibold">{</span></span>list</span>(supported)</span><span class="ansibold">}</span></span>: pass </span></span><span class="ansibold">{</span>wanted</span><span class="ansibold">}</span></span>&apos;</span>
</span><span class="ansigreen">    167</span>                   &apos;</span></span> as keyword arg(s)</span></span>&apos;</span>,
</span><span class="ansigreen">    168</span>                   stacklevel=</span>stacklevel,
</span><span class="ansigreen">    169</span>                   category=</span>category)
</span><span class="ansigreen">--&gt; 171</span> <span class="ansibold">return</span> </span>func</span></span>(</span></span>*</span></span>args</span></span>,</span></span> </span></span>*</span></span>*</span></span>kwargs</span></span>)</span>

File <span class="ansigreen">/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/graphviz/piping.py:121</span>, in <span class="ansicyan">Pipe._pipe_legacy</span><span class="ansiblue">(self, format, renderer, formatter, neato_no_op, quiet, engine, encoding)</span>
<span class="ansigreen">    112</span> @_tools</span></span>.</span>deprecate_positional_args(supported_number</span>=</span></span>2</span>)
</span><span class="ansigreen">    113</span> <span class="ansibold">def</span> </span>_pipe_legacy</span>(</span>self</span>,
</span><span class="ansigreen">    114</span>                  format</span>: typing</span>.</span>Optional[</span>str</span>] </span>=</span> </span><span class="ansibold">None</span>,
</span><span class="ansigreen">   (...)</span>
<span class="ansigreen">    119</span>                  engine: typing.</span>Optional[</span>str</span>] </span>=</span> </span><span class="ansibold">None</span>,
</span><span class="ansigreen">    120</span>                  encoding: typing.</span>Optional[</span>str</span>] </span>=</span> </span><span class="ansibold">None</span>) </span>-</span></span>&gt;</span> typing</span>.</span>Union[</span>bytes</span>, </span>str</span>]:
</span><span class="ansigreen">--&gt; 121</span>     <span class="ansibold">return</span> </span>self</span></span>.</span></span>_pipe_future</span></span>(</span></span>format</span></span>,</span>
</span><span class="ansigreen">    122</span>                              </span></span>renderer</span></span>=</span></span>renderer</span></span>,</span>
</span><span class="ansigreen">    123</span>                              </span></span>formatter</span></span>=</span></span>formatter</span></span>,</span>
</span><span class="ansigreen">    124</span>                              </span></span>neato_no_op</span></span>=</span></span>neato_no_op</span></span>,</span>
</span><span class="ansigreen">    125</span>                              </span></span>quiet</span></span>=</span></span>quiet</span></span>,</span>
</span><span class="ansigreen">    126</span>                              </span></span>engine</span></span>=</span></span>engine</span></span>,</span>
</span><span class="ansigreen">    127</span>                              </span></span>encoding</span></span>=</span></span>encoding</span></span>)</span>

File <span class="ansigreen">/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/graphviz/piping.py:149</span>, in <span class="ansicyan">Pipe._pipe_future</span><span class="ansiblue">(self, format, renderer, formatter, neato_no_op, quiet, engine, encoding)</span>
<span class="ansigreen">    146</span> <span class="ansibold">if</span> encoding </span><span class="ansibold">is</span> </span><span class="ansibold">not</span> </span><span class="ansibold">None</span>:
</span><span class="ansigreen">    147</span>     <span class="ansibold">if</span> codecs</span>.</span>lookup(encoding) </span><span class="ansibold">is</span> codecs</span>.</span>lookup(</span>self</span></span>.</span>encoding):
</span><span class="ansigreen">    148</span>         # common case: both stdin and stdout need the same encoding</span>
</span><span class="ansigreen">--&gt; 149</span>         <span class="ansibold">return</span> </span>self</span></span>.</span></span>_pipe_lines_string</span></span>(</span></span>*</span></span>args</span></span>,</span></span> </span></span>encoding</span></span>=</span></span>encoding</span></span>,</span></span> </span></span>*</span></span>*</span></span>kwargs</span></span>)</span>
</span><span class="ansigreen">    150</span>     <span class="ansibold">try</span>:
</span><span class="ansigreen">    151</span>         raw =</span> </span>self</span></span>.</span>_pipe_lines(</span>*</span>args, input_encoding</span>=</span></span>self</span></span>.</span>encoding, </span>*</span></span>*</span>kwargs)

File <span class="ansigreen">/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/graphviz/backend/piping.py:212</span>, in <span class="ansicyan">pipe_lines_string</span><span class="ansiblue">(engine, format, input_lines, encoding, renderer, formatter, neato_no_op, quiet)</span>
<span class="ansigreen">    206</span> cmd =</span> dot_command</span>.</span>command(engine, </span>format</span>,
</span><span class="ansigreen">    207</span>                           renderer=</span>renderer,
</span><span class="ansigreen">    208</span>                           formatter=</span>formatter,
</span><span class="ansigreen">    209</span>                           neato_no_op=</span>neato_no_op)
</span><span class="ansigreen">    210</span> kwargs =</span> {</span>&apos;</span></span>input_lines</span></span>&apos;</span>: input_lines, </span>&apos;</span></span>encoding</span></span>&apos;</span>: encoding}
</span><span class="ansigreen">--&gt; 212</span> proc =</span> </span>execute</span></span>.</span></span>run_check</span></span>(</span></span>cmd</span></span>,</span></span> </span></span>capture_output</span></span>=</span></span><span class="ansibold">True</span></span>,</span></span> </span></span>quiet</span></span>=</span></span>quiet</span></span>,</span></span> </span></span>*</span></span>*</span></span>kwargs</span></span>)</span>
</span><span class="ansigreen">    213</span> <span class="ansibold">return</span> proc</span>.</span>stdout

File <span class="ansigreen">/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/graphviz/backend/execute.py:84</span>, in <span class="ansicyan">run_check</span><span class="ansiblue">(cmd, input_lines, encoding, quiet, **kwargs)</span>
<span class="ansigreen">     82</span> <span class="ansibold">except</span> </span><span class="ansibold">OSError</span> </span><span class="ansibold">as</span> e:
</span><span class="ansigreen">     83</span>     <span class="ansibold">if</span> e</span>.</span>errno </span>==</span> errno</span>.</span>ENOENT:
</span><span class="ansigreen">---&gt; 84</span>         <span class="ansibold">raise</span> ExecutableNotFound(cmd) </span><span class="ansibold">from</span> </span><span class="ansibold">e</span>
</span><span class="ansigreen">     85</span>     <span class="ansibold">raise</span>
</span><span class="ansigreen">     87</span> <span class="ansibold">if</span> </span><span class="ansibold">not</span> quiet </span><span class="ansibold">and</span> proc</span>.</span>stderr:

<span class="ansired">ExecutableNotFound</span>: failed to execute PosixPath(&apos;dot&apos;), make sure the Graphviz executables are on your systems&apos; PATH
<graphviz.sources.Source at 0x7fd7880fca90>

Your composed model has three stages:

  1. The first stage is a preprocessing layer composed of a neural network and common to all the models in the next stage. In practice, such a preprocessing layer could either be a pre-trained embedding to fine-tune, or a randomly initialized neural network.
  2. The second stage is an ensemble of two decision forest and two neural network models.
  3. The last stage averages the predictions of the models in the second stage. It does not contain any learnable weights.

The neural networks are trained using the backpropagation algorithm and gradient descent. This algorithm has two important properties: (1) The layer of neural network can be trained if its receives a loss gradient (more precisely, the gradient of the loss according to the layer's output), and (2) the algorithm "transmits" the loss gradient from the layer's output to the layer's input (this is the "chain rule"). For these two reasons, Backpropagation can train together multiple layers of neural networks stacked on top of each other.

In this example, the decision forests are trained with the Random Forest (RF) algorithm. Unlike Backpropagation, the training of RF does not "transmit" the loss gradient to from its output to its input. For this reasons, the classical RF algorithm cannot be used to train or fine-tune a neural network underneath. In other words, the "decision forest" stages cannot be used to train the "Learnable NN pre-processing block".

  1. Train the preprocessing and neural networks stage.
  2. Train the decision forest stages.

Install TensorFlow Decision Forests

Install TF-DF by running the following cell.

pip install tensorflow_decision_forests -U --quiet

Wurlitzer is needed to display the detailed training logs in Colabs (when using verbose=2 in the model constructor).

pip install wurlitzer -U --quiet

Import libraries

import tensorflow_decision_forests as tfdf

import os
import numpy as np
import pandas as pd
import tensorflow as tf
import math
import matplotlib.pyplot as plt
2022-09-19 11:26:09.440101: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-09-19 11:26:10.148783: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-09-19 11:26:10.149022: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-09-19 11:26:10.149034: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

Dataset

You will use a simple synthetic dataset in this tutorial to make it easier to interpret the final model.

def make_dataset(num_examples, num_features, seed=1234):
  np.random.seed(seed)
  features = np.random.uniform(-1, 1, size=(num_examples, num_features))
  noise = np.random.uniform(size=(num_examples))

  left_side = np.sqrt(
      np.sum(np.multiply(np.square(features[:, 0:2]), [1, 2]), axis=1))
  right_side = features[:, 2] * 0.7 + np.sin(
      features[:, 3] * 10) * 0.5 + noise * 0.0 + 0.5

  labels = left_side <= right_side
  return features, labels.astype(int)

Generate some examples:

make_dataset(num_examples=5, num_features=4)
(array([[-0.6169611 ,  0.24421754, -0.12454452,  0.57071717],
        [ 0.55995162, -0.45481479, -0.44707149,  0.60374436],
        [ 0.91627871,  0.75186527, -0.28436546,  0.00199025],
        [ 0.36692587,  0.42540405, -0.25949849,  0.12239237],
        [ 0.00616633, -0.9724631 ,  0.54565324,  0.76528238]]),
 array([0, 0, 0, 1, 0]))

You can also plot them to get an idea of the synthetic pattern:

plot_features, plot_label = make_dataset(num_examples=50000, num_features=4)

plt.rcParams["figure.figsize"] = [8, 8]
common_args = dict(c=plot_label, s=1.0, alpha=0.5)

plt.subplot(2, 2, 1)
plt.scatter(plot_features[:, 0], plot_features[:, 1], **common_args)

plt.subplot(2, 2, 2)
plt.scatter(plot_features[:, 1], plot_features[:, 2], **common_args)

plt.subplot(2, 2, 3)
plt.scatter(plot_features[:, 0], plot_features[:, 2], **common_args)

plt.subplot(2, 2, 4)
plt.scatter(plot_features[:, 0], plot_features[:, 3], **common_args)
<matplotlib.collections.PathCollection at 0x7fd67c55b190>

png

Note that this pattern is smooth and not axis aligned. This will advantage the neural network models. This is because it is easier for a neural network than for a decision tree to have round and non aligned decision boundaries.

On the other hand, we will train the model on a small datasets with 2500 examples. This will advantage the decision forest models. This is because decision forests are much more efficient, using all the available information from the examples (decision forests are "sample efficient").

Our ensemble of neural networks and decision forests will use the best of both worlds.

Let's create a train and test tf.data.Dataset:

def make_tf_dataset(batch_size=64, **args):
  features, labels = make_dataset(**args)
  return tf.data.Dataset.from_tensor_slices(
      (features, labels)).batch(batch_size)


num_features = 10

train_dataset = make_tf_dataset(
    num_examples=2500, num_features=num_features, batch_size=100, seed=1234)
test_dataset = make_tf_dataset(
    num_examples=10000, num_features=num_features, batch_size=100, seed=5678)

Model structure

Define the model structure as follows:

# Input features.
raw_features = tf.keras.layers.Input(shape=(num_features,))

# Stage 1
# =======

# Common learnable pre-processing
preprocessor = tf.keras.layers.Dense(10, activation=tf.nn.relu6)
preprocess_features = preprocessor(raw_features)

# Stage 2
# =======

# Model #1: NN
m1_z1 = tf.keras.layers.Dense(5, activation=tf.nn.relu6)(preprocess_features)
m1_pred = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)(m1_z1)

# Model #2: NN
m2_z1 = tf.keras.layers.Dense(5, activation=tf.nn.relu6)(preprocess_features)
m2_pred = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)(m2_z1)


# Model #3: DF
model_3 = tfdf.keras.RandomForestModel(num_trees=1000, random_seed=1234)
m3_pred = model_3(preprocess_features)

# Model #4: DF
model_4 = tfdf.keras.RandomForestModel(
    num_trees=1000,
    #split_axis="SPARSE_OBLIQUE", # Uncomment this line to increase the quality of this model
    random_seed=4567)
m4_pred = model_4(preprocess_features)

# Since TF-DF uses deterministic learning algorithms, you should set the model's
# training seed to different values otherwise both
# `tfdf.keras.RandomForestModel` will be exactly the same.

# Stage 3
# =======

mean_nn_only = tf.reduce_mean(tf.stack([m1_pred, m2_pred], axis=0), axis=0)
mean_nn_and_df = tf.reduce_mean(
    tf.stack([m1_pred, m2_pred, m3_pred, m4_pred], axis=0), axis=0)

# Keras Models
# ============

ensemble_nn_only = tf.keras.models.Model(raw_features, mean_nn_only)
ensemble_nn_and_df = tf.keras.models.Model(raw_features, mean_nn_and_df)
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmpophji3ym as temporary training directory
Warning: The model was called directly (i.e. using `model(data)` instead of using `model.predict(data)`) before being trained. The model will only return zeros until trained. The output shape might change after training Tensor("inputs:0", shape=(None, 10), dtype=float32)
WARNING:absl:The model was called directly (i.e. using `model(data)` instead of using `model.predict(data)`) before being trained. The model will only return zeros until trained. The output shape might change after training Tensor("inputs:0", shape=(None, 10), dtype=float32)
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmpg88ve94z as temporary training directory
Warning: The model was called directly (i.e. using `model(data)` instead of using `model.predict(data)`) before being trained. The model will only return zeros until trained. The output shape might change after training Tensor("inputs:0", shape=(None, 10), dtype=float32)
WARNING:absl:The model was called directly (i.e. using `model(data)` instead of using `model.predict(data)`) before being trained. The model will only return zeros until trained. The output shape might change after training Tensor("inputs:0", shape=(None, 10), dtype=float32)

Before you train the model, you can plot it to check if it is similar to the initial diagram.

from keras.utils.vis_utils import plot_model

plot_model(ensemble_nn_and_df, to_file="/tmp/model.png", show_shapes=True)
You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.

Model training

First train the preprocessing and two neural network layers using the backpropagation algorithm.

%%time
ensemble_nn_only.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss=tf.keras.losses.BinaryCrossentropy(),
        metrics=["accuracy"])

ensemble_nn_only.fit(train_dataset, epochs=20, validation_data=test_dataset)
Epoch 1/20
25/25 [==============================] - 1s 16ms/step - loss: 0.6725 - accuracy: 0.5984 - val_loss: 0.6450 - val_accuracy: 0.6929
Epoch 2/20
25/25 [==============================] - 0s 9ms/step - loss: 0.6256 - accuracy: 0.7372 - val_loss: 0.6045 - val_accuracy: 0.7376
Epoch 3/20
25/25 [==============================] - 0s 9ms/step - loss: 0.5880 - accuracy: 0.7500 - val_loss: 0.5741 - val_accuracy: 0.7392
Epoch 4/20
25/25 [==============================] - 0s 9ms/step - loss: 0.5599 - accuracy: 0.7500 - val_loss: 0.5522 - val_accuracy: 0.7392
Epoch 5/20
25/25 [==============================] - 0s 10ms/step - loss: 0.5391 - accuracy: 0.7500 - val_loss: 0.5361 - val_accuracy: 0.7392
Epoch 6/20
25/25 [==============================] - 0s 9ms/step - loss: 0.5231 - accuracy: 0.7500 - val_loss: 0.5234 - val_accuracy: 0.7392
Epoch 7/20
25/25 [==============================] - 0s 10ms/step - loss: 0.5101 - accuracy: 0.7500 - val_loss: 0.5124 - val_accuracy: 0.7392
Epoch 8/20
25/25 [==============================] - 0s 10ms/step - loss: 0.4984 - accuracy: 0.7500 - val_loss: 0.5022 - val_accuracy: 0.7392
Epoch 9/20
25/25 [==============================] - 0s 9ms/step - loss: 0.4875 - accuracy: 0.7500 - val_loss: 0.4925 - val_accuracy: 0.7392
Epoch 10/20
25/25 [==============================] - 0s 9ms/step - loss: 0.4771 - accuracy: 0.7500 - val_loss: 0.4833 - val_accuracy: 0.7392
Epoch 11/20
25/25 [==============================] - 0s 9ms/step - loss: 0.4672 - accuracy: 0.7500 - val_loss: 0.4746 - val_accuracy: 0.7392
Epoch 12/20
25/25 [==============================] - 0s 10ms/step - loss: 0.4578 - accuracy: 0.7500 - val_loss: 0.4665 - val_accuracy: 0.7404
Epoch 13/20
25/25 [==============================] - 0s 9ms/step - loss: 0.4490 - accuracy: 0.7520 - val_loss: 0.4592 - val_accuracy: 0.7459
Epoch 14/20
25/25 [==============================] - 0s 10ms/step - loss: 0.4411 - accuracy: 0.7600 - val_loss: 0.4528 - val_accuracy: 0.7566
Epoch 15/20
25/25 [==============================] - 0s 9ms/step - loss: 0.4340 - accuracy: 0.7732 - val_loss: 0.4472 - val_accuracy: 0.7668
Epoch 16/20
25/25 [==============================] - 0s 9ms/step - loss: 0.4278 - accuracy: 0.7784 - val_loss: 0.4424 - val_accuracy: 0.7760
Epoch 17/20
25/25 [==============================] - 0s 9ms/step - loss: 0.4224 - accuracy: 0.7872 - val_loss: 0.4381 - val_accuracy: 0.7823
Epoch 18/20
25/25 [==============================] - 0s 9ms/step - loss: 0.4175 - accuracy: 0.7936 - val_loss: 0.4343 - val_accuracy: 0.7893
Epoch 19/20
25/25 [==============================] - 0s 9ms/step - loss: 0.4131 - accuracy: 0.8012 - val_loss: 0.4309 - val_accuracy: 0.7920
Epoch 20/20
25/25 [==============================] - 0s 9ms/step - loss: 0.4089 - accuracy: 0.8080 - val_loss: 0.4275 - val_accuracy: 0.7939
CPU times: user 7.56 s, sys: 1.58 s, total: 9.14 s
Wall time: 5.71 s
<keras.callbacks.History at 0x7fd67a3fc1f0>

Let's evaluate the preprocessing and the part with the two neural networks only:

evaluation_nn_only = ensemble_nn_only.evaluate(test_dataset, return_dict=True)
print("Accuracy (NN #1 and #2 only): ", evaluation_nn_only["accuracy"])
print("Loss (NN #1 and #2 only): ", evaluation_nn_only["loss"])
100/100 [==============================] - 0s 2ms/step - loss: 0.4275 - accuracy: 0.7939
Accuracy (NN #1 and #2 only):  0.7939000129699707
Loss (NN #1 and #2 only):  0.4275112450122833

Let's train the two Decision Forest components (one after another).

%%time
train_dataset_with_preprocessing = train_dataset.map(lambda x,y: (preprocessor(x), y))
test_dataset_with_preprocessing = test_dataset.map(lambda x,y: (preprocessor(x), y))

model_3.fit(train_dataset_with_preprocessing)
model_4.fit(train_dataset_with_preprocessing)
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7fd72b928dc0> and will run it as-is.
Cause: could not parse the source code of <function <lambda> at 0x7fd72b928dc0>: no matching AST found among candidates:

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7fd72b928dc0> and will run it as-is.
Cause: could not parse the source code of <function <lambda> at 0x7fd72b928dc0>: no matching AST found among candidates:

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function <lambda> at 0x7fd72b928dc0> and will run it as-is.
Cause: could not parse the source code of <function <lambda> at 0x7fd72b928dc0>: no matching AST found among candidates:

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7fd788103b80> and will run it as-is.
Cause: could not parse the source code of <function <lambda> at 0x7fd788103b80>: no matching AST found among candidates:

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7fd788103b80> and will run it as-is.
Cause: could not parse the source code of <function <lambda> at 0x7fd788103b80>: no matching AST found among candidates:

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function <lambda> at 0x7fd788103b80> and will run it as-is.
Cause: could not parse the source code of <function <lambda> at 0x7fd788103b80>: no matching AST found among candidates:

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
Reading training dataset...
Training dataset read in 0:00:03.065826. Found 2500 examples.
Training model...
[INFO kernel.cc:1176] Loading model from path /tmpfs/tmp/tmpophji3ym/model/ with prefix 000d92fb6204406a
Model trained in 0:00:01.644162
Compiling model...
[INFO abstract_model.cc:1248] Engine "RandomForestOptPred" built
[INFO kernel.cc:1022] Use fast generic engine
WARNING:tensorflow:AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7fd67db69550> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7fd67db69550> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7fd67db69550> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
Model compiled.
Reading training dataset...
Training dataset read in 0:00:00.218492. Found 2500 examples.
Training model...
[INFO kernel.cc:1176] Loading model from path /tmpfs/tmp/tmpg88ve94z/model/ with prefix 992512c7c7f24247
Model trained in 0:00:01.533185
Compiling model...
[INFO kernel.cc:1022] Use fast generic engine
Model compiled.
CPU times: user 20.3 s, sys: 1.1 s, total: 21.4 s
Wall time: 7.82 s
<keras.callbacks.History at 0x7fd72b6dd8e0>

And let's evaluate the Decision Forests individually.

model_3.compile(["accuracy"])
model_4.compile(["accuracy"])

evaluation_df3_only = model_3.evaluate(
    test_dataset_with_preprocessing, return_dict=True)
evaluation_df4_only = model_4.evaluate(
    test_dataset_with_preprocessing, return_dict=True)

print("Accuracy (DF #3 only): ", evaluation_df3_only["accuracy"])
print("Accuracy (DF #4 only): ", evaluation_df4_only["accuracy"])
100/100 [==============================] - 1s 10ms/step - loss: 0.0000e+00 - accuracy: 0.8056
100/100 [==============================] - 1s 10ms/step - loss: 0.0000e+00 - accuracy: 0.8064
Accuracy (DF #3 only):  0.8055999875068665
Accuracy (DF #4 only):  0.8064000010490417

Let's evaluate the entire model composition:

ensemble_nn_and_df.compile(
    loss=tf.keras.losses.BinaryCrossentropy(), metrics=["accuracy"])

evaluation_nn_and_df = ensemble_nn_and_df.evaluate(
    test_dataset, return_dict=True)

print("Accuracy (2xNN and 2xDF): ", evaluation_nn_and_df["accuracy"])
print("Loss (2xNN and 2xDF): ", evaluation_nn_and_df["loss"])
100/100 [==============================] - 1s 10ms/step - loss: 0.4043 - accuracy: 0.8072
Accuracy (2xNN and 2xDF):  0.807200014591217
Loss (2xNN and 2xDF):  0.4043433368206024

To finish, let's finetune the neural network layer a bit more. Note that we do not finetune the pre-trained embedding as the DF models depends on it (unless we would also retrain them after).

In summary, you have:

Accuracy (NN #1 and #2 only): 0.793900
Accuracy (DF #3 only):        0.805600
Accuracy (DF #4 only):        0.806400
----------------------------------------
Accuracy (2xNN and 2xDF): 0.807200
                  +0.013300 over NN #1 and #2 only
                  +0.001600 over DF #3 only
                  +0.000800 over DF #4 only

Here, you can see that the composed model performs better than its individual parts. This is why ensembles work so well.

What's next?

In this example, you saw how to combine decision forests with neural networks. An extra step would be to further train the neural network and the decision forests together.

In addition, for the sake of clarity, the decision forests received only the preprocessed input. However, decision forests are generally great are consuming raw data. The model would be improved by also feeding the raw features to the decision forest models.

In this example, the final model is the average of the predictions of the individual models. This solution works well if all of the model perform more of less with the same. However, if one of the sub-models is very good, aggregating it with other models might actually be detrimental (or vice-versa; for example try to reduce the number of examples from 1k and see how it hurts the neural networks a lot; or enable the SPARSE_OBLIQUE split in the second Random Forest model).