adds preloading weights into existing siamese network
This commit is contained in:
parent
d6e6abc8af
commit
d8d2b5463e
20
evaluate.py
20
evaluate.py
@ -1,21 +1,25 @@
|
|||||||
import tensorflow.keras as keras
|
import tensorflow.keras as keras
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pdb
|
import ipdb
|
||||||
|
|
||||||
def getI(path):
|
class g:
|
||||||
return np.asarray(Image.open(path).convert('RGB').resize((100, 100))) / 255
|
def __init__(self, path):
|
||||||
|
self.image = np.asarray(Image.open(path).convert('RGB').resize((100, 100))) / 255
|
||||||
|
|
||||||
|
def show(self):
|
||||||
|
self.image.show()
|
||||||
|
|
||||||
def predict(image1, image2):
|
def predict(image1, image2):
|
||||||
return model.predict([np.array([image2]), np.array([image1])])
|
return model.predict([np.array([image2.image]), np.array([image1.image])])
|
||||||
|
|
||||||
model = keras.models.load_model('./siamese_checkpoint')
|
model = keras.models.load_model('../siamese_100x100_pretrainedb_vgg16')
|
||||||
image1 = getI('data/fruits/fruits-360/Training/Avocado/r_254_100.jpg')
|
image1 = g('data/fruits/fruits-360/Training/Avocado/r_254_100.jpg')
|
||||||
image2 = getI('data/fruits/fruits-360/Training/Avocado/r_250_100.jpg')
|
image2 = g('data/fruits/fruits-360/Training/Avocado/r_250_100.jpg')
|
||||||
|
|
||||||
print(predict(image1, image2))
|
print(predict(image1, image2))
|
||||||
|
|
||||||
pdb.set_trace()
|
ipdb.set_trace()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,8 +12,15 @@ from siamese import SiameseNetwork
|
|||||||
import pdb
|
import pdb
|
||||||
|
|
||||||
import os, math, numpy as np
|
import os, math, numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# import Image handling and
|
||||||
|
# set do not import metadata
|
||||||
|
import PIL
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
PIL.ImageFile.LOAD_TRUNCATED_IMAGES=True
|
||||||
|
|
||||||
batch_size = 128
|
batch_size = 128
|
||||||
num_classes = 131
|
num_classes = 131
|
||||||
|
|
||||||
@ -21,9 +28,9 @@ num_classes = 131
|
|||||||
img_rows, img_cols = 100, 100
|
img_rows, img_cols = 100, 100
|
||||||
|
|
||||||
def createTrainingData():
|
def createTrainingData():
|
||||||
base_dir = './classified/'
|
base_dir = './data/combined/'
|
||||||
train_test_split = 0.7
|
train_test_split = 0.7
|
||||||
no_of_files_in_each_class = 200
|
no_of_files_in_each_class = 200
|
||||||
|
|
||||||
#Read all the folders in the directory
|
#Read all the folders in the directory
|
||||||
folder_list = os.listdir(base_dir)
|
folder_list = os.listdir(base_dir)
|
||||||
|
Loading…
Reference in New Issue
Block a user