From d8d2b5463eb1af6162069cfa2c4e99fe87022f90 Mon Sep 17 00:00:00 2001 From: raphael Date: Thu, 12 Aug 2021 15:25:37 +0200 Subject: [PATCH] adds preloading weights into existing siamese network --- evaluate.py | 20 ++++++++++++-------- train_lambda_coco.py | 11 +++++++++-- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/evaluate.py b/evaluate.py index 3ad92fc..709886c 100644 --- a/evaluate.py +++ b/evaluate.py @@ -1,21 +1,25 @@ import tensorflow.keras as keras from PIL import Image import numpy as np -import pdb +import ipdb -def getI(path): - return np.asarray(Image.open(path).convert('RGB').resize((100, 100))) / 255 +class g: + def __init__(self, path): + self.image = np.asarray(Image.open(path).convert('RGB').resize((100, 100))) / 255 + + def show(self): + self.image.show() def predict(image1, image2): - return model.predict([np.array([image2]), np.array([image1])]) + return model.predict([np.array([image2.image]), np.array([image1.image])]) -model = keras.models.load_model('./siamese_checkpoint') -image1 = getI('data/fruits/fruits-360/Training/Avocado/r_254_100.jpg') -image2 = getI('data/fruits/fruits-360/Training/Avocado/r_250_100.jpg') +model = keras.models.load_model('../siamese_100x100_pretrainedb_vgg16') +image1 = g('data/fruits/fruits-360/Training/Avocado/r_254_100.jpg') +image2 = g('data/fruits/fruits-360/Training/Avocado/r_250_100.jpg') print(predict(image1, image2)) -pdb.set_trace() +ipdb.set_trace() diff --git a/train_lambda_coco.py b/train_lambda_coco.py index 89d18af..8fa66f5 100644 --- a/train_lambda_coco.py +++ b/train_lambda_coco.py @@ -12,8 +12,15 @@ from siamese import SiameseNetwork import pdb import os, math, numpy as np + + +# import Image handling and +# set do not import metadata +import PIL from PIL import Image +PIL.ImageFile.LOAD_TRUNCATED_IMAGES=True + batch_size = 128 num_classes = 131 @@ -21,9 +28,9 @@ num_classes = 131 img_rows, img_cols = 100, 100 def createTrainingData(): - base_dir = './classified/' + base_dir = './data/combined/' train_test_split = 0.7 - no_of_files_in_each_class = 200 + no_of_files_in_each_class = 200 #Read all the folders in the directory folder_list = os.listdir(base_dir)