Converting from HDF5 to tfrecord and reading tfrecords into tensorflow
Posted on Mon 29 April 2019 in Tensorflow
HDF5 is a popular file format for handling large complex datasets, often the type of datasets we want to use to train machine learning models in tensorflow. This example was made because I had to piece together several resources to convert my dataset and read it into tensorflow, so I wanted to put online a very simple and quick example for others.
Making an example hdf5 dataset¶
First, lets make a quick hdf5 dataset out of fashion-MNIST (which we can import from the tensorflow). To make the dataset diverse, we'll use some uint8
, float32
, int64
, and string
type data.
import tensorflow as tf
# load dataset
(train_images, train_labels), _ = tf.keras.datasets.mnist.load_data()
# take a look at the data
import numpy as np
np.shape(train_images), train_images.dtype, np.shape(train_labels), train_labels.dtype
# lets make some text labels for our dataset
text_label_dict = {0:"zero", 1:"one", 2:"two", 3:"three", 4:"four", 5:"five", 6:"six", 7:"seven", 8:"eight", 9:"nine"}
text_labels = [text_label_dict[i] for i in train_labels]
text_labels[:3]
# lets get everything into a dictionary so we can save it using deepdish
hdf5_dict = {
"text_labels": text_labels,
"train_images": train_images,
"float32_labels": train_labels.astype(np.float32),
"int64_labels": train_labels.astype(np.int64)
}
# now lets save it as hdf5 using deepdish for simplicity
import deepdish as dd
dd.io.save("myhdf5.hdf5", hdf5_dict, compression=None)
converting the HDF5 to tfrecord¶
Now that our data is HDF5 format, lets load it back up and convert it to tfrecord. First, we need to define functions to convert each "feature" (e.g. "one", 1, or a handwritten digit) into a tensorflow feature object. There are three types of objects, _bytes_feature
, _float_feature
, or _int64_feature
. All of our data needs to be converted into this format.
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# lets load our data back up
mydata = dd.io.load("myhdf5.hdf5")
mydata.keys()
# lets see what it looks like to convert one sample into a feature
print(mydata['float32_labels'][0], type(_float_feature(mydata['float32_labels'][0])))
Now, to convert a whole row of samples from our dataset into features, we need to serialize that example. To do this, I wrote a the function serialize_example
, which takes as an argument the data, its name, and what type of feature it will be written as.
def serialize_example(example):
"""Serialize an item in a dataset
Arguments:
example {[list]} -- list of dictionaries with fields "name" , "_type", and "data"
Returns:
[type] -- [description]
"""
dset_item = {}
for feature in example.keys():
dset_item[feature] = example[feature]["_type"](example[feature]["data"])
example_proto = tf.train.Example(features=tf.train.Features(feature=dset_item))
return example_proto.SerializeToString()
Lets test it on a single row:¶
row = 0
example = serialize_example(
{
"float32_labels": {
"data": mydata["float32_labels"][row],
"_type": _float_feature,
},
"int64_labels": {
"data": mydata["int64_labels"][row],
"_type": _int64_feature,
},
"text_labels": {
"data": np.string_(mydata["text_labels"][row]).astype("|S7"),
"_type": _bytes_feature,
},
"train_images": {
"data": mydata["train_images"][row].flatten().tobytes(),
"_type": _bytes_feature,
},
}
)
print('{}...'.format(example[:100]))
Write to a tfrecord file¶
from tqdm.autonotebook import tqdm # just a progressbar
n_observations = len(mydata["float32_labels"]) # how many items are in your dataset
# loop through hdf5 of examples, save to tfrecord
with tf.io.TFRecordWriter(str('myfile.tfrecord')) as writer:
# for each example
for exi in tqdm(range(n_observations)):
# create an item in the datset converted to the correct formats (float, int, byte)
example = serialize_example(
{
"float32_labels": {
"data": mydata["float32_labels"][exi],
"_type": _float_feature,
},
"int64_labels": {
"data": mydata["int64_labels"][exi],
"_type": _int64_feature,
},
"text_labels": {
"data": np.string_(mydata["text_labels"][exi]).astype("|S7"),
"_type": _bytes_feature,
},
"train_images": {
"data": mydata["train_images"][exi].flatten().tobytes(),
"_type": _bytes_feature,
},
}
)
# write the defined example into the dataset
writer.write(example)
Reading the data back¶
I'm showing two ways to read the dataset back. The first is a quick way to grab the data, and the second is more relevant for reading data into your tensorflow pipeline.
Read dataset with numpy + ParseFromString¶
# dataset class https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset
raw_dataset = tf.data.TFRecordDataset([str("myfile.tfrecord")])
# create a 'Dataset' with at most 'count' elements
dset = raw_dataset.take(count=10)
# grab a single element from that dataset
element = list(dset)[1]
# a "Feature message" https://www.tensorflow.org/api_docs/python/tf/train/Example
example = tf.train.Example()
# parse the element in to the example message
example.ParseFromString(element.numpy())
list(example.features.feature)
import matplotlib.pyplot as plt
%matplotlib inline
fig, ax = plt.subplots(ncols = 5, figsize=(15,3))
for i in range(5):
# grab a single element from that dataset
element = list(dset)[i]
# a "Feature message" https://www.tensorflow.org/api_docs/python/tf/train/Example
example = tf.train.Example()
# parse the element in to the example message
example.ParseFromString(element.numpy())
# subset the syllable
img_buff = dict(example.features.feature)['train_images']
# convert the buffer into a uint8
image = tf.io.decode_raw(img_buff.bytes_list.value[0], tf.uint8).numpy().reshape(28,28)
# show the image
ax[i].matshow(image, cmap=plt.cm.Greys)
string_label = (dict(example.features.feature)['text_labels'].bytes_list.value[0]).decode("utf-8")
ax[i].set_title(string_label)
ax[i].axis('off')
read the dataset directly into tensorflow¶
This will be be more useful for feeding directly into your graph. We need to parse this data back into its original data format, which tensorflow tensorflow doesnt store. The function below is taking an example from the dataset, reading it, and parsing it back into its original data type
if int(tf.__version__[0]) < 2:
from tensorflow import FixedLenFeature, parse_single_example
else:
from tensorflow.io import FixedLenFeature, parse_single_example
def _dtype_to_tf_feattype(dtype):
""" convert tf dtype to correct tffeature format
"""
if dtype in [tf.float32, tf.int64]:
return dtype
else:
return tf.string
def _parse_function(example_proto, data_types):
""" parse dataset from tfrecord, and convert to correct format
"""
# list features
features = {
lab: FixedLenFeature([], _dtype_to_tf_feattype(dtype))
for lab, dtype in data_types.items()
}
# parse features
parsed_features = parse_single_example(example_proto, features)
feat_dtypes = [tf.float32, tf.string, tf.int64]
# convert the features if they are in the wrong format
parse_list = [
parsed_features[lab]
if dtype in feat_dtypes
else tf.io.decode_raw(parsed_features[lab], dtype)
for lab, dtype in data_types.items()
]
return parse_list
# read the dataset
raw_dataset = tf.data.TFRecordDataset([str("myfile.tfrecord")])
data_types = {
"float32_labels": tf.float32,
"int64_labels": tf.int64,
"text_labels": tf.string,
"train_images": tf.uint8,
}
# parse each data type to the raw dataset
dataset = raw_dataset.map(lambda x: _parse_function(x, data_types=data_types))
# shuffle the dataset
dataset = dataset.shuffle(buffer_size=10000)
# create batches
dataset = dataset.batch(10)
float32_labs, int64_labs, string_labs, images = next(iter(dataset))
fig, ax = plt.subplots(ncols = 5, figsize=(15,3))
for i in range(5):
# show the image
ax[i].matshow(images[i].numpy().reshape(28,28), cmap=plt.cm.Greys)
string_label = string_labs[i].numpy().decode("utf-8")
ax[i].set_title(string_label)
ax[i].axis('off')