remodels siamese network with vgg16 and 100x100 input
- worse performance than with initial design - vgg16 pretrained weights are used for the base network, which is then piped into a custom head model, which - flattens the layer (previously done in the base model) + Dense Layer + Normalization + Activation - training split with 360 fruits used, same as previous mode - maximum prediction level around 0.95 after ca 60 epochs
This commit is contained in:
parent
e5058cc8cc
commit
fbc6ee8187
13
evaluate.py
Normal file
13
evaluate.py
Normal file
@ -0,0 +1,13 @@
|
||||
import tensorflow.keras as keras
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import pdb
|
||||
model = keras.models.load_model('./siamese_checkpoint')
|
||||
image1 = np.asarray(Image.open('../towards/data/fruits-360/Training/Avocado/r_254_100.jpg').convert('RGB').resize((100,
|
||||
100))) / 255
|
||||
image2 = np.asarray(Image.open('../towards/data/fruits-360/Training/Avocado/r_250_100.jpg').convert('RGB').resize((100,
|
||||
100))) / 255
|
||||
|
||||
output = model.predict([np.array([image2]), np.array([image1])])
|
||||
pdb.set_trace()
|
||||
|
4
notes.md
4
notes.md
@ -28,8 +28,8 @@ the steps taken so far, which lead to a successfull detection of an image
|
||||
import tensorflow.keras as keras
|
||||
from PIL import image
|
||||
model = keras.models.load_model('./siamese_checkpoint')
|
||||
image1 = np.asarray(Image.open('../towards/data/fruits-360/Training/Avocado/r_254_100.jpg').convert('RGB').resize((28, 28))) / 255
|
||||
image2 = np.asarray(Image.open('../towards/data/fruits-360/Training/Avocado/r_250_100.jpg').convert('RGB').resize((28, 28))) / 255
|
||||
image1 = np.asarray(Image.open('../towards/data/fruits-360/Training/Avocado/r_254_100.jpg').convert('RGB').resize((28, 28)))
|
||||
image2 = np.asarray(Image.open('../towards/data/fruits-360/Training/Avocado/r_250_100.jpg').convert('RGB').resize((28, 28)))
|
||||
# note that the double division through 255 is only because the model bas taught with this double division, depends on
|
||||
# the input numbers of course
|
||||
|
||||
|
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -26,12 +26,11 @@ from siamese import SiameseNetwork
|
||||
import os, math, numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
batch_size = 128
|
||||
num_classes = 131
|
||||
|
||||
# input image dimensions
|
||||
img_rows, img_cols = 28, 28
|
||||
img_rows, img_cols = 100, 100
|
||||
|
||||
def createTrainingData():
|
||||
base_dir = 'data/fruits-360/Training/'
|
||||
@ -133,22 +132,8 @@ x_train = x_train.astype('float32')
|
||||
x_test = x_test.astype('float32')
|
||||
|
||||
def create_own_base_model(input_shape):
|
||||
model_input = Input(shape=input_shape)
|
||||
|
||||
embedding = Conv2D(32, kernel_size=(10, 10), input_shape=input_shape)(model_input)
|
||||
embedding = MaxPooling2D(pool_size=(2, 2))(embedding)
|
||||
embedding = Conv2D(64, kernel_size=(7, 7))(embedding)
|
||||
embedding = MaxPooling2D(pool_size=(2, 2))(embedding)
|
||||
embedding = Conv2D(128, kernel_size=(4, 4))(embedding)
|
||||
embedding = MaxPooling2D(pool_size=(2, 2))(embedding)
|
||||
embedding = Conv2D(256, kernel_size=(4, 4))(embedding)
|
||||
embedding = MaxPooling2D(pool_size=(2, 2))(embedding)
|
||||
embedding = Flatten()(embedding)
|
||||
embedding = Dense(4096, activation='sigmoid')(embedding)
|
||||
embedding = BatchNormalization()(embedding)
|
||||
embedding = Activation(activation='relu')(embedding)
|
||||
|
||||
return Model(model_input, embedding)
|
||||
return keras.applications.vgg16.VGG16(include_top=False, input_tensor=Input(shape=input_shape), weights='imagenet',
|
||||
classes=1)
|
||||
|
||||
def create_base_model(input_shape):
|
||||
model_input = Input(shape=input_shape)
|
||||
@ -168,6 +153,30 @@ def create_base_model(input_shape):
|
||||
|
||||
return Model(model_input, embedding)
|
||||
|
||||
def create_own_head_model(embedding_shape):
|
||||
embedding_a = Input(shape=embedding_shape[1:])
|
||||
embedding_b = Input(shape=embedding_shape[1:])
|
||||
|
||||
embedding_a_mod = Flatten()(embedding_a)
|
||||
embedding_a_mod = Dense(128)(embedding_a_mod)
|
||||
embedding_a_mod = BatchNormalization()(embedding_a_mod)
|
||||
embedding_a_mod = Activation(activation='relu')(embedding_a_mod)
|
||||
|
||||
embedding_b_mod = Flatten()(embedding_b)
|
||||
embedding_b_mod = Dense(128)(embedding_b_mod)
|
||||
embedding_b_mod = BatchNormalization()(embedding_b_mod)
|
||||
embedding_b_mod = Activation(activation='relu')(embedding_b_mod)
|
||||
|
||||
head = Concatenate()([embedding_a_mod, embedding_b_mod])
|
||||
head = Dense(8)(head)
|
||||
head = BatchNormalization()(head)
|
||||
head = Activation(activation='sigmoid')(head)
|
||||
|
||||
head = Dense(1)(head)
|
||||
head = BatchNormalization()(head)
|
||||
head = Activation(activation='sigmoid')(head)
|
||||
|
||||
return Model([embedding_a, embedding_b], head)
|
||||
|
||||
def create_head_model(embedding_shape):
|
||||
embedding_a = Input(shape=embedding_shape[1:])
|
||||
@ -214,10 +223,10 @@ def get_batch(x_train, y_train, x_test, y_test, cat_train, batch_size=64):
|
||||
|
||||
|
||||
num_classes = 131
|
||||
epochs = 20
|
||||
epochs = 500
|
||||
|
||||
base_model = create_base_model(input_shape)
|
||||
head_model = create_head_model(base_model.output_shape)
|
||||
base_model = create_own_base_model(input_shape)
|
||||
head_model = create_own_head_model(base_model.output_shape)
|
||||
|
||||
siamese_network = SiameseNetwork(base_model, head_model)
|
||||
siamese_network.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
|
||||
@ -225,7 +234,7 @@ siamese_network.compile(loss='binary_crossentropy', optimizer='adam', metrics=['
|
||||
siamese_checkpoint_path = "./siamese_checkpoint"
|
||||
|
||||
siamese_callbacks = [
|
||||
# EarlyStopping(monitor='val_accuracy', patience=10, verbose=0),
|
||||
EarlyStopping(monitor='val_accuracy', patience=10, verbose=0),
|
||||
ModelCheckpoint(siamese_checkpoint_path, monitor='val_accuracy', save_best_only=True, verbose=0)
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user