From 66822725fc7f9c3994181b7906686cd32a3d1563 Mon Sep 17 00:00:00 2001
From: hannandarryl <hannandarryl@gmail.com>
Date: Wed, 14 Dec 2022 21:39:24 +0000
Subject: [PATCH] onsd changes

---
 sparse_coding_torch/onsd/classifier_model.py | 26 ++++----
 sparse_coding_torch/onsd/train_MLP.py        | 26 ++++++--
 sparse_coding_torch/onsd/train_classifier.py | 65 ++++++++++----------
 3 files changed, 68 insertions(+), 49 deletions(-)

diff --git a/sparse_coding_torch/onsd/classifier_model.py b/sparse_coding_torch/onsd/classifier_model.py
index 70dbe71..2d61e1a 100644
--- a/sparse_coding_torch/onsd/classifier_model.py
+++ b/sparse_coding_torch/onsd/classifier_model.py
@@ -15,12 +15,12 @@ class ONSDClassifier(keras.layers.Layer):
         
 #         self.sparse_filters = tf.squeeze(keras.models.load_model(sparse_checkpoint).weights[0], axis=0)
 
-        self.conv_1 = keras.layers.Conv2D(32, kernel_size=(8, 8), strides=2, activation='relu', padding='valid')
-        self.conv_2 = keras.layers.Conv2D(32, kernel_size=(8, 8), strides=2, activation='relu', padding='valid')
-        self.conv_3 = keras.layers.Conv2D(32, kernel_size=(8, 8), strides=2, activation='relu', padding='valid')
-#         self.conv_4 = keras.layers.Conv2D(32, kernel_size=(8, 8), strides=2, activation='relu', padding='valid')
-#         self.conv_5 = keras.layers.Conv2D(32, kernel_size=(4, 4), strides=1, activation='relu', padding='valid')
-#         self.conv_6 = keras.layers.Conv2D(32, kernel_size=(8, 8), strides=2, activation='relu', padding='valid')
+        self.conv_1 = keras.layers.Conv2D(32, kernel_size=(8, 8), strides=(2), activation='relu', padding='valid')
+        self.conv_2 = keras.layers.Conv2D(32, kernel_size=(4, 4), strides=(2), activation='relu', padding='valid')
+        self.conv_3 = keras.layers.Conv2D(32, kernel_size=(4, 4), strides=(2), activation='relu', padding='valid')
+        self.conv_4 = keras.layers.Conv2D(32, kernel_size=(4, 4), strides=(2), activation='relu', padding='valid')
+        self.conv_5 = keras.layers.Conv2D(32, kernel_size=(4, 4), strides=(2), activation='relu', padding='valid')
+#         self.conv_6 = keras.layers.Conv2D(32, kernel_size=(4, 4), strides=(12), activation='relu', padding='valid')
 #         self.conv_1 = keras.layers.Conv1D(10, kernel_size=3, strides=1, activation='relu', padding='valid')
 #         self.conv_2 = keras.layers.Conv1D(10, kernel_size=3, strides=1, activation='relu', padding='valid')
 
@@ -35,19 +35,21 @@ class ONSDClassifier(keras.layers.Layer):
         self.ff_2 = keras.layers.Dense(100, activation='relu', use_bias=True)
         self.ff_3 = keras.layers.Dense(20, activation='relu', use_bias=True)
         self.ff_final_1 = keras.layers.Dense(1)
-        self.ff_final_2 = keras.layers.Dense(1)
+#         self.ff_final_2 = keras.layers.Dense(1)
         self.do_dropout = True
 
 #     @tf.function
     def call(self, activations):
+        activations = tf.squeeze(activations, axis=1)
+#         activations = tf.transpose(activations, [0, 2, 3, 1])
 #         x = tf.nn.conv2d(activations, self.sparse_filters, strides=(1, 4), padding='VALID')
 #         x = tf.nn.relu(x)
         x = self.conv_1(activations)
         x = self.conv_2(x)
-        x = self.dropout(x, self.do_dropout)
+#         x = self.dropout(x, self.do_dropout)
         x = self.conv_3(x)
-#         x = self.conv_4(x)
-#         x = self.conv_5(x)
+        x = self.conv_4(x)
+        x = self.conv_5(x)
 #         x = self.conv_6(x)
         x = self.flatten(x)
 #         x = self.ff_1(x)
@@ -57,9 +59,9 @@ class ONSDClassifier(keras.layers.Layer):
         x = self.ff_3(x)
 #         x = self.dropout(x)
         class_pred = self.ff_final_1(x)
-        width_pred = tf.math.sigmoid(self.ff_final_2(x))
+#         width_pred = tf.math.sigmoid(self.ff_final_2(x))
 
-        return class_pred, width_pred
+        return class_pred#, width_pred
     
 class ONSDConv(keras.layers.Layer):
     def __init__(self, do_regression):
diff --git a/sparse_coding_torch/onsd/train_MLP.py b/sparse_coding_torch/onsd/train_MLP.py
index afdc46f..f847cda 100644
--- a/sparse_coding_torch/onsd/train_MLP.py
+++ b/sparse_coding_torch/onsd/train_MLP.py
@@ -27,6 +27,7 @@ import copy
 import matplotlib.pyplot as plt
 import itertools
 import csv
+import json
 
 tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
 import absl.logging
@@ -261,6 +262,7 @@ if __name__ == "__main__":
     recon_model = keras.models.load_model(args.sparse_checkpoint)
     
     data_augmentation = keras.Sequential([
+        keras.layers.Resizing(image_height, image_width)
 #         keras.layers.RandomFlip('horizontal'),
 # #         keras.layers.RandomFlip('vertical'),
 #         keras.layers.RandomRotation(5),
@@ -268,7 +270,7 @@ if __name__ == "__main__":
     ])
         
     
-    splits, dataset = load_onsd_videos(args.batch_size, input_size=(image_height, image_width), crop_size=(crop_height, crop_width), yolo_model=yolo_model, mode=args.splits, n_splits=args.n_splits, do_regression=args.do_regression)
+    splits, dataset = load_onsd_videos(args.batch_size, crop_size=(crop_height, crop_width), yolo_model=yolo_model, mode=args.splits, n_splits=args.n_splits, do_regression=args.do_regression)
     positive_class = 'Positives'
     
     all_video_labels = [f.split('/')[-3] for f in dataset.get_all_videos()]
@@ -318,13 +320,14 @@ if __name__ == "__main__":
     test_frame_true = []
     test_frame_pred = []
     
-    with open(os.path.join(output_dir, 'test_ids.txt'),'w') as f:
-        pass
+#     with open(os.path.join(output_dir, 'test_ids.txt'),'w') as f:
+#         pass
 
     i_fold = 0
+    fold_to_videos_map = {}
     for train_idx, test_idx in splits:
-        with open(os.path.join(output_dir, 'test_ids.txt'), 'a+') as test_id_out:
-            test_id_out.write(str(test_idx) + '\n')
+#         with open(os.path.join(output_dir, 'test_ids.txt'), 'a+') as test_id_out:
+#             test_id_out.write(str(test_idx) + '\n')
         train_loader = copy.deepcopy(dataset)
         train_loader.set_indicies(train_idx)
         test_loader = copy.deepcopy(dataset)
@@ -395,8 +398,14 @@ if __name__ == "__main__":
             
         classifier_model.compile(optimizer=keras.optimizers.Adam(learning_rate=args.lr), loss=criterion)
         
+        callbacks = [
+            keras.callbacks.ModelCheckpoint(os.path.join(args.output_dir, "model_fold_{}.h5".format(i_fold)), save_best_only=False, save_weights_only=True)
+        ]
+        
         if args.train:
-            classifier_model.fit(train_x, train_y, batch_size=args.batch_size, epochs=args.epochs, verbose=False)
+            classifier_model.fit(train_x, train_y, batch_size=args.batch_size, epochs=args.epochs, verbose=False, callbacks=callbacks)
+        else:
+            classifier_model.load_weights(os.path.join(args.output_dir, "model_fold_{}.h5".format(i_fold)))
 
         y_true_train = train_y
         if args.do_regression:
@@ -442,6 +451,8 @@ if __name__ == "__main__":
         ])
 
         test_videos = list(test_loader.get_all_videos())# + [v[1] for v in difficult_vids[i_fold]]
+        
+        fold_to_videos_map[i_fold] = test_videos
 
         test_labels = [vid_f.split('/')[-3] for vid_f in test_videos]
 
@@ -517,4 +528,7 @@ if __name__ == "__main__":
             
     print("Final video accuracy={:.2f}, video f1={:.2f}, frame train accuracy={:.2f}, frame test accuracy={:.2f}".format(final_acc, final_f1, train_frame_acc, test_frame_acc))
     print(final_conf)
+    
+    with open(os.path.join(args.output_dir, 'fold_to_videos.json'), 'w+') as fold_vid_out:
+        json.dump(fold_to_videos_map, fold_vid_out)
 
diff --git a/sparse_coding_torch/onsd/train_classifier.py b/sparse_coding_torch/onsd/train_classifier.py
index f8f3904..759daf0 100644
--- a/sparse_coding_torch/onsd/train_classifier.py
+++ b/sparse_coding_torch/onsd/train_classifier.py
@@ -235,12 +235,12 @@ def calculate_onsd_scores_frame_classifier(input_videos, labels, yolo_model, cla
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument('--batch_size', default=32, type=int)
-    parser.add_argument('--kernel_height', default=15, type=int)
-    parser.add_argument('--kernel_width', default=15, type=int)
+    parser.add_argument('--kernel_height', default=30, type=int)
+    parser.add_argument('--kernel_width', default=60, type=int)
     parser.add_argument('--kernel_depth', default=1, type=int)
-    parser.add_argument('--num_kernels', default=32, type=int)
+    parser.add_argument('--num_kernels', default=8, type=int)
     parser.add_argument('--stride', default=1, type=int)
-    parser.add_argument('--max_activation_iter', default=150, type=int)
+    parser.add_argument('--max_activation_iter', default=300, type=int)
     parser.add_argument('--activation_lr', default=1e-2, type=float)
     parser.add_argument('--lr', default=5e-5, type=float)
     parser.add_argument('--epochs', default=40, type=int)
@@ -260,9 +260,10 @@ if __name__ == "__main__":
     parser.add_argument('--train_sparse', action='store_true')
     parser.add_argument('--mixing_ratio', type=float, default=1.0)
     parser.add_argument('--sparse_lr', type=float, default=0.003)
-    parser.add_argument('--crop_height', type=int, default=200)
-    parser.add_argument('--crop_width', type=int, default=200)
-    parser.add_argument('--scale_factor', type=int, default=1)
+    parser.add_argument('--crop_height', type=int, default=30)
+    parser.add_argument('--crop_width', type=int, default=300)
+    parser.add_argument('--image_height', type=int, default=30)
+    parser.add_argument('--image_width', type=int, default=250)
     parser.add_argument('--clip_depth', type=int, default=1)
     parser.add_argument('--frames_to_skip', type=int, default=1)
     
@@ -271,8 +272,8 @@ if __name__ == "__main__":
     crop_height = args.crop_height
     crop_width = args.crop_width
 
-    image_height = int(crop_height / args.scale_factor)
-    image_width = int(crop_width / args.scale_factor)
+    image_height = args.image_height
+    image_width = args.image_width
     clip_depth = args.clip_depth
 
     batch_size = args.batch_size
@@ -306,18 +307,19 @@ if __name__ == "__main__":
     sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)
     recon_model = keras.models.load_model(args.sparse_checkpoint)
     
+    crop_amount = (crop_width - image_width)
+    assert crop_amount % 2 == 0
+    crop_amount = crop_amount // 2
+        
     data_augmentation = keras.Sequential([
-#         keras.layers.RandomFlip('horizontal'),
-# #         keras.layers.RandomFlip('vertical'),
-#         keras.layers.RandomRotation(5),
-#         keras.layers.RandomBrightness(0.1)
+        keras.layers.RandomTranslation(0, 0.08),
+        keras.layers.Cropping2D((0, crop_amount))
     ])
-#     transforms = torchvision.transforms.Compose(
-#     [torchvision.transforms.RandomAffine(scale=)
-#     ])
+    
+    just_crop = keras.layers.Cropping2D((0, crop_amount))
         
     
-    splits, dataset = load_onsd_videos(args.batch_size, input_size=(image_height, image_width), crop_size=(crop_height, crop_width), yolo_model=yolo_model, mode=args.splits, n_splits=args.n_splits)
+    splits, dataset = load_onsd_videos(args.batch_size, crop_size=(crop_height, crop_width), yolo_model=yolo_model, mode=args.splits, n_splits=args.n_splits)
     positive_class = 'Positives'
     
 #     difficult_vids = split_difficult_vids(dataset.get_difficult_vids(), args.n_splits)
@@ -369,7 +371,7 @@ if __name__ == "__main__":
         if args.checkpoint:
             classifier_model = keras.models.load_model(args.checkpoint)
         else:
-            classifier_inputs = keras.Input(shape=((clip_depth - args.kernel_depth) // 1 + 1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
+            classifier_inputs = keras.Input(shape=output.shape[1:])
             classifier_outputs = ONSDClassifier(args.sparse_checkpoint)(classifier_inputs)
 
             classifier_model = keras.Model(inputs=classifier_inputs, outputs=classifier_outputs)
@@ -380,7 +382,7 @@ if __name__ == "__main__":
         best_so_far = float('inf')
 
         class_criterion = keras.losses.BinaryCrossentropy(from_logits=True, reduction=keras.losses.Reduction.SUM)
-        width_criterion = keras.losses.MeanSquaredError(reduction=keras.losses.Reduction.SUM)
+#         width_criterion = keras.losses.MeanSquaredError(reduction=keras.losses.Reduction.SUM)
 
 
         train_losses = []
@@ -399,8 +401,6 @@ if __name__ == "__main__":
                 y_true_train = None
                 y_pred_train = None
 
-#                 for images, labels, width in tqdm(balanced_ds.shuffle(len(train_tf)).batch(args.batch_size)):
-#                 for images, labels, width in tqdm(balanced_ds.take(len(train_tf)).shuffle(len(train_tf)).batch(args.batch_size)):
                 classifier_model.do_dropout = True
                 for images, labels, width in tqdm(train_tf.shuffle(len(train_tf)).batch(args.batch_size)):
                     images = tf.expand_dims(data_augmentation(tf.transpose(images, [0, 2, 3, 1])), axis=1)
@@ -463,7 +463,7 @@ if __name__ == "__main__":
 #                     eval_loader = train_tf
                 classifier_model.do_dropout = False
                 for images, labels, width in tqdm(test_tf.batch(args.batch_size)):
-                    images = tf.expand_dims(tf.transpose(images, [0, 2, 3, 1]), axis=1)
+                    images = tf.expand_dims(just_crop(tf.transpose(images, [0, 2, 3, 1])), axis=1)
                 
                     activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
 
@@ -482,11 +482,11 @@ if __name__ == "__main__":
                         y_true = tf.concat((y_true, labels), axis=0)
                         y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0)
                         
-                    for p, g in zip(width_pred, width):
-                        if g == 0:
-                            continue
-                        width_p.append(p * dataset.max_width)
-                        width_gt.append(g * dataset.max_width)
+#                     for p, g in zip(width_pred, width):
+#                         if g == 0:
+#                             continue
+#                         width_p.append(p * dataset.max_width)
+#                         width_gt.append(g * dataset.max_width)
 
                 t2 = time.perf_counter()
 
@@ -505,8 +505,8 @@ if __name__ == "__main__":
 
                 train_accuracy = accuracy_score(y_true_train, y_pred_train)
                 
-                test_mae = keras.losses.MeanAbsoluteError()(width_gt, width_p)
-#                 test_mae = 0.0
+#                 test_mae = keras.losses.MeanAbsoluteError()(width_gt, width_p)
+                test_mae = 0.0
                 
                 train_losses.append(epoch_loss)
                 test_losses.append(test_loss)
@@ -539,7 +539,7 @@ if __name__ == "__main__":
         transform = torchvision.transforms.Compose(
         [torchvision.transforms.Grayscale(1),
          MinMaxScaler(0, 255),
-         torchvision.transforms.Resize((image_height, image_width))
+         torchvision.transforms.CenterCrop((image_height, image_width))
         ])
 
         test_videos = list(test_loader.get_all_videos())# + [v[1] for v in difficult_vids[i_fold]]
@@ -547,7 +547,10 @@ if __name__ == "__main__":
         test_labels = [vid_f.split('/')[-3] for vid_f in test_videos]
 
         classifier_model.do_dropout = False
-        y_pred, y_true, fn, fp = calculate_onsd_scores(test_videos, test_labels, yolo_model, classifier_model, sparse_model, recon_model, transform, image_width, image_height, dataset.max_width)
+        max_width = 0
+        if hasattr(dataset, 'max_width'):
+            max_width = dataset.max_width
+        y_pred, y_true, fn, fp = calculate_onsd_scores(test_videos, test_labels, yolo_model, classifier_model, sparse_model, recon_model, transform, crop_width, crop_height, max_width)
 #         y_pred, y_true, fn, fp = calculate_onsd_scores_measured(test_videos, yolo_model, classifier_model, sparse_model, recon_model, transform, crop_width, crop_height)
             
         t2 = time.perf_counter()
-- 
GitLab