-
Notifications
You must be signed in to change notification settings - Fork 956
/
densenet-cifar10-2.4.1.py
238 lines (203 loc) · 8.2 KB
/
densenet-cifar10-2.4.1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
"""Trains a 100-Layer DenseNet on the CIFAR10 dataset.
With data augmentation:
Greater than 93.55% test accuracy in 200 epochs
225sec per epoch on GTX 1080Ti
Densely Connected Convolutional Networks
https://arxiv.org/pdf/1608.06993.pdf
http://openaccess.thecvf.com/content_cvpr_2017/papers/
Huang_Densely_Connected_Convolutional_CVPR_2017_paper.pdf
Network below is similar to 100-Layer DenseNet-BC (k=12)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.keras.layers import Dense, Conv2D, BatchNormalization
from tensorflow.keras.layers import MaxPooling2D, AveragePooling2D
from tensorflow.keras.layers import Input, Flatten, Dropout
from tensorflow.keras.layers import concatenate, Activation
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import plot_model
from tensorflow.keras.utils import to_categorical
import os
import numpy as np
import math
# training parameters
batch_size = 32
epochs = 200
data_augmentation = True
# network parameters
num_classes = 10
num_dense_blocks = 3
use_max_pool = False
# DenseNet-BC with dataset augmentation
# Growth rate | Depth | Accuracy (paper)| Accuracy (this) |
# 12 | 100 | 95.49% | 93.74% |
# 24 | 250 | 96.38% | requires big mem GPU |
# 40 | 190 | 96.54% | requires big mem GPU |
growth_rate = 12
depth = 100
num_bottleneck_layers = (depth - 4) // (2 * num_dense_blocks)
num_filters_bef_dense_block = 2 * growth_rate
compression_factor = 0.5
# load the CIFAR10 data
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
# input image dimensions
input_shape = x_train.shape[1:]
# mormalize data
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
print('y_train shape:', y_train.shape)
# convert class vectors to binary class matrices.
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)
def lr_schedule(epoch):
"""Learning Rate Schedule
Learning rate is scheduled to be reduced after 80, 120, 160, 180 epochs.
Called automatically every epoch as part of callbacks during training.
# Arguments
epoch (int): The number of epochs
# Returns
lr (float32): learning rate
"""
lr = 1e-3
if epoch > 180:
lr *= 0.5e-3
elif epoch > 160:
lr *= 1e-3
elif epoch > 120:
lr *= 1e-2
elif epoch > 80:
lr *= 1e-1
print('Learning rate: ', lr)
return lr
# start model definition
# densenet CNNs (composite function) are made of BN-ReLU-Conv2D
inputs = Input(shape=input_shape)
x = BatchNormalization()(inputs)
x = Activation('relu')(x)
x = Conv2D(num_filters_bef_dense_block,
kernel_size=3,
padding='same',
kernel_initializer='he_normal')(x)
x = concatenate([inputs, x])
# stack of dense blocks bridged by transition layers
for i in range(num_dense_blocks):
# a dense block is a stack of bottleneck layers
for j in range(num_bottleneck_layers):
y = BatchNormalization()(x)
y = Activation('relu')(y)
y = Conv2D(4 * growth_rate,
kernel_size=1,
padding='same',
kernel_initializer='he_normal')(y)
if not data_augmentation:
y = Dropout(0.2)(y)
y = BatchNormalization()(y)
y = Activation('relu')(y)
y = Conv2D(growth_rate,
kernel_size=3,
padding='same',
kernel_initializer='he_normal')(y)
if not data_augmentation:
y = Dropout(0.2)(y)
x = concatenate([x, y])
# no transition layer after the last dense block
if i == num_dense_blocks - 1:
continue
# transition layer compresses num of feature maps and reduces the size by 2
num_filters_bef_dense_block += num_bottleneck_layers * growth_rate
num_filters_bef_dense_block = int(num_filters_bef_dense_block * compression_factor)
y = BatchNormalization()(x)
y = Conv2D(num_filters_bef_dense_block,
kernel_size=1,
padding='same',
kernel_initializer='he_normal')(y)
if not data_augmentation:
y = Dropout(0.2)(y)
x = AveragePooling2D()(y)
# add classifier on top
# after average pooling, size of feature map is 1 x 1
x = AveragePooling2D(pool_size=8)(x)
y = Flatten()(x)
outputs = Dense(num_classes,
kernel_initializer='he_normal',
activation='softmax')(y)
# instantiate and compile model
# orig paper uses SGD but RMSprop works better for DenseNet
model = Model(inputs=inputs, outputs=outputs)
model.compile(loss='categorical_crossentropy',
optimizer=RMSprop(1e-3),
metrics=['acc'])
model.summary()
# enable this if pydot can be installed
# pip install pydot
#plot_model(model, to_file="cifar10-densenet.png", show_shapes=True)
# prepare model model saving directory
save_dir = os.path.join(os.getcwd(), 'saved_models')
model_name = 'cifar10_densenet_model.{epoch:02d}.h5'
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
filepath = os.path.join(save_dir, model_name)
# prepare callbacks for model saving and for learning rate reducer
checkpoint = ModelCheckpoint(filepath=filepath,
monitor='val_acc',
verbose=1,
save_best_only=True)
lr_scheduler = LearningRateScheduler(lr_schedule)
lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),
cooldown=0,
patience=5,
min_lr=0.5e-6)
callbacks = [checkpoint, lr_reducer, lr_scheduler]
# run training, with or without data augmentation
if not data_augmentation:
print('Not using data augmentation.')
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_test, y_test),
shuffle=True,
callbacks=callbacks)
else:
print('Using real-time data augmentation.')
# preprocessing and realtime data augmentation
datagen = ImageDataGenerator(
featurewise_center=False, # set input mean to 0 over the dataset
samplewise_center=False, # set each sample mean to 0
featurewise_std_normalization=False, # divide inputs by std of dataset
samplewise_std_normalization=False, # divide each input by its std
zca_whitening=False, # apply ZCA whitening
rotation_range=0, # randomly rotate images in the range (deg 0 to 180)
width_shift_range=0.1, # randomly shift images horizontally
height_shift_range=0.1, # randomly shift images vertically
horizontal_flip=True, # randomly flip images
vertical_flip=False) # randomly flip images
# compute quantities required for featurewise normalization
# (std, mean, and principal components if ZCA whitening is applied)
datagen.fit(x_train)
steps_per_epoch = math.ceil(len(x_train) / batch_size)
# fit the model on the batches generated by datagen.flow().
model.fit(x=datagen.flow(x_train, y_train, batch_size=batch_size),
verbose=1,
epochs=epochs,
validation_data=(x_test, y_test),
steps_per_epoch=steps_per_epoch,
callbacks=callbacks)
# fit the model on the batches generated by datagen.flow()
#model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size),
## steps_per_epoch=x_train.shape[0] // batch_size,
# validation_data=(x_test, y_test),
# epochs=epochs, verbose=1,
# callbacks=callbacks)
# score trained model
scores = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])