Robust machine learning on streaming data using Kafka and Tensorflow-IO

View on Run in Google Colab View source on GitHub Download notebook


This tutorial focuses on streaming data from a Kafka cluster into a which is then used in conjunction with tf.keras for training and inference.

Kafka is primarily a distributed event-streaming platform which provides scalable and fault-tolerant streaming data across data pipelines. It is an essential technical component of a plethora of major enterprises where mission-critical data delivery is a primary requirement.


Install the required tensorflow-io and kafka packages

print("Installing the tensorflow-io package !")
!pip install -q tensorflow-io

print("Installing the kafka-python package !")
!pip install -q kafka-python
Installing the tensorflow-io package !
WARNING: You are using pip version 20.2.1; however, version 20.2.2 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.
Installing the kafka-python package !
WARNING: You are using pip version 20.2.1; however, version 20.2.2 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.

Import packages

import os
from datetime import datetime
import threading
import json
from kafka import KafkaProducer
from kafka.errors import KafkaError
from sklearn.model_selection import train_test_split
import pandas as pd
import tensorflow as tf
import tensorflow_io as tfio

Validate tf and tfio imports

print("tensorflow-io version: {}".format(tfio.__version__))
print("tensorflow version: {}".format(tf.__version__))
tensorflow-io version: 0.15.0
tensorflow version: 2.3.0

Download and setup Kafka and Zookeeper instances

For demo purposes, the following instances are setup locally:

  • Kafka (Brokers:
  • Zookeeper (Node:
curl -sSOL
tar -xzf confluent-community-5.4.1-2.12.tar.gz

Using the default configurations (provided by the confluent package) for spinning up the instances.

cd confluent-5.4.1 && bin/zookeeper-server-start -daemon etc/kafka/
cd confluent-5.4.1 && bin/kafka-server-start -daemon etc/kafka/
cd confluent-5.4.1 && bin/schema-registry-start -daemon etc/schema-registry/
echo "Waiting for 10 secs until kafka, zookeeper and schema registry services are up and running"
sleep 10

Waiting for 10 secs until kafka, zookeeper and schema registry services are up and running

Once the instances are started as daemon processes, grep for kafka in the processes list. The three java processes correspond to zookeeper, kafka and the schema-registry instances.

ps -ef | grep kafka
kbuilder 22384 19519  2 19:52 ?        00:00:00 python /tmpfs/src/gfile/ --input_notebook=/tmpfs/src/temp/docs/tutorials/kafka.ipynb --timeout=15000
kbuilder 22484     1 16 19:52 ?        00:00:01 java -Xmx512M -Xms512M -server -XX:+UseG1GC -XX:MaxGCPauseMillis=20 -XX:InitiatingHeapOccupancyPercent=35 -XX:+ExplicitGCInvokesConcurrent -Djava.awt.headless=true -Xlog:gc*:file=/tmpfs/src/temp/docs/tutorials/confluent-5.4.1/bin/../logs/zookeeper-gc.log:time,tags:filecount=10,filesize=102400 -Dkafka.logs.dir=/tmpfs/src/temp/docs/tutorials/confluent-5.4.1/bin/../logs -Dlog4j.configuration=file:bin/../etc/kafka/ -cp /tmpfs/src/temp/docs/tutorials/confluent-5.4.1/bin/../share/java/kafka/*:/tmpfs/src/temp/docs/tutorials/confluent-5.4.1/bin/../support-metrics-client/build/dependant-libs-2.12.10/*:/tmpfs/src/temp/docs/tutorials/confluent-5.4.1/bin/../support-metrics-client/build/libs/*:/usr/share/java/support-metrics-client/* org.apache.zookeeper.server.quorum.QuorumPeerMain etc/kafka/
kbuilder 22533     1 77 19:52 ?        00:00:08 java -Xmx1G -Xms1G -server -XX:+UseG1GC -XX:MaxGCPauseMillis=20 -XX:InitiatingHeapOccupancyPercent=35 -XX:+ExplicitGCInvokesConcurrent -Djava.awt.headless=true -Xlog:gc*:file=/tmpfs/src/temp/docs/tutorials/confluent-5.4.1/bin/../logs/kafkaServer-gc.log:time,tags:filecount=10,filesize=102400 -Dkafka.logs.dir=/tmpfs/src/temp/docs/tutorials/confluent-5.4.1/bin/../logs -Dlog4j.configuration=file:bin/../etc/kafka/ -cp /tmpfs/src/temp/docs/tutorials/confluent-5.4.1/bin/../share/java/kafka/*:/tmpfs/src/temp/docs/tutorials/confluent-5.4.1/bin/../support-metrics-client/build/dependant-libs-2.12.10/*:/tmpfs/src/temp/docs/tutorials/confluent-5.4.1/bin/../support-metrics-client/build/libs/*:/usr/share/java/support-metrics-client/* etc/kafka/
kbuilder 22782 22388  0 19:52 pts/0    00:00:00 /bin/bash -c ps -ef | grep kafka
kbuilder 22784 22782  0 19:52 pts/0    00:00:00 grep kafka

Create the kafka topics with the following specs:

  • susy-train: partitions=1, replication-factor=1
  • susy-test: partitions=2, replication-factor=1
confluent-5.4.1/bin/kafka-topics --create --zookeeper --replication-factor 1 --partitions 1 --topic susy-train
confluent-5.4.1/bin/kafka-topics --create --zookeeper --replication-factor 1 --partitions 2 --topic susy-test

Created topic susy-train.
Created topic susy-test.

Describe the topic for details on the configuration

confluent-5.4.1/bin/kafka-topics --bootstrap-server --describe --topic susy-train
confluent-5.4.1/bin/kafka-topics --bootstrap-server --describe --topic susy-test

Topic: susy-train   PartitionCount: 1   ReplicationFactor: 1    Configs: segment.bytes=1073741824
    Topic: susy-train   Partition: 0    Leader: 0   Replicas: 0 Isr: 0
Topic: susy-test    PartitionCount: 2   ReplicationFactor: 1    Configs: segment.bytes=1073741824
    Topic: susy-test    Partition: 0    Leader: 0   Replicas: 0 Isr: 0
    Topic: susy-test    Partition: 1    Leader: 0   Replicas: 0 Isr: 0

The replication factor 1 indicates that the data is not being replicated. This is due to the presence of a single broker in our kafka setup. In production systems, the number of bootstrap servers can be in the range of 100's of nodes. That is where the fault-tolerance using replication comes into picture.

Please refer the docs for more details.


Kafka being an event streaming platform, enables data from various sources to be written into it. For instance:

  • Web traffic logs
  • Astronomical measurements
  • IoT sensor data
  • Product reviews and many more.

For the purpose of this tutorial, lets download the SUSY dataset and feed the data into kafka manually. The goal of this classification problem is to distinguish between a signal process which produces supersymmetric particles and a background process which does not.

curl -sSOL

Explore the dataset

The first column is the class label (1 for signal, 0 for background), followed by the 18 features (8 low-level features then 10 high-level features). The first 8 features are kinematic properties measured by the particle detectors in the accelerator. The last 10 features are functions of the first 8 features. These are high-level features derived by physicists to help discriminate between the two classes.

          #  labels
          #  low-level features
          #  high-level derived features

The entire dataset consists of 5 million rows. However, for the purpose of this tutorial, let's consider only a fraction of the dataset (100,000 rows) so that less time is spent on the moving the data and more time on understanding the functionality of the api.

susy_iterator = pd.read_csv('SUSY.csv.gz', header=None, names=COLUMNS, chunksize=100000)
susy_df = next(susy_iterator)
# Number of datapoints and columns
len(susy_df), len(susy_df.columns)
(100000, 19)
# Number of datapoints belonging to each class (0: background noise, 1: signal)
len(susy_df[susy_df["class"]==0]), len(susy_df[susy_df["class"]==1])
(54025, 45975)

Split the dataset

train_df, test_df = train_test_split(susy_df, test_size=0.4, shuffle=True)
print("Number of training samples: ",len(train_df))
print("Number of testing sample: ",len(test_df))

x_train_df = train_df.drop(["class"], axis=1)
y_train_df = train_df["class"]

x_test_df = test_df.drop(["class"], axis=1)
y_test_df = test_df["class"]

# The labels are set as the kafka message keys so as to store data
# in multiple-partitions. Thus, enabling efficient data retrieval
# using the consumer groups.
x_train = list(filter(None, x_train_df.to_csv(index=False).split("\n")[1:]))
y_train = list(filter(None, y_train_df.to_csv(index=False).split("\n")[1:]))

x_test = list(filter(None, x_test_df.to_csv(index=False).split("\n")[1:]))
y_test = list(filter(None, y_test_df.to_csv(index=False).split("\n")[1:]))

Number of training samples:  60000
Number of testing sample:  40000

NUM_COLUMNS = len(x_train_df.columns)
len(x_train), len(y_train), len(x_test), len(y_test)
(60000, 60000, 40000, 40000)

Store the train and test data in kafka

Storing the data in kafka simulates an environment for continuous remote data retrieval for training and inference purposes.

def error_callback(exc):
    raise Exception('Error while sendig data to kafka: {0}'.format(str(exc)))

def write_to_kafka(topic_name, items):
  producer = KafkaProducer(bootstrap_servers=[''])
  for message, key in items:
    producer.send(topic_name, key=key.encode('utf-8'), value=message.encode('utf-8')).add_errback(error_callback)
  print("Wrote {0} messages into topic: {1}".format(count, topic_name))

write_to_kafka("susy-train", zip(x_train, y_train))
write_to_kafka("susy-test", zip(x_test, y_test))

Wrote 60000 messages into topic: susy-train
Wrote 40000 messages into topic: susy-test

Define the tfio train dataset

The IODataset class is utilized for streaming data from kafka into tensorflow. The class inherits from and thus has all the useful functionalities of out of the box.

def decode_kafka_item(item):
  message =, [[0.0] for i in range(NUM_COLUMNS)])
  key = tf.strings.to_number(item.key)
  return (message, key)

train_ds = tfio.IODataset.from_kafka('susy-train', partition=0, offset=0)
train_ds = train_ds.shuffle(buffer_size=SHUFFLE_BUFFER_SIZE)
train_ds =
train_ds = train_ds.batch(BATCH_SIZE)

Build and train the model

# Set the parameters


# design/build the model
model = tf.keras.Sequential([
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(256, activation='relu'),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(1, activation='sigmoid')

Model: "sequential"
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 128)               2432      
dropout (Dropout)            (None, 128)               0         
dense_1 (Dense)              (None, 256)               33024     
dropout_1 (Dropout)          (None, 256)               0         
dense_2 (Dense)              (None, 128)               32896     
dropout_2 (Dropout)          (None, 128)               0         
dense_3 (Dense)              (None, 1)                 129       
Total params: 68,481
Trainable params: 68,481
Non-trainable params: 0

# compile the model
model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=METRICS)
# fit the model, epochs=10)
Epoch 1/10
938/938 [==============================] - 7s 8ms/step - loss: 0.6217 - accuracy: 0.7521
Epoch 2/10
938/938 [==============================] - 7s 8ms/step - loss: 0.6132 - accuracy: 0.7747
Epoch 3/10
938/938 [==============================] - 7s 7ms/step - loss: 0.6113 - accuracy: 0.7786
Epoch 4/10
938/938 [==============================] - 7s 7ms/step - loss: 0.6108 - accuracy: 0.7794
Epoch 5/10
938/938 [==============================] - 7s 7ms/step - loss: 0.6100 - accuracy: 0.7820
Epoch 6/10
938/938 [==============================] - 7s 7ms/step - loss: 0.6096 - accuracy: 0.7825
Epoch 7/10
938/938 [==============================] - 7s 7ms/step - loss: 0.6096 - accuracy: 0.7819
Epoch 8/10
938/938 [==============================] - 7s 7ms/step - loss: 0.6094 - accuracy: 0.7830
Epoch 9/10
938/938 [==============================] - 7s 7ms/step - loss: 0.6087 - accuracy: 0.7841
Epoch 10/10
938/938 [==============================] - 7s 7ms/step - loss: 0.6084 - accuracy: 0.7841

<tensorflow.python.keras.callbacks.History at 0x7f4aed94eac8>

Since only a fraction of the dataset is being utilized, our accuracy is limited to ~78% during the training phase. However, please feel free to store additional data in kafka for better a model performance. Also, since our goal was to just demonstrate the functionality of the tfio kafka datasets, a smaller and less-complicated neural network was used. However, one can increase the complexity of the model, modify the learning strategy, tune hyper-parameters etc for exploration purposes. For a baseline approach, please refer to this article.

Infer on the test data

For faster and scalable inference, lets utilize the streaming.KafkaGroupIODataset class.

Define the tfio test dataset

test_ds = tfio.experimental.streaming.KafkaGroupIODataset(

test_ds =
test_ds = test_ds.batch(BATCH_SIZE)

Though this class can be used for training purposes, there are caveats which need to be addressed. Once all the messages are read from kafka and the latest offsets are committed using the streaming.KafkaGroupIODataset, the consumer doesn't restart reading the messages from the beginning. Thus, while training, it is possible only to train for a single epoch with the data continuously flowing in. This kind of a functionality has limited use cases during the training phase wherein, once a datapoint has been consumed by the model it is no longer required and can be discarded.

However, this functionality shines when it comes to robust inference with exactly-once semantics.

evaluate the performance on the test data

res = model.evaluate(test_ds)
print("test loss, test acc:", res)

625/625 [==============================] - 18s 29ms/step - loss: 0.6093 - accuracy: 0.7749
test loss, test acc: [0.609319269657135, 0.7748749852180481]

Track the offset lag of the testcg consumer group

confluent-5.4.1/bin/kafka-consumer-groups --bootstrap-server --describe --group testcg

GROUP           TOPIC           PARTITION  CURRENT-OFFSET  LOG-END-OFFSET  LAG             CONSUMER-ID                                  HOST            CLIENT-ID
testcg          susy-test       0          21624           21624           0               rdkafka-99f87920-7641-484d-ba6c-415aa2e784d5 /    rdkafka
testcg          susy-test       1          18376           18376           0               rdkafka-99f87920-7641-484d-ba6c-415aa2e784d5 /    rdkafka

Once the current-offset matches the log-end-offset for all the partitions, it indicates that the consumer(s) have completed fetching all the messages from the kafka topic.