siamese/train_coco.py

269 lines
9.1 KiB
Python

"""
This is a modified version of the Keras mnist example.
https://keras.io/examples/mnist_cnn/
Instead of using a fixed number of epochs this version continues to train until a stop criteria is reached.
A siamese neural network is used to pre-train an embedding for the network. The resulting embedding is then extended
with a softmax output layer for categorical predictions.
Model performance should be around 99.84% after training. The resulting model is identical in structure to the one in
the example yet shows considerable improvement in relative error confirming that the embedding learned by the siamese
network is useful.
"""
from __future__ import print_function
import tensorflow.keras as keras
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Activation, Concatenate
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Flatten, Dense
from siamese import SiameseNetwork
import pdb
import os, math, numpy as np
from PIL import Image
batch_size = 128
num_classes = 131
# input image dimensions
img_rows, img_cols = 100, 100
def createTrainingData():
base_dir = './classified/'
train_test_split = 0.7
no_of_files_in_each_class = 400
#Read all the folders in the directory
folder_list = os.listdir(base_dir)
print( len(folder_list), "categories found in the dataset")
#Declare training array
cat_list = []
x = []
names = []
y = []
y_label = 0
#Using just 5 images per category
for folder_name in folder_list:
files_list = os.listdir(os.path.join(base_dir, folder_name))
if len(files_list) < no_of_files_in_each_class:
continue
temp=[]
for file_name in files_list[:no_of_files_in_each_class]:
temp.append(len(x))
x.append(np.asarray(Image.open(os.path.join(base_dir, folder_name, file_name)).convert('RGB').resize((img_rows, img_cols))))
names.append(folder_name + "/" + file_name)
y.append(y_label)
y_label+=1
cat_list.append(temp)
cat_list = np.asarray(cat_list)
x = np.asarray(x)/255.0
y = np.asarray(y)
print('X, Y shape',x.shape, y.shape, cat_list.shape)
#Training Split
x_train, y_train, cat_train, x_val, y_val, cat_test = [], [], [], [], [], []
train_split = math.floor((train_test_split) * no_of_files_in_each_class)
test_split = math.floor((1-train_test_split) * no_of_files_in_each_class)
train_count = 0
test_count = 0
for i in range(len(x)-1):
if i % no_of_files_in_each_class == 0:
cat_train.append([])
cat_test.append([])
class_train_count = 1
class_test_count = 1
if i % math.floor(1/train_test_split) == 0 and class_test_count < test_split:
x_val.append(x[i])
y_val.append(y[i])
cat_test[-1].append(test_count)
test_count += 1
class_test_count += 1
elif class_train_count < train_split:
x_train.append(x[i])
y_train.append(y[i])
cat_train[-1].append(train_count)
train_count += 1
class_train_count += 1
x_val = np.array(x_val)
y_val = np.array(y_val)
x_train = np.array(x_train)
y_train = np.array(y_train)
cat_train = np.array(cat_train)
cat_test = np.array(cat_test)
print('X&Y shape of training data :',x_train.shape, 'and',
y_train.shape, cat_train.shape)
print('X&Y shape of testing data :' , x_val.shape, 'and',
y_val.shape, cat_test.shape)
return (x_train, y_train), (x_val, y_val), cat_train
# the data, split between train and test sets
# (x_train, y_train), (x_test, y_test) = mnist.load_data()
# channels = 1
(x_train, y_train), (x_test, y_test), cat_train = createTrainingData()
channels = 3
if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], channels, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], channels, img_rows, img_cols)
input_shape = (channels, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, channels)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, channels)
input_shape = (img_rows, img_cols, channels)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
def create_own_base_model(input_shape):
return keras.applications.vgg16.VGG16(include_top=False, input_tensor=Input(shape=input_shape), weights='imagenet',
classes=1)
def create_base_model(input_shape):
model_input = Input(shape=input_shape)
embedding = Conv2D(32, kernel_size=(3, 3), input_shape=input_shape)(model_input)
embedding = BatchNormalization()(embedding)
embedding = Activation(activation='relu')(embedding)
embedding = MaxPooling2D(pool_size=(2, 2))(embedding)
embedding = Conv2D(64, kernel_size=(3, 3))(embedding)
embedding = BatchNormalization()(embedding)
embedding = Activation(activation='relu')(embedding)
embedding = MaxPooling2D(pool_size=(2, 2))(embedding)
embedding = Flatten()(embedding)
embedding = Dense(128)(embedding)
embedding = BatchNormalization()(embedding)
embedding = Activation(activation='relu')(embedding)
return Model(model_input, embedding)
def create_own_head_model(embedding_shape):
embedding_a = Input(shape=embedding_shape[1:])
embedding_b = Input(shape=embedding_shape[1:])
embedding_a_mod = Flatten()(embedding_a)
embedding_a_mod = Dense(128)(embedding_a_mod)
embedding_a_mod = BatchNormalization()(embedding_a_mod)
embedding_a_mod = Activation(activation='relu')(embedding_a_mod)
embedding_b_mod = Flatten()(embedding_b)
embedding_b_mod = Dense(128)(embedding_b_mod)
embedding_b_mod = BatchNormalization()(embedding_b_mod)
embedding_b_mod = Activation(activation='relu')(embedding_b_mod)
head = Concatenate()([embedding_a_mod, embedding_b_mod])
head = Dense(8)(head)
head = BatchNormalization()(head)
head = Activation(activation='sigmoid')(head)
head = Dense(1)(head)
head = BatchNormalization()(head)
head = Activation(activation='sigmoid')(head)
return Model([embedding_a, embedding_b], head)
def create_head_model(embedding_shape):
embedding_a = Input(shape=embedding_shape[1:])
embedding_b = Input(shape=embedding_shape[1:])
head = Concatenate()([embedding_a, embedding_b])
head = Dense(8)(head)
head = BatchNormalization()(head)
head = Activation(activation='sigmoid')(head)
head = Dense(1)(head)
head = BatchNormalization()(head)
head = Activation(activation='sigmoid')(head)
return Model([embedding_a, embedding_b], head)
def get_batch(x_train, y_train, x_test, y_test, cat_train, batch_size=64):
temp_x = x_train
temp_cat_list = cat_train
start=0
batch_x=[]
batch_y = np.zeros(batch_size)
batch_y[int(batch_size/2):] = 1
np.random.shuffle(batch_y)
class_list = np.random.randint(start, len(cat_train), batch_size)
batch_x.append(np.zeros((batch_size, 100, 100, 3)))
batch_x.append(np.zeros((batch_size, 100, 100, 3)))
for i in range(0, batch_size):
batch_x[0][i] = temp_x[np.random.choice(temp_cat_list[class_list[i]])]
#If train_y has 0 pick from the same class, else pick from any other class
if batch_y[i]==0:
r = np.random.choice(temp_cat_list[class_list[i]])
batch_x[1][i] = temp_x[r]
else:
temp_list = np.append(temp_cat_list[:class_list[i]].flatten(), temp_cat_list[class_list[i]+1:].flatten())
batch_x[1][i] = temp_x[np.random.choice(temp_list)]
return(batch_x, batch_y)
num_classes = 131
epochs = 2000
base_model = create_own_base_model(input_shape)
head_model = create_own_head_model(base_model.output_shape)
siamese_network = SiameseNetwork(base_model, head_model)
siamese_network.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
siamese_checkpoint_path = "./siamese_checkpoint"
siamese_callbacks = [
# EarlyStopping(monitor='val_accuracy', patience=10, verbose=0),
ModelCheckpoint(siamese_checkpoint_path, monitor='val_accuracy', save_best_only=True, verbose=0)
]
# batch_size = 64
# for epoch in range(1, epochs):
# batch_x, batch_y = get_batch(x_train, y_train, x_test, y_test, cat_train, train_size, batch_size)
# loss = siamese_network.train_on_batch(batch_x, batch_y)
# print('Epoch:', epoch, ', Loss:', loss)
siamese_network.fit(x_train, y_train,
validation_data=(x_test, y_test),
batch_size=45,
epochs=epochs,
callbacks=siamese_callbacks)
# try:
# siamese_network = keras.models.load_model(siamese_checkpoint_path)
# except Exception as e:
# print(e)
# print("!!!!!!")
# siamese_network.load_weights(siamese_checkpoint_path)
score = siamese_network.evaluate(x_test, y_test, batch_size=60, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])