updates for products
This commit is contained in:
parent
3ef8c74f2e
commit
0bf090c2a8
@ -11,17 +11,14 @@ count = 0
|
||||
|
||||
|
||||
|
||||
print("loading coco annotations...")
|
||||
coco = json.load(open('./coco/annotations/instances_default.json'))
|
||||
print("done")
|
||||
|
||||
def findAnnotationName(annotationId):
|
||||
for c in coco['categories']:
|
||||
def findAnnotationName(annotationId, annotations):
|
||||
for c in annotations['categories']:
|
||||
if c['id'] == annotationId:
|
||||
return c['name']
|
||||
|
||||
def findAnnotationToId(ident):
|
||||
for annotation in coco['annotations']:
|
||||
def findAnnotationToId(ident, annotations):
|
||||
for annotation in annotations['annotations']:
|
||||
img_an = annotation['image_id']
|
||||
if img_an == ident:
|
||||
return annotation
|
||||
@ -34,14 +31,14 @@ def show(pil, pause=0.2):
|
||||
plt.close()
|
||||
|
||||
|
||||
def parseImage(coImg):
|
||||
global no_label, small, passed, coco
|
||||
def parseImage(coImg, annotations, subset):
|
||||
global no_label, small, passed
|
||||
# open image file
|
||||
|
||||
path = "coco/roto_frank/images/" + coImg['file_name'].split('/')[4]
|
||||
path = "coco/"+subset+"/images/" + coImg['file_name'].split('/')[4]
|
||||
img = Image.open(path)
|
||||
|
||||
an = findAnnotationToId(coImg['id'])
|
||||
an = findAnnotationToId(coImg['id'], annotations)
|
||||
if an == None:
|
||||
no_label += 1
|
||||
return
|
||||
@ -53,15 +50,20 @@ def parseImage(coImg):
|
||||
small += 1
|
||||
return
|
||||
|
||||
imagePath = f"classified_roto/{findAnnotationName(an['category_id'])}/{an['id']}.png"
|
||||
imagePath = f"classified/classified_{subset}/{findAnnotationName(an['category_id'], annotations)}/{an['id']}.png"
|
||||
os.makedirs(os.path.dirname(imagePath), exist_ok=True)
|
||||
crop.save(imagePath)
|
||||
passed += 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
def parseSubset(subset):
|
||||
global count
|
||||
|
||||
for coImg in coco['images']:
|
||||
parseImage(coImg)
|
||||
print("loading" + subset + "annotations...")
|
||||
annotations = json.load(open('./coco/'+subset+'/annotations/instances_default.json'))
|
||||
print("done")
|
||||
|
||||
for coImg in annotations['images']:
|
||||
parseImage(coImg, annotations,subset)
|
||||
count += 1
|
||||
if count % 100 == 0:
|
||||
print("status:")
|
||||
@ -70,3 +72,15 @@ if __name__ == "__main__":
|
||||
print(f"passed: {passed}")
|
||||
print("-----")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parseSubset('donaulager')
|
||||
parseSubset('instantina')
|
||||
parseSubset('lkw_walter_2019')
|
||||
parseSubset('lkw_walter_2020')
|
||||
parseSubset('tkl')
|
||||
parseSubset('vw_bratislava')
|
||||
parseSubset('vw_portugal')
|
||||
parseSubset('watt')
|
||||
parseSubset('wls')
|
||||
|
||||
|
@ -12,7 +12,7 @@ count = 0
|
||||
|
||||
|
||||
print("loading coco annotations...")
|
||||
coco = json.load(open('./coco/annotations/instances_val2014.json'))
|
||||
coco = json.load(open('./coco/annotations/instances_default.json'))
|
||||
print("done")
|
||||
|
||||
def findAnnotationName(annotationId):
|
||||
@ -38,7 +38,7 @@ def parseImage(coImg):
|
||||
global no_label, small, passed, coco
|
||||
# open image file
|
||||
|
||||
path = "coco/val2014/" + coImg['file_name']
|
||||
path = "coco/roto_frank/images/" + coImg['file_name'].split('/')[4]
|
||||
img = Image.open(path)
|
||||
|
||||
an = findAnnotationToId(coImg['id'])
|
||||
@ -46,9 +46,6 @@ def parseImage(coImg):
|
||||
no_label += 1
|
||||
return
|
||||
|
||||
|
||||
path = "coco/val2014/" + coImg['file_name']
|
||||
img = Image.open(path)
|
||||
c = an['bbox']
|
||||
crop = img.crop((c[0], c[1], c[0]+c[2], c[1]+c[3]))
|
||||
|
||||
@ -56,7 +53,7 @@ def parseImage(coImg):
|
||||
small += 1
|
||||
return
|
||||
|
||||
imagePath = f"classified/{findAnnotationName(an['category_id'])}/{an['id']}.png"
|
||||
imagePath = f"classified_roto/{findAnnotationName(an['category_id'])}/{an['id']}.png"
|
||||
os.makedirs(os.path.dirname(imagePath), exist_ok=True)
|
||||
crop.save(imagePath)
|
||||
passed += 1
|
||||
|
20
siamese.py
20
siamese.py
@ -202,10 +202,13 @@ class SiameseNetwork:
|
||||
|
||||
img_rows = self.input_shape[0]
|
||||
img_cols = self.input_shape[1]
|
||||
img1 = np.asarray(Image.open(x[element_index_1]).convert('RGB').resize((img_rows, img_cols)))/255.0
|
||||
img2 = np.asarray(Image.open(x[element_index_2]).convert('RGB').resize((img_rows, img_cols)))/255.0
|
||||
# img1 = x[element_index_1]
|
||||
# img2 = x[element_index_2]
|
||||
|
||||
if type(x[element_index_1]) == str:
|
||||
img1 = np.asarray(Image.open(x[element_index_1]).convert('RGB').resize((img_rows, img_cols)))/255.0
|
||||
img2 = np.asarray(Image.open(x[element_index_2]).convert('RGB').resize((img_rows, img_cols)))/255.0
|
||||
else:
|
||||
img1 = x[element_index_1]
|
||||
img2 = x[element_index_2]
|
||||
positive_pairs.append([img1,img2])
|
||||
positive_labels.append([1.0])
|
||||
return positive_pairs, positive_labels
|
||||
@ -241,8 +244,13 @@ class SiameseNetwork:
|
||||
|
||||
img_rows = self.input_shape[0]
|
||||
img_cols = self.input_shape[1]
|
||||
img1 = np.asarray(Image.open(x[element_index_1]).convert('RGB').resize((img_rows, img_cols)))/255.0
|
||||
img2 = np.asarray(Image.open(x[element_index_2]).convert('RGB').resize((img_rows, img_cols)))/255.0
|
||||
|
||||
if type(x[element_index_1]) == str:
|
||||
img1 = np.asarray(Image.open(x[element_index_1]).convert('RGB').resize((img_rows, img_cols)))/255.0
|
||||
img2 = np.asarray(Image.open(x[element_index_2]).convert('RGB').resize((img_rows, img_cols)))/255.0
|
||||
else:
|
||||
img1 = x[element_index_1]
|
||||
img2 = x[element_index_2]
|
||||
|
||||
negative_pairs.append([img1,img2])
|
||||
negative_labels.append([0.0])
|
||||
|
@ -21,9 +21,9 @@ num_classes = 131
|
||||
img_rows, img_cols = 100, 100
|
||||
|
||||
def createTrainingData():
|
||||
base_dir = './classified/'
|
||||
base_dir = './data/COCO/products/'
|
||||
train_test_split = 0.7
|
||||
no_of_files_in_each_class = 10
|
||||
no_of_files_in_each_class = 20
|
||||
|
||||
#Read all the folders in the directory
|
||||
folder_list = os.listdir(base_dir)
|
||||
@ -39,6 +39,8 @@ def createTrainingData():
|
||||
#Using just 5 images per category
|
||||
for folder_name in folder_list:
|
||||
files_list = os.listdir(os.path.join(base_dir, folder_name))
|
||||
if len(files_list) < no_of_files_in_each_class:
|
||||
continue
|
||||
temp=[]
|
||||
for file_name in files_list[:no_of_files_in_each_class]:
|
||||
temp.append(len(x))
|
||||
@ -220,18 +222,19 @@ head_model = create_own_head_model(base_model.output_shape)
|
||||
siamese_network = SiameseNetwork(base_model, head_model)
|
||||
siamese_network.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
|
||||
|
||||
siamese_checkpoint_path = "./siamese_checkpoint"
|
||||
|
||||
siamese_checkpoint_path = "../siamese100_products"
|
||||
model_path = "/variables/variables"
|
||||
siamese_callbacks = [
|
||||
# EarlyStopping(monitor='val_accuracy', patience=10, verbose=0),
|
||||
ModelCheckpoint(siamese_checkpoint_path, monitor='val_accuracy', save_best_only=True, verbose=0)
|
||||
]
|
||||
|
||||
# batch_size = 64
|
||||
# for epoch in range(1, epochs):
|
||||
# batch_x, batch_y = get_batch(x_train, y_train, x_test, y_test, cat_train, train_size, batch_size)
|
||||
# loss = siamese_network.train_on_batch(batch_x, batch_y)
|
||||
# print('Epoch:', epoch, ', Loss:', loss)
|
||||
try:
|
||||
print("loading weights for model")
|
||||
siamese_network.load_weights(siamese_checkpoint_path+model_path)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
siamese_network.fit(x_train, y_train,
|
||||
validation_data=(x_test, y_test),
|
||||
@ -239,13 +242,6 @@ siamese_network.fit(x_train, y_train,
|
||||
epochs=epochs,
|
||||
callbacks=siamese_callbacks)
|
||||
|
||||
# try:
|
||||
# siamese_network = keras.models.load_model(siamese_checkpoint_path)
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
# print("!!!!!!")
|
||||
# siamese_network.load_weights(siamese_checkpoint_path)
|
||||
|
||||
|
||||
score = siamese_network.evaluate(x_test, y_test, batch_size=60, verbose=0)
|
||||
print('Test loss:', score[0])
|
||||
|
@ -28,7 +28,7 @@ num_classes = 131
|
||||
img_rows, img_cols = 100, 100
|
||||
|
||||
def createTrainingData():
|
||||
base_dir = './data/combined/'
|
||||
base_dir = './COCO/products/'
|
||||
train_test_split = 0.7
|
||||
no_of_files_in_each_class = 200
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user