Converting from HDF5 to tfrecord and reading tfrecords into tensorflow

Posted on Mon 29 April 2019 in Tensorflow

HDF5 is a popular file format for handling large complex datasets, often the type of datasets we want to use to train machine learning models in tensorflow. This example was made because I had to piece together several resources to convert my dataset and read it into tensorflow, so I wanted to put online a very simple and quick example for others.

Making an example hdf5 dataset

First, lets make a quick hdf5 dataset out of fashion-MNIST (which we can import from the tensorflow). To make the dataset diverse, we'll use some uint8, float32, int64, and string type data.

In [1]:
import tensorflow as tf
# load dataset
(train_images, train_labels), _ = tf.keras.datasets.mnist.load_data()
In [2]:
# take a look at the data
import numpy as np
np.shape(train_images), train_images.dtype, np.shape(train_labels), train_labels.dtype
Out[2]:
((60000, 28, 28), dtype('uint8'), (60000,), dtype('uint8'))
In [3]:
# lets make some text labels for our dataset
text_label_dict = {0:"zero", 1:"one", 2:"two", 3:"three", 4:"four", 5:"five", 6:"six", 7:"seven", 8:"eight", 9:"nine"}
text_labels = [text_label_dict[i] for i in train_labels]
text_labels[:3]
Out[3]:
['five', 'zero', 'four']
In [4]:
# lets get everything into a dictionary so we can save it using deepdish
hdf5_dict = {
    "text_labels": text_labels,
    "train_images": train_images,
    "float32_labels": train_labels.astype(np.float32),
    "int64_labels": train_labels.astype(np.int64)
}
In [5]:
# now lets save it as hdf5 using deepdish for simplicity
import deepdish as dd
dd.io.save("myhdf5.hdf5", hdf5_dict, compression=None)

converting the HDF5 to tfrecord

Now that our data is HDF5 format, lets load it back up and convert it to tfrecord. First, we need to define functions to convert each "feature" (e.g. "one", 1, or a handwritten digit) into a tensorflow feature object. There are three types of objects, _bytes_feature, _float_feature, or _int64_feature. All of our data needs to be converted into this format.

In [6]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
In [7]:
# lets load our data back up
mydata = dd.io.load("myhdf5.hdf5")
mydata.keys()
Out[7]:
dict_keys(['float32_labels', 'int64_labels', 'text_labels', 'train_images'])
In [8]:
# lets see what it looks like to convert one sample into a feature
print(mydata['float32_labels'][0], type(_float_feature(mydata['float32_labels'][0])))
5.0 <class 'tensorflow.core.example.feature_pb2.Feature'>

Now, to convert a whole row of samples from our dataset into features, we need to serialize that example. To do this, I wrote a the function serialize_example, which takes as an argument the data, its name, and what type of feature it will be written as.

In [9]:
def serialize_example(example):
    """Serialize an item in a dataset
    Arguments:
      example {[list]} -- list of dictionaries with fields "name" , "_type", and "data"

    Returns:
      [type] -- [description]
    """
    dset_item = {}
    for feature in example.keys():
        dset_item[feature] = example[feature]["_type"](example[feature]["data"])
        example_proto = tf.train.Example(features=tf.train.Features(feature=dset_item))
    return example_proto.SerializeToString()

Lets test it on a single row:

In [10]:
row = 0
example = serialize_example(
    {
        "float32_labels": {
            "data": mydata["float32_labels"][row],
            "_type": _float_feature,
        },
        "int64_labels": {
            "data": mydata["int64_labels"][row],
            "_type": _int64_feature,
        },
        "text_labels": {
            "data": np.string_(mydata["text_labels"][row]).astype("|S7"),
            "_type": _bytes_feature,
        },
        "train_images": {
            "data": mydata["train_images"][row].flatten().tobytes(),
            "_type": _bytes_feature,
        },
    }
)
In [11]:
print('{}...'.format(example[:100]))
b'\n\xf6\x06\n\x1a\n\x0efloat32_labels\x12\x08\x12\x06\n\x04\x00\x00\xa0@\n\xa7\x06\n\x0ctrain_images\x12\x96\x06\n\x93\x06\n\x90\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'...

Write to a tfrecord file

In [12]:
from tqdm.autonotebook import tqdm # just a progressbar
n_observations = len(mydata["float32_labels"])  # how many items are in your dataset
# loop through hdf5 of examples, save to tfrecord
with tf.io.TFRecordWriter(str('myfile.tfrecord')) as writer:
    # for each example
    for exi in tqdm(range(n_observations)):
        # create an item in the datset converted to the correct formats (float, int, byte)
        example = serialize_example(
            {
                "float32_labels": {
                    "data": mydata["float32_labels"][exi],
                    "_type": _float_feature,
                },
                "int64_labels": {
                    "data": mydata["int64_labels"][exi],
                    "_type": _int64_feature,
                },
                "text_labels": {
                    "data": np.string_(mydata["text_labels"][exi]).astype("|S7"),
                    "_type": _bytes_feature,
                },
                "train_images": {
                    "data": mydata["train_images"][exi].flatten().tobytes(),
                    "_type": _bytes_feature,
                },
            }
        )
        # write the defined example into the dataset
        writer.write(example)
/mnt/cube/tsainbur/conda_envs/tpy3/lib/python3.6/site-packages/tqdm/autonotebook/__init__.py:14: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)
  " (e.g. in jupyter console)", TqdmExperimentalWarning)

Reading the data back

I'm showing two ways to read the dataset back. The first is a quick way to grab the data, and the second is more relevant for reading data into your tensorflow pipeline.

Read dataset with numpy + ParseFromString

In [13]:
# dataset class https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset
raw_dataset = tf.data.TFRecordDataset([str("myfile.tfrecord")])
In [14]:
# create a 'Dataset' with at most 'count' elements
dset = raw_dataset.take(count=10)
In [15]:
# grab a single element from that dataset
element = list(dset)[1]
# a "Feature message" https://www.tensorflow.org/api_docs/python/tf/train/Example
example = tf.train.Example()
# parse the element in to the example message
example.ParseFromString(element.numpy())
list(example.features.feature)
Out[15]:
['float32_labels', 'int64_labels', 'text_labels', 'train_images']
In [16]:
import matplotlib.pyplot as plt
%matplotlib inline
In [17]:
fig, ax = plt.subplots(ncols = 5, figsize=(15,3))
for i in range(5):
    # grab a single element from that dataset
    element = list(dset)[i]
    # a "Feature message" https://www.tensorflow.org/api_docs/python/tf/train/Example
    example = tf.train.Example()
    # parse the element in to the example message
    example.ParseFromString(element.numpy())
    # subset the syllable
    img_buff = dict(example.features.feature)['train_images']
    # convert the buffer into a uint8
    image = tf.io.decode_raw(img_buff.bytes_list.value[0], tf.uint8).numpy().reshape(28,28)
    # show the image
    ax[i].matshow(image, cmap=plt.cm.Greys)
    string_label = (dict(example.features.feature)['text_labels'].bytes_list.value[0]).decode("utf-8") 
    ax[i].set_title(string_label)
    ax[i].axis('off')

read the dataset directly into tensorflow

This will be be more useful for feeding directly into your graph. We need to parse this data back into its original data format, which tensorflow tensorflow doesnt store. The function below is taking an example from the dataset, reading it, and parsing it back into its original data type

In [18]:
if int(tf.__version__[0]) < 2:
    from tensorflow import FixedLenFeature, parse_single_example
else:
    from tensorflow.io import FixedLenFeature, parse_single_example
In [19]:
def _dtype_to_tf_feattype(dtype):
    """ convert tf dtype to correct tffeature format
    """
    if dtype in [tf.float32, tf.int64]:
        return dtype
    else:
        return tf.string
In [20]:
def _parse_function(example_proto, data_types):
    """ parse dataset from tfrecord, and convert to correct format
    """
    # list features
    features = {
        lab: FixedLenFeature([], _dtype_to_tf_feattype(dtype))
        for lab, dtype in data_types.items()
    }
    # parse features
    parsed_features = parse_single_example(example_proto, features)
    feat_dtypes = [tf.float32, tf.string, tf.int64]
    
    # convert the features if they are in the wrong format
    parse_list = [
        parsed_features[lab]
        if dtype in feat_dtypes
        else tf.io.decode_raw(parsed_features[lab], dtype)
        for lab, dtype in data_types.items()
    ]
    return parse_list
In [21]:
# read the dataset
raw_dataset = tf.data.TFRecordDataset([str("myfile.tfrecord")])
In [22]:
data_types = {
    "float32_labels": tf.float32,
    "int64_labels": tf.int64,
    "text_labels": tf.string,
    "train_images": tf.uint8,
}
In [23]:
# parse each data type to the raw dataset
dataset = raw_dataset.map(lambda x: _parse_function(x, data_types=data_types))
In [24]:
# shuffle the dataset
dataset = dataset.shuffle(buffer_size=10000)
# create batches
dataset = dataset.batch(10)
In [25]:
float32_labs, int64_labs, string_labs, images  = next(iter(dataset))
In [26]:
fig, ax = plt.subplots(ncols = 5, figsize=(15,3))
for i in range(5):
    # show the image
    ax[i].matshow(images[i].numpy().reshape(28,28), cmap=plt.cm.Greys)
    string_label = string_labs[i].numpy().decode("utf-8") 
    ax[i].set_title(string_label)
    ax[i].axis('off')