81 lines
2.4 KiB
Python
81 lines
2.4 KiB
Python
|
|
"""
|
|
Tests for the siamese neural network module
|
|
"""
|
|
|
|
import numpy as np
|
|
import keras
|
|
from keras import Model, Input
|
|
from keras.layers import Concatenate, Dense, BatchNormalization, Activation
|
|
|
|
from siamese import SiameseNetwork
|
|
|
|
|
|
def test_siamese():
|
|
"""
|
|
Test that all components the siamese network work correctly by executing a
|
|
training run against generated data.
|
|
"""
|
|
|
|
num_classes = 5
|
|
input_shape = (3,)
|
|
epochs = 1000
|
|
|
|
# Generate some data
|
|
x_train = np.random.rand(100, 3)
|
|
y_train = np.random.randint(num_classes, size=100)
|
|
|
|
x_test = np.random.rand(30, 3)
|
|
y_test = np.random.randint(num_classes, size=30)
|
|
|
|
# Define base and head model
|
|
def create_base_model(input_shape):
|
|
model_input = Input(shape=input_shape)
|
|
|
|
embedding = Dense(4)(model_input)
|
|
embedding = BatchNormalization()(embedding)
|
|
embedding = Activation(activation='relu')(embedding)
|
|
|
|
return Model(model_input, embedding)
|
|
|
|
def create_head_model(embedding_shape):
|
|
embedding_a = Input(shape=embedding_shape)
|
|
embedding_b = Input(shape=embedding_shape)
|
|
|
|
head = Concatenate()([embedding_a, embedding_b])
|
|
head = Dense(4)(head)
|
|
head = BatchNormalization()(head)
|
|
head = Activation(activation='sigmoid')(head)
|
|
|
|
head = Dense(1)(head)
|
|
head = BatchNormalization()(head)
|
|
head = Activation(activation='sigmoid')(head)
|
|
|
|
return Model([embedding_a, embedding_b], head)
|
|
|
|
# Create siamese neural network
|
|
base_model = create_base_model(input_shape)
|
|
head_model = create_head_model(base_model.output_shape)
|
|
siamese_network = SiameseNetwork(base_model, head_model)
|
|
|
|
# Prepare siamese network for training
|
|
siamese_network.compile(loss='binary_crossentropy',
|
|
optimizer=keras.optimizers.adam())
|
|
|
|
# Evaluate network before training to establish a baseline
|
|
score_before = siamese_network.evaluate_generator(
|
|
x_train, y_train, batch_size=64
|
|
)
|
|
|
|
# Train network
|
|
siamese_network.fit(x_train, y_train,
|
|
validation_data=(x_test, y_test),
|
|
batch_size=64,
|
|
epochs=epochs)
|
|
|
|
# Evaluate network
|
|
score_after = siamese_network.evaluate(x_train, y_train, batch_size=64)
|
|
|
|
# Ensure that the training loss score improved as a result of the training
|
|
assert(score_before > score_after)
|