Variational Autoencoder in Tensorflow (Jupyter Notebook)
Posted on Sat 07 July 2018 in Machine Learning
Load packages¶
In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import layers
import tensorflow.contrib as tf_contrib
from tqdm import tqdm_notebook as tqdm # progress bar
import matplotlib.pyplot as plt
%matplotlib inline
Specify which GPU to use¶
In [2]:
# Allocate GPU
import os
gpu_to_use = 1
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_to_use)
Create the Variational Autoencoder object¶
In [3]:
weight_init = (
tf_contrib.layers.variance_scaling_initializer()
) # kaming init for encoder / decoder
weight_regularizer = tf_contrib.layers.l2_regularizer(scale=0.0001)
class VAE(object):
def __init__(self, params):
self.params = params
self.initialize_network()
def initialize_network(self):
""" Defines the network architecture
"""
# initialize graph and session
self.graph = tf.Graph()
self.config = tf.ConfigProto(
allow_soft_placement=True, log_device_placement=True
)
self.sess = tf.InteractiveSession(graph=self.graph, config=self.config)
# Global step needs to be defined to coordinate multi-GPU
self.global_step = tf.get_variable(
"global_step", [], initializer=tf.constant_initializer(0), trainable=False
)
self.define_network() # define the network
self.sess.run(tf.global_variables_initializer()) # Initialize the network
self.saver = tf.train.Saver() # initialize network saver
print("Network Initialized")
def encoder(self, X, scope="encoder", verbose=True):
with tf.variable_scope("encoder"):
encoder_net = [
tf.reshape(X, [self.params["batch_size"], np.prod(self.params["dims"])])
]
for layer_i in np.arange(2):
encoder_net.append(
tf_contrib.layers.fully_connected(
encoder_net[len(encoder_net) - 1],
num_outputs=200,
activation_fn=self.params["activation_fn"],
)
)
z_mean = tf_contrib.layers.fully_connected(
encoder_net[len(encoder_net) - 1],
num_outputs=self.params["n_hidden"],
activation_fn=None,
scope="z_mean",
)
z_std = tf_contrib.layers.fully_connected(
encoder_net[len(encoder_net) - 1],
num_outputs=self.params["n_hidden"],
activation_fn=None,
scope="z_std",
)
return z_mean, z_std
def decoder(self, Z, scope="decoder", verbose=True):
decoder_net = [Z]
for layer_i in np.arange(2):
decoder_net.append(
tf_contrib.layers.fully_connected(
decoder_net[len(decoder_net) - 1],
num_outputs=200,
activation_fn=self.params["activation_fn"],
)
)
x_reconstruction = tf_contrib.layers.fully_connected(
decoder_net[len(decoder_net) - 1],
num_outputs=int(np.prod(self.params["dims"])),
activation_fn=tf.nn.sigmoid,
)
return x_reconstruction
def define_network(self):
# define the input
self.x_real = tf.placeholder(
tf.float32, [self.params["batch_size"], np.prod(self.params["dims"])]
)
# run through the encoder
self.z_mean, self.z_std = self.encoder(self.x_real)
samples = tf.random_normal(
shape=tf.shape(self.z_std), mean=0, stddev=1, dtype=tf.float32
)
self.z = self.z_mean + tf.sqrt(tf.exp(self.z_std)) * samples
# run through the decoder
self.x_recon = self.decoder(self.z)
# losses
self.recon_loss = -tf.reduce_sum(
self.x_real * tf.log(1e-8 + self.x_recon)
+ (1 - self.x_real) * tf.log(1e-8 + 1 - self.x_recon),
1,
)
self.latent_loss = -0.5 * tf.reduce_sum(
1 + self.z_std - tf.square(self.z_mean) - tf.exp(self.z_std), axis=1
)
self.loss = tf.reduce_mean(
self.recon_loss + self.params["beta"] * self.latent_loss
)
self.latent_loss = tf.reduce_mean(self.latent_loss)
self.recon_loss = tf.reduce_mean(self.recon_loss)
# prepare optimizers
self.opt = tf.train.AdamOptimizer(
learning_rate=self.params["learning_rate"], epsilon=self.params["adam_eps"]
)
# specify loss to parameters
self.trainable_params = tf.trainable_variables()
# Calculate the gradients for the batch of data
self.grads = self.opt.compute_gradients(
self.loss, var_list=self.trainable_params
)
# Apply gradients
self.train = self.opt.apply_gradients(self.grads, global_step=self.global_step)
TRAIN¶
In [ ]:
dims = [28, 28]
In [ ]:
n_sample = 2
# set network parameters
params = {
"adam_eps": 1.0e-8,
"dims": dims,
"batch_size": 1000,
"activation_fn": tf.nn.elu,
"learning_rate": 0.001,
"n_hidden": 2, # how many neurons in the latent content layer
"beta": 10, # parameter corresponding to the emphasis on the latent term
}
In [ ]:
model = VAE(params) # create the model
In [ ]:
# load mnist dataset
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data", one_hot=False)
In [15]:
n_samples = mnist.train.num_examples
n_samples
Out[15]:
In [16]:
### For visualization
def plot_samples(ax, samples):
for index, sample in enumerate(samples):
ax[index].imshow(sample.reshape(28, 28), cmap="gray")
ax[index].axis("off")
recon_loss_list = []
latent_loss_list = []
example_data = mnist.test.next_batch(params["batch_size"])[0]
nex = 10 # number of samples to plot
for epoch in tqdm(range(51)):
recon_loss_list.append([])
latent_loss_list.append([])
for batch_i in tqdm(np.arange(int(n_samples / params["batch_size"])), leave=False):
batch = mnist.train.next_batch(params["batch_size"])[0]
_, recon_loss, lat_loss = model.sess.run(
(model.train, model.recon_loss, model.latent_loss),
feed_dict={model.x_real: batch},
)
latent_loss_list[-1].append(lat_loss)
recon_loss_list[-1].append(recon_loss)
if epoch % 10 == 0:
print(
"Epoch:",
"%04d" % (epoch + 1),
"cost=",
"{:.9f}".format(np.mean(recon_loss_list[-1])),
"{:.9f}".format(np.mean(latent_loss_list)),
)
fig, ax = plt.subplots(nrows=2, ncols=nex, figsize=(nex, 2))
plot_samples(ax[0], example_data[:nex])
plot_samples(
ax[1],
np.squeeze(model.sess.run([model.x_recon], {model.x_real: example_data}))[
:nex
],
)
plt.show()
Plot latent space representations¶
In [17]:
labels = []
points = []
for i in tqdm(range(int(n_samples / params["batch_size"]))):
images = mnist.test.next_batch(params["batch_size"])
labels.append(images[1])
points.append(model.sess.run([model.z], {model.x_real: images[0]}))
points = np.concatenate(np.concatenate(points))
In [18]:
def plot_latent(ax, codes, labels):
ax.scatter(codes[:, 0], codes[:, 1], s=1, c=labels, alpha=0.1)
ax.set_aspect("equal")
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 8))
plot_latent(ax[0], points, np.concatenate(labels))
plot_latent(
ax[1],
np.random.normal(0, 1, len(points) * 2).reshape(len(points), 2),
np.zeros(len(points)),
)
# for axi in ax: axi.set_ylim([-8,8]);axi.set_xlim([-8,8])
plt.show()