adds preloading weights into existing siamese network
This commit is contained in:
		
							
								
								
									
										20
									
								
								evaluate.py
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								evaluate.py
									
									
									
									
									
								
							@@ -1,21 +1,25 @@
 | 
				
			|||||||
import tensorflow.keras as keras
 | 
					import tensorflow.keras as keras
 | 
				
			||||||
from PIL import Image
 | 
					from PIL import Image
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import pdb
 | 
					import ipdb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def getI(path):
 | 
					class g:
 | 
				
			||||||
    return np.asarray(Image.open(path).convert('RGB').resize((100, 100))) / 255
 | 
					    def __init__(self, path):
 | 
				
			||||||
 | 
					        self.image = np.asarray(Image.open(path).convert('RGB').resize((100, 100))) / 255
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def show(self):
 | 
				
			||||||
 | 
					        self.image.show()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def predict(image1, image2):
 | 
					def predict(image1, image2):
 | 
				
			||||||
    return model.predict([np.array([image2]), np.array([image1])]) 
 | 
					    return model.predict([np.array([image2.image]), np.array([image1.image])]) 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
model = keras.models.load_model('./siamese_checkpoint')
 | 
					model = keras.models.load_model('../siamese_100x100_pretrainedb_vgg16')
 | 
				
			||||||
image1 = getI('data/fruits/fruits-360/Training/Avocado/r_254_100.jpg')
 | 
					image1 = g('data/fruits/fruits-360/Training/Avocado/r_254_100.jpg')
 | 
				
			||||||
image2 = getI('data/fruits/fruits-360/Training/Avocado/r_250_100.jpg')
 | 
					image2 = g('data/fruits/fruits-360/Training/Avocado/r_250_100.jpg')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
print(predict(image1, image2))
 | 
					print(predict(image1, image2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
pdb.set_trace()
 | 
					ipdb.set_trace()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,8 +12,15 @@ from siamese import SiameseNetwork
 | 
				
			|||||||
import pdb
 | 
					import pdb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import os, math, numpy as np
 | 
					import os, math, numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# import Image handling and 
 | 
				
			||||||
 | 
					# set do not import metadata
 | 
				
			||||||
 | 
					import PIL
 | 
				
			||||||
from PIL import Image
 | 
					from PIL import Image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					PIL.ImageFile.LOAD_TRUNCATED_IMAGES=True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
batch_size = 128
 | 
					batch_size = 128
 | 
				
			||||||
num_classes = 131
 | 
					num_classes = 131
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -21,7 +28,7 @@ num_classes = 131
 | 
				
			|||||||
img_rows, img_cols = 100, 100
 | 
					img_rows, img_cols = 100, 100
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def createTrainingData():
 | 
					def createTrainingData():
 | 
				
			||||||
    base_dir = './classified/'
 | 
					    base_dir = './data/combined/'
 | 
				
			||||||
    train_test_split = 0.7
 | 
					    train_test_split = 0.7
 | 
				
			||||||
    no_of_files_in_each_class = 200 
 | 
					    no_of_files_in_each_class = 200 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user