-
Notifications
You must be signed in to change notification settings - Fork 167
/
vae.py
87 lines (64 loc) · 2.94 KB
/
vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
'''TensorFlow implementation of http://arxiv.org/pdf/1312.6114v10.pdf'''
from __future__ import absolute_import, division, print_function
import math
import numpy as np
import tensorflow as tf
from tensorflow.contrib import layers
from tensorflow.contrib import losses
from tensorflow.contrib.framework import arg_scope
from utils import encoder, decoder
from generator import Generator
class VAE(Generator):
def __init__(self, hidden_size, batch_size, learning_rate):
self.input_tensor = tf.placeholder(
tf.float32, [None, 28 * 28])
with arg_scope([layers.conv2d, layers.conv2d_transpose],
activation_fn=tf.nn.elu,
normalizer_fn=layers.batch_norm,
normalizer_params={'scale': True}):
with tf.variable_scope("model") as scope:
encoded = encoder(self.input_tensor, hidden_size * 2)
mean = encoded[:, :hidden_size]
stddev = tf.sqrt(tf.exp(encoded[:, hidden_size:]))
epsilon = tf.random_normal([tf.shape(mean)[0], hidden_size])
input_sample = mean + epsilon * stddev
output_tensor = decoder(input_sample)
with tf.variable_scope("model", reuse=True) as scope:
self.sampled_tensor = decoder(tf.random_normal(
[batch_size, hidden_size]))
vae_loss = self.__get_vae_cost(mean, stddev)
rec_loss = self.__get_reconstruction_cost(
output_tensor, self.input_tensor)
loss = vae_loss + rec_loss
self.train = layers.optimize_loss(loss, tf.contrib.framework.get_or_create_global_step(
), learning_rate=learning_rate, optimizer='Adam', update_ops=[])
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())
def __get_vae_cost(self, mean, stddev, epsilon=1e-8):
'''VAE loss
See the paper
Args:
mean:
stddev:
epsilon:
'''
return tf.reduce_sum(0.5 * (tf.square(mean) + tf.square(stddev) -
2.0 * tf.log(stddev + epsilon) - 1.0))
def __get_reconstruction_cost(self, output_tensor, target_tensor, epsilon=1e-8):
'''Reconstruction loss
Cross entropy reconstruction loss
Args:
output_tensor: tensor produces by decoder
target_tensor: the target tensor that we want to reconstruct
epsilon:
'''
return tf.reduce_sum(-target_tensor * tf.log(output_tensor + epsilon) -
(1.0 - target_tensor) * tf.log(1.0 - output_tensor + epsilon))
def update_params(self, input_tensor):
'''Update parameters of the network
Args:
input_tensor: a batch of flattened images [batch_size, 28*28]
Returns:
Current loss value
'''
return self.sess.run(self.train, {self.input_tensor: input_tensor})