Predictive prefetching with TensorFlow.js

In this tutorial, you’ll run an example web application that uses TensorFlow.js to do predictive prefetching of resources. Built with Angular, the example is inspired by the Google Merchandise Store but doesn't share any data or implementation details with it.

The example uses a pre-trained model to make predictions. In a real-world scenario, you would need to train a model using analytics from your website. You can use TFX to do such training. To learn more about training a custom model for predictive prefetching, see this blog post.

The example code is available on GitHub.

Prerequisites

To complete this tutorial, you need the following installed in your development environment:

Install the example

Get the source code and install dependencies:

  1. Clone or download the tfjs-examples repository.
  2. Change into the angular-predictive-prefetching/client directory and install dependencies:

    cd tfjs-examples/angular-predictive-prefetching/client && yarn
    
  3. Change into the angular-predictive-prefetching/server directory and install dependencies:

    cd ../server && yarn
    

Run the example

Start both the server and the client:

  1. Start the server: In the server directory, run yarn start.

  2. Start the client:

    1. Open another terminal window.
    2. Change into the tfjs-examples/angular-predictive-prefetching/client.
    3. Run the following commands:

      yarn build
      cd dist/merch-store
      npx serve -s .
      

      You might be prompted to install the serve package. If so, enter y to install the package.

  3. Navigate to http://localhost:3000 in your browser. You should see a mock Google merchandise store.

Explore with DevTools

Use Chrome DevTools to see prefetching in action:

  1. Open DevTools and select Console.
  2. Navigate to a few different pages in the application, to prime the app. Then select Sale in the left navigation. You should see log output like this:

    Navigating from: 'sale'
    'quickview' -> 0.381757915019989
    'apparel-unisex' -> 0.3150934875011444
    'store.html' -> 0.1957530975341797
    '' -> 0.052346792072057724
    'signin.html' -> 0.0007763378671370447
    

    This output shows predictions for the page that you (the user) will visit next. The application fetches resources based on these predictions.

  3. To see the fetch requests, select Network. The output is a little bit noisy, but you should be able to find requests for resources for the predicted pages. For example, after predicting quickview, the application makes a request to http://localhost:8000/api/merch/quickview.

How predictive prefetching works

The example app uses a pre-trained model to predict the page that a user will visit next. When the user navigates to a new page, the app queries the model and then prefetches images associated with predicted pages.

The app does the predictive prefetching on a service worker, so that it can query the model without blocking the main thread. Based on the user's navigation history, the service worker makes predictions for future navigation and prefetches relevant product images.

The service worker is loaded in the main file of the Angular app, main.ts:

if ('serviceWorker' in navigator) {
  navigator.serviceWorker.register('/prefetch.service-worker.js', { scope: '/' });
}

The snippet above downloads the prefetch.service-worker.js script and runs it in the background.

In merch-display.component.ts, the app forwards navigation events to the service worker:

this.route.params.subscribe((routeParams) => {
  this.getMerch(routeParams.category);
  if (this._serviceWorker) {
    this._serviceWorker.postMessage({ page: routeParams.category });
  }
});

In the snippet above, the app watches for changes to the parameters of the URL. On change, the script forwards the category of the page to the service worker.

The service worker script, prefetch.service-worker.js, handles messages from the main thread, makes predictions based on them, and prefetches the relevant resources.

The service worker uses loadGraphModel to load the pre-trained model:

const MODEL_URL = "/assets/model.json";

let model = null;
tf.loadGraphModel(MODEL_URL).then((m) => (model = m));

The prediction happens in the following function expression:

const predict = async (path, userId) => {
  if (!model) {
    return;
  }
  const page = pages.indexOf(path);
  const pageId = tf.tensor1d([parseInt(page)], "int32");

  const sessionIndex = tf.tensor1d([parseInt(userId)], "int32");

  const result = model.predict({
    cur_page: pageId,
    session_index: sessionIndex,
  });
  const values = result.dataSync();
  const orders = sortWithIndices(values).slice(0, 5);
  return orders;
};

The predict function is then invoked by the prefetch function:

const prefetch = async (path, sessionId) => {
  const predictions = await predict(path, sessionId);
  const formattedPredictions = predictions
    .map(([a, b]) => `'${b}' -> ${a}`)
    .join("\n");
  console.log(`Navigating from: '${path}'`);
  console.log(formattedPredictions);
  const connectionSpeed = navigator.connection.effectiveType;
  const threshold = connectionSpeeds[connectionSpeed];
  const cache = await caches.open(ImageCache);
  predictions.forEach(async ([probability, category]) => {
    if (probability >= threshold) {
      const merchs = (await getMerchList(category)).map(getUrl);
      [...new Set(merchs)].forEach((url) => {
        const request = new Request(url, {
          mode: "no-cors",
        });
        fetch(request).then((response) => cache.put(request, response));
      });
    }
  });
};

First, prefetch predicts the pages that the user might visit next. Then it iterates over the predictions. For each prediction, if the probability exceeds a certain threshold based on connection speed, the function fetches resources for the predicted page. By fetching these resources before the next page request, the app can potentially serve content faster and provide a better user experience.

What’s next

In this tutorial, the example app uses a pre-trained model to make predictions. You can use TFX to train a model for predictive prefetching. To learn more, see Speed-up your sites with web-page prefetching using Machine Learning.