diff --git a/sparse_coding_torch/keras_model.py b/sparse_coding_torch/keras_model.py index 4850759c1db56aad56bfcd1a3121cb3749acca5b..a55afc3750772a9541bed63f7ff9604872ad50db 100644 --- a/sparse_coding_torch/keras_model.py +++ b/sparse_coding_torch/keras_model.py @@ -258,7 +258,7 @@ class PNBClassifier(keras.layers.Layer): self.max_pool = keras.layers.MaxPooling2D(pool_size=8, strides=8) # self.conv_1 = keras.layers.Conv2D(32, kernel_size=8, strides=4, activation='relu', padding='valid') - self.conv_2 = keras.layers.Conv2D(48, kernel_size=4, strides=2, activation='relu', padding='valid') + self.conv_2 = keras.layers.Conv2D(48, kernel_size=(8, 16), strides=4, activation='relu', padding='valid') # self.conv_3 = keras.layers.Conv2D(24, kernel_size=4, strides=2, activation='relu', padding='valid') # self.conv_4 = keras.layers.Conv2D(24, kernel_size=4, strides=2, activation='relu', padding='valid') @@ -266,8 +266,8 @@ class PNBClassifier(keras.layers.Layer): self.dropout = keras.layers.Dropout(0.5) -# self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True) -# self.ff_2 = keras.layers.Dense(500, activation='relu', use_bias=True) + self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True) + self.ff_2 = keras.layers.Dense(500, activation='relu', use_bias=True) self.ff_3 = keras.layers.Dense(100, activation='relu', use_bias=True) self.ff_4 = keras.layers.Dense(1) @@ -280,10 +280,10 @@ class PNBClassifier(keras.layers.Layer): # x = self.conv_3(x) # x = self.conv_4(x) x = self.flatten(x) -# x = self.ff_1(x) -# x = self.dropout(x) -# x = self.ff_2(x) -# x = self.dropout(x) + x = self.ff_1(x) + x = self.dropout(x) + x = self.ff_2(x) + x = self.dropout(x) x = self.ff_3(x) x = self.dropout(x) x = self.ff_4(x) diff --git a/sparse_coding_torch/train_classifier.py b/sparse_coding_torch/train_classifier.py index 8f0276cac663126541cda57ca90ce2d5db0c775d..a18886236609a34e53fd8cf00612db4810a15f7b 100644 --- a/sparse_coding_torch/train_classifier.py +++ b/sparse_coding_torch/train_classifier.py @@ -50,7 +50,7 @@ if __name__ == "__main__": if args.dataset == 'pnb': image_height = 285 - image_width = 235 + image_width = 470 elif args.dataset == 'ptx': image_height = 100 image_width = 200 diff --git a/sparse_coding_torch/video_loader.py b/sparse_coding_torch/video_loader.py index e7060d1c76d90271aa0ce7252b4a964e1ea99104..a855e5e472742ac1f750b89c99fb56428f8e1ccd 100644 --- a/sparse_coding_torch/video_loader.py +++ b/sparse_coding_torch/video_loader.py @@ -98,11 +98,11 @@ def get_yolo_regions(yolo_model, clip, is_right): lower_y = upper_y - 285 if is_right: - lower_x = center_x - 235 + lower_x = center_x - 470 upper_x = center_x else: lower_x = center_x - upper_x = center_x + 235 + upper_x = center_x + 470 trimmed_clip = clip[:, :, lower_y:upper_y, lower_x:upper_x] @@ -172,8 +172,8 @@ class PNBLoader(Dataset): self.augmentation = augmentation self.labels = [name for name in listdir(video_path) if isdir(join(video_path, name))] - clip_cache_file = 'clip_cache_pnb.pt' - clip_cache_final_file = 'clip_cache_pnb_final.pt' + clip_cache_file = 'clip_cache_pnb_double.pt' + clip_cache_final_file = 'clip_cache_pnb_final_double.pt' region_labels = load_pnb_region_labels(join(video_path, 'sme_region_labels.csv')) @@ -222,15 +222,15 @@ class PNBLoader(Dataset): # cv2.imwrite('test.png', vc_sub[0, 0, :, :].unsqueeze(2).numpy()) for clip in get_yolo_regions(yolo_model, vc_sub, is_right): -# if self.transform: -# clip = self.transform(clip) + if self.transform: + clip = self.transform(clip) # print(clip[0, 0, :, :].size()) # cv2.imwrite('test_yolo.png', clip[0, 0, :, :].unsqueeze(2).numpy()) # print(clip.shape) # tv.io.write_video('test_yolo.mp4', clip.swapaxes(0,1).swapaxes(1,2).swapaxes(2,3).numpy(), fps=20) - print(path) - raise Exception +# print(path) +# raise Exception self.clips.append(('Negatives', clip, self.videos[vid_idx][2])) diff --git a/yolov4 b/yolov4 deleted file mode 160000 index 9f16748aa3f45ff240608da4bd9b1216a29127f5..0000000000000000000000000000000000000000 --- a/yolov4 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 9f16748aa3f45ff240608da4bd9b1216a29127f5