siamese/mnist_siamese_example.py

300 lines
10 KiB
Python
Raw Normal View History

"""
This is a modified version of the Keras mnist example.
https://keras.io/examples/mnist_cnn/
Instead of using a fixed number of epochs this version continues to train until a stop criteria is reached.
A siamese neural network is used to pre-train an embedding for the network. The resulting embedding is then extended
with a softmax output layer for categorical predictions.
Model performance should be around 99.84% after training. The resulting model is identical in structure to the one in
the example yet shows considerable improvement in relative error confirming that the embedding learned by the siamese
network is useful.
"""
from __future__ import print_function
import tensorflow.keras as keras
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Activation, Concatenate
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Flatten, Dense
from siamese import SiameseNetwork
import os, math, numpy as np
from PIL import Image
import pdb
batch_size = 128
num_classes = 131
epochs = 999999
# input image dimensions
img_rows, img_cols = 28, 28
def createTrainingData():
base_dir = '../towards/data/fruits-360/Training/'
train_test_split = 0.7
no_of_files_in_each_class = 80
#Read all the folders in the directory
folder_list = os.listdir(base_dir)
print( len(folder_list), "categories found in the dataset")
#Declare training array
cat_list = []
x = []
names = []
y = []
y_label = 0
#Using just 5 images per category
for folder_name in folder_list:
files_list = os.listdir(os.path.join(base_dir, folder_name))
temp=[]
for file_name in files_list[:no_of_files_in_each_class]:
temp.append(len(x))
x.append(np.asarray(Image.open(os.path.join(base_dir, folder_name, file_name)).convert('RGB').resize((img_rows, img_cols))))
names.append(folder_name + "/" + file_name)
y.append(y_label)
y_label+=1
cat_list.append(temp)
cat_list = np.asarray(cat_list)
x = np.asarray(x)/255.0
y = np.asarray(y)
print('X, Y shape',x.shape, y.shape, cat_list.shape)
#Training Split
x_train, y_train, cat_train, x_val, y_val, cat_test = [], [], [], [], [], []
train_split = math.floor((train_test_split) * no_of_files_in_each_class)
test_split = math.floor((1-train_test_split) * no_of_files_in_each_class)
train_count = 0
test_count = 0
for i in range(len(x)-1):
if i % no_of_files_in_each_class == 0:
cat_train.append([])
cat_test.append([])
class_train_count = 1
class_test_count = 1
if i % math.floor(1/train_test_split) == 0 and class_test_count < test_split:
x_val.append(x[i])
y_val.append(y[i])
cat_test[-1].append(test_count)
test_count += 1
class_test_count += 1
elif class_train_count < train_split:
x_train.append(x[i])
y_train.append(y[i])
cat_train[-1].append(train_count)
train_count += 1
class_train_count += 1
x_val = np.array(x_val)
y_val = np.array(y_val)
x_train = np.array(x_train)
y_train = np.array(y_train)
cat_train = np.array(cat_train)
cat_test = np.array(cat_test)
print('X&Y shape of training data :',x_train.shape, 'and',
y_train.shape, cat_train.shape)
print('X&Y shape of testing data :' , x_val.shape, 'and',
y_val.shape, cat_test.shape)
return (x_train, y_train), (x_val, y_val), cat_train
# the data, split between train and test sets
# (x_train, y_train), (x_test, y_test) = mnist.load_data()
# channels = 1
(x_train, y_train), (x_test, y_test), cat_train = createTrainingData()
channels = 3
if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], channels, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], channels, img_rows, img_cols)
input_shape = (channels, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, channels)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, channels)
input_shape = (img_rows, img_cols, channels)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
pdb.set_trace()
def create_own_base_model(input_shape):
model_input = Input(shape=input_shape)
embedding = Conv2D(32, kernel_size=(10, 10), input_shape=input_shape)(model_input)
embedding = MaxPooling2D(pool_size=(2, 2))(embedding)
embedding = Conv2D(64, kernel_size=(7, 7))(embedding)
embedding = MaxPooling2D(pool_size=(2, 2))(embedding)
embedding = Conv2D(128, kernel_size=(4, 4))(embedding)
embedding = MaxPooling2D(pool_size=(2, 2))(embedding)
embedding = Conv2D(256, kernel_size=(4, 4))(embedding)
embedding = MaxPooling2D(pool_size=(2, 2))(embedding)
embedding = Flatten()(embedding)
embedding = Dense(4096, activation='sigmoid')(embedding)
embedding = BatchNormalization()(embedding)
embedding = Activation(activation='relu')(embedding)
return Model(model_input, embedding)
def create_base_model(input_shape):
model_input = Input(shape=input_shape)
embedding = Conv2D(32, kernel_size=(3, 3), input_shape=input_shape)(model_input)
embedding = BatchNormalization()(embedding)
embedding = Activation(activation='relu')(embedding)
embedding = MaxPooling2D(pool_size=(2, 2))(embedding)
embedding = Conv2D(64, kernel_size=(3, 3))(embedding)
embedding = BatchNormalization()(embedding)
embedding = Activation(activation='relu')(embedding)
embedding = MaxPooling2D(pool_size=(2, 2))(embedding)
embedding = Flatten()(embedding)
embedding = Dense(128)(embedding)
embedding = BatchNormalization()(embedding)
embedding = Activation(activation='relu')(embedding)
return Model(model_input, embedding)
def create_head_model(embedding_shape):
embedding_a = Input(shape=embedding_shape[1:])
embedding_b = Input(shape=embedding_shape[1:])
head = Concatenate()([embedding_a, embedding_b])
head = Dense(8)(head)
head = BatchNormalization()(head)
head = Activation(activation='sigmoid')(head)
head = Dense(1)(head)
head = BatchNormalization()(head)
head = Activation(activation='sigmoid')(head)
return Model([embedding_a, embedding_b], head)
def get_batch(x_train, y_train, x_test, y_test, cat_train, batch_size=64):
temp_x = x_train
temp_cat_list = cat_train
start=0
batch_x=[]
batch_y = np.zeros(batch_size)
batch_y[int(batch_size/2):] = 1
np.random.shuffle(batch_y)
class_list = np.random.randint(start, len(cat_train), batch_size)
batch_x.append(np.zeros((batch_size, 100, 100, 3)))
batch_x.append(np.zeros((batch_size, 100, 100, 3)))
for i in range(0, batch_size):
batch_x[0][i] = temp_x[np.random.choice(temp_cat_list[class_list[i]])]
#If train_y has 0 pick from the same class, else pick from any other class
if batch_y[i]==0:
r = np.random.choice(temp_cat_list[class_list[i]])
batch_x[1][i] = temp_x[r]
else:
temp_list = np.append(temp_cat_list[:class_list[i]].flatten(), temp_cat_list[class_list[i]+1:].flatten())
batch_x[1][i] = temp_x[np.random.choice(temp_list)]
return(batch_x, batch_y)
num_classes = 131
epochs = 2000
base_model = create_base_model(input_shape)
head_model = create_head_model(base_model.output_shape)
siamese_network = SiameseNetwork(base_model, head_model)
siamese_network.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
siamese_checkpoint_path = "./siamese_checkpoint"
siamese_callbacks = [
# EarlyStopping(monitor='val_accuracy', patience=10, verbose=0),
ModelCheckpoint(siamese_checkpoint_path, monitor='val_accuracy', save_best_only=True, verbose=0)
]
# batch_size = 64
# for epoch in range(1, epochs):
# batch_x, batch_y = get_batch(x_train, y_train, x_test, y_test, cat_train, train_size, batch_size)
# loss = siamese_network.train_on_batch(batch_x, batch_y)
# print('Epoch:', epoch, ', Loss:', loss)
siamese_network.fit(x_train, y_train,
validation_data=(x_test, y_test),
batch_size=45,
epochs=epochs,
callbacks=siamese_callbacks)
# try:
# siamese_network = keras.models.load_model(siamese_checkpoint_path)
# except Exception as e:
# print(e)
# print("!!!!!!")
# siamese_network.load_weights(siamese_checkpoint_path)
embedding = base_model.outputs[-1]
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
# Add softmax layer to the pre-trained embedding network
embedding = Dense(num_classes)(embedding)
embedding = BatchNormalization()(embedding)
embedding = Activation(activation='sigmoid')(embedding)
model = Model(base_model.inputs[0], embedding)
model.compile(loss=keras.losses.binary_crossentropy,
optimizer=keras.optimizers.Adam(),
metrics=['accuracy'])
model_checkpoint_path = "./model_checkpoint"
model__callbacks = [
# EarlyStopping(monitor='val_accuracy', patience=10, verbose=0),
ModelCheckpoint(model_checkpoint_path, monitor='val_accuracy', save_best_only=True, verbose=0)
]
# for e in range(1, epochs):
# batch_x, batch_y = get_batch(x_train, y_train, x_test, y_test, cat_train, train_size, batch_size)
# loss = model.train_on_batch(batch_x, batch_y)
# print('Epoch:', epoch, ', Loss:', loss)
model.fit(x_train, y_train,
batch_size=128,
epochs=epochs,
callbacks=model__callbacks,
validation_data=(x_test, y_test))
# try:
# model = keras.models.load_model(model_checkpoint_path)
# except Exception as e:
# print(e)
# print("!!!!!!")
# model.load_weights(model_checkpoint_path)
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])