diff --git a/notebooks/visualize_filters_keras.ipynb b/notebooks/visualize_filters_keras.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..4a5918309ac20d7c55e1dee9ddaafee9256fcba4
--- /dev/null
+++ b/notebooks/visualize_filters_keras.ipynb
@@ -0,0 +1,270 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "0603e9bd-ac66-4984-b1de-f09367956968",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from sparse_coding_torch.keras_model import SparseCode, ReconSparse\n",
+    "from sparse_coding_torch.load_data import load_pnb_videos\n",
+    "import tensorflow.keras as keras\n",
+    "from sparse_coding_torch.train_sparse_model import plot_video, plot_filters\n",
+    "from IPython.display import HTML\n",
+    "import tensorflow as tf"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a02865e4-6685-4401-98bd-e1a71675d122",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "image_height = 360\n",
+    "image_width = 304\n",
+    "batch_size = 1\n",
+    "kernel_size = 15\n",
+    "num_kernels = 48\n",
+    "stride=1\n",
+    "max_activation_iter = 150\n",
+    "activation_lr=1e-2\n",
+    "lam=0.05\n",
+    "run_2d=False\n",
+    "sparse_checkpoint = '../sparse_coding_torch/output/sparse_pnb_48/sparse_conv3d_model-best.pt/'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "72e88838-67a8-4deb-86b0-82f5e4eca9db",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "inputs = keras.Input(shape=(5, image_height, image_width, 1))\n",
+    "        \n",
+    "filter_inputs = keras.Input(shape=(5, kernel_size, kernel_size, 1, num_kernels), dtype='float32')\n",
+    "\n",
+    "output = SparseCode(batch_size=batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=num_kernels, kernel_size=kernel_size, stride=stride, lam=lam, activation_lr=activation_lr, max_activation_iter=max_activation_iter, run_2d=run_2d)(inputs, filter_inputs)\n",
+    "\n",
+    "sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)\n",
+    "\n",
+    "recon_model = keras.models.load_model(sparse_checkpoint)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "0ed41c99-07a4-4ffb-964c-34bf7ab50a2d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "ani = plot_filters(recon_model.get_weights()[0])\n",
+    "HTML(ani.to_html5_video())"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ef944b87-dfb2-4aff-a0b7-497cb86b9742",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "train_loader, _ = load_pnb_videos(batch_size, classify_mode=True, mode='all_train', device=None, n_splits=1, sparse_model=None)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d69a8ad5-4037-490c-bef5-50d93190ece8",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pos_label, pos_clip, vid = list(train_loader)[1]\n",
+    "neg_label, neg_clip, vid = list(train_loader)[2]\n",
+    "print(pos_label)\n",
+    "print(neg_label)\n",
+    "ani = plot_video(neg_clip.squeeze(0))\n",
+    "HTML(ani.to_html5_video())"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "6f27ad28-599b-48aa-8f29-175cb8db82e5",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "all_negative_activations = []\n",
+    "all_positive_activations = []\n",
+    "\n",
+    "for label, clip, vid in train_loader:\n",
+    "    clip = clip.permute(0, 2, 3, 4, 1).numpy()\n",
+    "    activations = sparse_model([clip, tf.expand_dims(recon_model.trainable_weights[0], axis=0)])\n",
+    "    if label[0] == 'Negatives':\n",
+    "        all_negative_activations.append(activations)\n",
+    "    else:\n",
+    "        all_positive_activations.append(activations)\n",
+    "\n",
+    "print(len(all_negative_activations))\n",
+    "print(len(all_positive_activations))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "76f1ae59-49ad-4e9f-97f2-05f6ee073231",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "all_negative_activations = tf.reshape(tf.stack(all_negative_activations), (-1, 64))\n",
+    "all_positive_activations = tf.reshape(tf.stack(all_positive_activations), (-1, 64))\n",
+    "\n",
+    "negative_scores = tf.math.reduce_sum(all_negative_activations, axis=0)\n",
+    "negative_scores = negative_scores / tf.math.reduce_max(negative_scores)\n",
+    "positive_scores = tf.math.reduce_sum(all_positive_activations, axis=0)\n",
+    "positive_scores = positive_scores / tf.math.reduce_max(positive_scores)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a4ba6bd8-898c-4e72-ab6c-ca3028fd38d0",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import matplotlib.pyplot as plt\n",
+    "from matplotlib import cm\n",
+    "from matplotlib.animation import FuncAnimation\n",
+    "\n",
+    "fig, ax = plt.subplots(nrows=2, ncols=1)\n",
+    "\n",
+    "ax[0].bar(range(len(negative_scores)), negative_scores)\n",
+    "ax[1].bar(range(len(positive_scores)), positive_scores)\n",
+    "\n",
+    "\n",
+    "plt.tight_layout()\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "24c0ff4c-198a-4de3-9808-2debb6d405c5",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "largest_diff = int(tf.math.argmax(tf.math.abs(positive_scores - negative_scores)))\n",
+    "print(largest_diff)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "fd5bfcd5-f503-43d9-9dd1-b5e72bd24e9c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def plot_filter(video):\n",
+    "\n",
+    "    fig = plt.gcf()\n",
+    "    ax = plt.gca()\n",
+    "\n",
+    "    DPI = fig.get_dpi()\n",
+    "#     fig.set_size_inches(video.shape[2]/float(DPI), video.shape[3]/float(DPI))\n",
+    "\n",
+    "    ax.set_title(\"Video\")\n",
+    "\n",
+    "    T = video.shape[1]\n",
+    "    im = ax.imshow(video[0, 0, :, :],\n",
+    "                     cmap=cm.Greys_r)\n",
+    "\n",
+    "    def update(i):\n",
+    "        t = i % T\n",
+    "        im.set_data(video[0, t, :, :])\n",
+    "\n",
+    "    return FuncAnimation(plt.gcf(), update, interval=1000/20)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "071ac8be-a95b-41c8-a688-d4c7c4122560",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "filter_diff = recon_model.trainable_weights[0][:, :, :, :, 4]\n",
+    "filter_diff = tf.expand_dims(tf.squeeze(filter_diff, -1), 0)\n",
+    "print(filter_diff.shape)\n",
+    "ani = plot_filter(filter_diff)\n",
+    "# HTML(ani.to_html5_video())\n",
+    "ani.save(\"/home/dwh48@drexel.edu/sparse_coding_torch/sparse_coding_torch/output/sparse_pnb_48/needle_filter.mp4\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "241beed8-4ad8-4f1f-8e0a-cc0d313af4ee",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "label, clip, vid = list(train_loader)[1]\n",
+    "clip = clip.permute(0, 2, 3, 4, 1).numpy()\n",
+    "ani = plot_video(clip)\n",
+    "# HTML(ani.to_html5_video())\n",
+    "ani.save(\"/home/dwh48@drexel.edu/sparse_coding_torch/sparse_coding_torch/output/sparse_pnb_48/needle_video.mp4\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "95e2db90-7a78-43ad-9311-e28558991d69",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "activations = sparse_model([clip, tf.expand_dims(recon_model.trainable_weights[0], axis=0)])[:, :, :, :, 4]\n",
+    "ani = plot_video(activations)\n",
+    "# HTML(ani.to_html5_video())\n",
+    "ani.save(\"/home/dwh48@drexel.edu/sparse_coding_torch/sparse_coding_torch/output/sparse_pnb_48/activation_map.mp4\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f05a6dc8-365e-4d40-bd74-e1cf3be8dabf",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5e9a2f94-9ffd-439c-b88c-ee99b179dec6",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python (pocus_project)",
+   "language": "python",
+   "name": "darryl_pocus"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.7"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/ptx_tensorflow/best_classifier.pt/keras_metadata.pb b/ptx_tensorflow/best_classifier.pt/keras_metadata.pb
new file mode 100644
index 0000000000000000000000000000000000000000..d54abc37f0a834091324029b1b955c88fb709af3
--- /dev/null
+++ b/ptx_tensorflow/best_classifier.pt/keras_metadata.pb
@@ -0,0 +1,10 @@
+
+�root"_tf_keras_network*�{"name": "model_2", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "must_restore_from_config": false, "class_name": "Functional", "config": {"name": "model_2", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 1, 43, 93, 64]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_4"}, "name": "input_4", "inbound_nodes": []}, {"class_name": "PTXClassifier", "config": {"name": "ptx_classifier", "trainable": true, "dtype": "float32"}, "name": "ptx_classifier", "inbound_nodes": [[["input_4", 0, 0, {}]]]}], "input_layers": [["input_4", 0, 0]], "output_layers": [["ptx_classifier", 0, 0]]}, "shared_object_id": 2, "input_spec": [{"class_name": "InputSpec", "config": {"dtype": null, "shape": {"class_name": "__tuple__", "items": [null, 1, 43, 93, 64]}, "ndim": 5, "max_ndim": null, "min_ndim": null, "axes": {}}}], "build_input_shape": {"class_name": "TensorShape", "items": [null, 1, 43, 93, 64]}, "is_graph_network": true, "full_save_spec": {"class_name": "__tuple__", "items": [[{"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 1, 43, 93, 64]}, "float32", "input_4"]}], {}]}, "save_spec": {"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 1, 43, 93, 64]}, "float32", "input_4"]}, "keras_version": "2.8.0", "backend": "tensorflow", "model_config": {"class_name": "Functional", "config": {"name": "model_2", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 1, 43, 93, 64]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_4"}, "name": "input_4", "inbound_nodes": [], "shared_object_id": 0}, {"class_name": "PTXClassifier", "config": {"name": "ptx_classifier", "trainable": true, "dtype": "float32"}, "name": "ptx_classifier", "inbound_nodes": [[["input_4", 0, 0, {}]]], "shared_object_id": 1}], "input_layers": [["input_4", 0, 0]], "output_layers": [["ptx_classifier", 0, 0]]}}}2
+�root.layer-0"_tf_keras_input_layer*�{"class_name": "InputLayer", "name": "input_4", "dtype": "float32", "sparse": false, "ragged": false, "batch_input_shape": {"class_name": "__tuple__", "items": [null, 1, 43, 93, 64]}, "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 1, 43, 93, 64]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_4"}}2
+�root.layer_with_weights-0"_tf_keras_layer*�{"name": "ptx_classifier", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "PTXClassifier", "config": {"name": "ptx_classifier", "trainable": true, "dtype": "float32"}, "inbound_nodes": [[["input_4", 0, 0, {}]]], "shared_object_id": 1}2
+�"root.layer_with_weights-0.max_pool"_tf_keras_layer*�{"name": "max_pooling2d", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "MaxPooling2D", "config": {"name": "max_pooling2d", "trainable": true, "dtype": "float32", "pool_size": {"class_name": "__tuple__", "items": [4, 4]}, "padding": "valid", "strides": {"class_name": "__tuple__", "items": [4, 4]}, "data_format": "channels_last"}, "shared_object_id": 4, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": 4, "max_ndim": null, "min_ndim": null, "axes": {}}, "shared_object_id": 5}}2
+�	 root.layer_with_weights-0.conv_1"_tf_keras_layer*�	{"name": "conv2d", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "Conv2D", "config": {"name": "conv2d", "trainable": true, "dtype": "float32", "filters": 24, "kernel_size": {"class_name": "__tuple__", "items": [8, 8]}, "strides": {"class_name": "__tuple__", "items": [4, 4]}, "padding": "valid", "data_format": "channels_last", "dilation_rate": {"class_name": "__tuple__", "items": [1, 1]}, "groups": 1, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 6}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 7}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "shared_object_id": 8, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 4, "axes": {"-1": 64}}, "shared_object_id": 9}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 10, 23, 64]}}2
+�
!root.layer_with_weights-0.flatten"_tf_keras_layer*�{"name": "flatten", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "shared_object_id": 10, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 1, "axes": {}}, "shared_object_id": 11}}2
+�!root.layer_with_weights-0.dropout"_tf_keras_layer*�{"name": "dropout", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "Dropout", "config": {"name": "dropout", "trainable": true, "dtype": "float32", "rate": 0.5, "noise_shape": null, "seed": null}, "shared_object_id": 12, "build_input_shape": {"class_name": "TensorShape", "items": [null, 20]}}2
+�root.layer_with_weights-0.ff_3"_tf_keras_layer*�{"name": "dense", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 20, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 13}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 14}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "shared_object_id": 15, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 96}}, "shared_object_id": 16}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 96]}}2
+�root.layer_with_weights-0.ff_4"_tf_keras_layer*�{"name": "dense_1", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 1, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 17}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 18}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "shared_object_id": 19, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 20}}, "shared_object_id": 20}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 20]}}2
\ No newline at end of file
diff --git a/ptx_tensorflow/best_classifier.pt/saved_model.pb b/ptx_tensorflow/best_classifier.pt/saved_model.pb
new file mode 100644
index 0000000000000000000000000000000000000000..3aa9d020e2183ae635522ea281b0fadca87f269b
Binary files /dev/null and b/ptx_tensorflow/best_classifier.pt/saved_model.pb differ
diff --git a/ptx_tensorflow/best_classifier.pt/variables/variables.data-00000-of-00001 b/ptx_tensorflow/best_classifier.pt/variables/variables.data-00000-of-00001
new file mode 100644
index 0000000000000000000000000000000000000000..ef0bf637487b38bc827748fded255cc73757ea01
Binary files /dev/null and b/ptx_tensorflow/best_classifier.pt/variables/variables.data-00000-of-00001 differ
diff --git a/ptx_tensorflow/best_classifier.pt/variables/variables.index b/ptx_tensorflow/best_classifier.pt/variables/variables.index
new file mode 100644
index 0000000000000000000000000000000000000000..cc6e2d0a811f388effbdb84520d78b2c658d00f0
Binary files /dev/null and b/ptx_tensorflow/best_classifier.pt/variables/variables.index differ
diff --git a/ptx_tensorflow/sparse.pt/keras_metadata.pb b/ptx_tensorflow/sparse.pt/keras_metadata.pb
new file mode 100644
index 0000000000000000000000000000000000000000..ae8e9a697296cc307f567f9c234a8460569f426a
--- /dev/null
+++ b/ptx_tensorflow/sparse.pt/keras_metadata.pb
@@ -0,0 +1,4 @@
+
+�root"_tf_keras_network*�{"name": "model_1", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "must_restore_from_config": false, "class_name": "Functional", "config": {"name": "model_1", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 1, 43, 93, 64]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_2"}, "name": "input_2", "inbound_nodes": []}, {"class_name": "ReconSparse", "config": {"layer was saved without config": true}, "name": "recon_sparse", "inbound_nodes": [[["input_2", 0, 0, {}]]]}], "input_layers": [["input_2", 0, 0]], "output_layers": [["recon_sparse", 0, 0]]}, "shared_object_id": 1, "input_spec": [{"class_name": "InputSpec", "config": {"dtype": null, "shape": {"class_name": "__tuple__", "items": [null, 1, 43, 93, 64]}, "ndim": 5, "max_ndim": null, "min_ndim": null, "axes": {}}}], "build_input_shape": {"class_name": "TensorShape", "items": [null, 1, 43, 93, 64]}, "is_graph_network": true, "full_save_spec": {"class_name": "__tuple__", "items": [[{"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 1, 43, 93, 64]}, "float32", "input_2"]}], {}]}, "save_spec": {"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 1, 43, 93, 64]}, "float32", "input_2"]}, "keras_version": "2.8.0", "backend": "tensorflow", "model_config": {"class_name": "Functional"}}2
+�root.layer-0"_tf_keras_input_layer*�{"class_name": "InputLayer", "name": "input_2", "dtype": "float32", "sparse": false, "ragged": false, "batch_input_shape": {"class_name": "__tuple__", "items": [null, 1, 43, 93, 64]}, "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 1, 43, 93, 64]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_2"}}2
+�root.layer_with_weights-0"_tf_keras_model*�{"name": "recon_sparse", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "must_restore_from_config": false, "class_name": "ReconSparse", "config": {"layer was saved without config": true}, "is_graph_network": false, "full_save_spec": {"class_name": "__tuple__", "items": [[{"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 1, 43, 93, 64]}, "float32", "input_1"]}], {}]}, "save_spec": {"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 1, 43, 93, 64]}, "float32", "input_1"]}, "keras_version": "2.8.0", "backend": "tensorflow", "model_config": {"class_name": "ReconSparse"}}2
\ No newline at end of file
diff --git a/ptx_tensorflow/sparse.pt/saved_model.pb b/ptx_tensorflow/sparse.pt/saved_model.pb
new file mode 100644
index 0000000000000000000000000000000000000000..2d2a65605e9a54a97d56f994cd0c00500f0b189b
Binary files /dev/null and b/ptx_tensorflow/sparse.pt/saved_model.pb differ
diff --git a/ptx_tensorflow/sparse.pt/variables/variables.data-00000-of-00001 b/ptx_tensorflow/sparse.pt/variables/variables.data-00000-of-00001
new file mode 100644
index 0000000000000000000000000000000000000000..d34cddad2a995332cfb2993af770d2fe0474dd08
Binary files /dev/null and b/ptx_tensorflow/sparse.pt/variables/variables.data-00000-of-00001 differ
diff --git a/ptx_tensorflow/sparse.pt/variables/variables.index b/ptx_tensorflow/sparse.pt/variables/variables.index
new file mode 100644
index 0000000000000000000000000000000000000000..a83339ea022e1ab94dba95b233e8cbe296f8633e
Binary files /dev/null and b/ptx_tensorflow/sparse.pt/variables/variables.index differ
diff --git a/run.py b/run.py
deleted file mode 100644
index 4ce10ed5d7ddcd34767702a63eb1812a649bd2c3..0000000000000000000000000000000000000000
--- a/run.py
+++ /dev/null
@@ -1,161 +0,0 @@
-import torch
-import os
-from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
-from sparse_coding_torch.small_data_classifier import SmallDataClassifierConv3d
-import time
-import numpy as np
-import torchvision
-from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
-from torchvision.datasets.video_utils import VideoClips
-import csv
-from datetime import datetime
-from yolov4.get_bounding_boxes import YoloModel
-import argparse
-
-
-if __name__ == "__main__":
-
-    parser = argparse.ArgumentParser(description='Process some integers.')
-    parser.add_argument('--fast', action='store_true',
-                    help='optimized for runtime')
-    parser.add_argument('--accurate', action='store_true',
-                    help='optimized for accuracy')
-    parser.add_argument('--verbose', action='store_true',
-                    help='output verbose')
-    args = parser.parse_args()
-    #print(args.accumulate(args.integers))
-    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-    batch_size = 1
-
-    frozen_sparse = ConvSparseLayer(in_channels=1,
-                                   out_channels=64,
-                                   kernel_size=(5, 15, 15),
-                                   stride=1,
-                                   padding=(0, 7, 7),
-                                   convo_dim=3,
-                                   rectifier=True,
-                                   lam=0.05,
-                                   max_activation_iter=150,
-                                   activation_lr=1e-2)
-
-    sparse_param = torch.load('sparse.pt', map_location=device)
-    frozen_sparse.load_state_dict(sparse_param['model_state_dict'])
-    frozen_sparse.to(device)
-
-    predictive_model = SmallDataClassifierConv3d()
-    predictive_model.to(device)
-
-    checkpoint = {k.replace('module.', ''): v for k,v in torch.load('classifier.pt', map_location=device)['model_state_dict'].items()}
-    predictive_model.load_state_dict(checkpoint)
-        
-    yolo_model = YoloModel()
-        
-    transform = torchvision.transforms.Compose(
-    [VideoGrayScaler(),
-     MinMaxScaler(0, 255),
-     torchvision.transforms.Normalize((0.2592,), (0.1251,)),
-     torchvision.transforms.CenterCrop((100, 200))
-    ])
-
-    frozen_sparse.eval()
-    predictive_model.eval()
-    
-    all_predictions = []
-    
-    all_files = list(os.listdir('input_videos'))
-    
-    for f in all_files:
-        print('Processing', f)
-        #start_time = time.time()
-        
-        clipstride = 15
-        if args.fast:
-            clipstride = 20
-        if args.accurate:
-            clipstride = 10
-        
-        vc = VideoClips([os.path.join('input_videos', f)],
-                        clip_length_in_frames=5,
-                        frame_rate=20,
-                       frames_between_clips=clipstride)
-    
-        ### START time after loading video ###
-        start_time = time.time()
-        clip_predictions = []
-        i = 0
-        cliplist = []
-        countclips = 0
-        for i in range(vc.num_clips()):
-
-            clip, _, _, _ = vc.get_clip(i)
-            clip = clip.swapaxes(1, 3).swapaxes(0, 1).swapaxes(2, 3).numpy()
-            
-            bounding_boxes = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1)).squeeze(0)
-            if bounding_boxes.size == 0:
-                continue
-            #widths = []
-            countclips = countclips + len(bounding_boxes)
-            
-            widths = [(bounding_boxes[i][3] - bounding_boxes[i][1]) for i in range(len(bounding_boxes))]
-            
-            #for i in range(len(bounding_boxes)):
-            #    widths.append(bounding_boxes[i][3] - bounding_boxes[i][1])
-
-            ind =  np.argmax(np.array(widths))
-            #for bb in bounding_boxes:
-            bb = bounding_boxes[ind]
-            center_x = (bb[3] + bb[1]) / 2 * 1920
-            center_y = (bb[2] + bb[0]) / 2 * 1080
-
-            width=400
-            height=400
-
-            lower_y = round(center_y - height / 2)
-            upper_y = round(center_y + height / 2)
-            lower_x = round(center_x - width / 2)
-            upper_x = round(center_x + width / 2)
-
-            trimmed_clip = clip[:, :, lower_y:upper_y, lower_x:upper_x]
-
-            trimmed_clip = torch.tensor(trimmed_clip).to(torch.float)
-
-            trimmed_clip = transform(trimmed_clip)
-            trimmed_clip.pin_memory()
-            cliplist.append(trimmed_clip)
-
-        if len(cliplist) > 0:
-            with torch.no_grad():
-                trimmed_clip = torch.stack(cliplist)
-                trimmed_clip = trimmed_clip.to(device, non_blocking=True)
-                activations = frozen_sparse(trimmed_clip)
-
-                pred, activations = predictive_model(activations)
-                #print(torch.nn.Sigmoid()(pred))
-                clip_predictions = (torch.nn.Sigmoid()(pred).round().detach().cpu().flatten().to(torch.long))
-
-            if args.verbose:
-                print(clip_predictions)
-                print("num of clips: ", countclips)
-            final_pred = torch.mode(clip_predictions)[0].item()
-            if len(clip_predictions) % 2 == 0 and torch.sum(clip_predictions).item() == len(clip_predictions)//2:
-                #print("I'm here")
-                final_pred = (torch.nn.Sigmoid()(pred)).mean().round().detach().cpu().to(torch.long).item()
-                
-            
-            if final_pred == 1:
-                str_pred = 'No Sliding'
-            else:
-                str_pred = 'Sliding'
-
-        else:
-            str_pred = "No Sliding"
-            
-        end_time = time.time()
-        
-        all_predictions.append({'FileName': f, 'Prediction': str_pred, 'TotalTimeSec': end_time - start_time})
-        
-    with open('output_' + datetime.now().strftime("%Y%m%d-%H%M%S") + '.csv', 'w+', newline='') as csv_out:
-        writer = csv.DictWriter(csv_out, fieldnames=all_predictions[0].keys())
-        
-        writer.writeheader()
-        writer.writerows(all_predictions)
diff --git a/run_ptx.py b/run_ptx.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b5ee354fcb6138f8c96870b7c891f1699e64bd7
--- /dev/null
+++ b/run_ptx.py
@@ -0,0 +1,180 @@
+import torch
+import os
+from sparse_coding_torch.keras_model import SparseCode, PNBClassifier, PTXClassifier, ReconSparse
+import time
+import numpy as np
+import torchvision
+from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
+from torchvision.datasets.video_utils import VideoClips
+import csv
+from datetime import datetime
+from yolov4.get_bounding_boxes import YoloModel
+import argparse
+import tensorflow as tf
+import tensorflow.keras as keras
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--input_dir', default='/shared_data/bamc_data/PTX_Sliding', type=str)
+    parser.add_argument('--kernel_size', default=15, type=int)
+    parser.add_argument('--kernel_depth', default=5, type=int)
+    parser.add_argument('--num_kernels', default=64, type=int)
+    parser.add_argument('--stride', default=2, type=int)
+    parser.add_argument('--max_activation_iter', default=100, type=int)
+    parser.add_argument('--activation_lr', default=1e-2, type=float)
+    parser.add_argument('--lam', default=0.05, type=float)
+    parser.add_argument('--sparse_checkpoint', default='converted_checkpoints/sparse.pt', type=str)
+    parser.add_argument('--checkpoint', default='converted_checkpoints/classifier.pt', type=str)
+    parser.add_argument('--run_2d', action='store_true')
+    parser.add_argument('--dataset', default='ptx', type=str)
+    
+    args = parser.parse_args()
+    #print(args.accumulate(args.integers))
+    batch_size = 1
+    
+    if args.dataset == 'pnb':
+        image_height = 250
+        image_width = 600
+    elif args.dataset == 'ptx':
+        image_height = 100
+        image_width = 200
+    else:
+        raise Exception('Invalid dataset')
+
+    if args.run_2d:
+        inputs = keras.Input(shape=(image_height, image_width, 5))
+    else:
+        inputs = keras.Input(shape=(5, image_height, image_width, 1))
+
+    filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')
+
+    output = SparseCode(batch_size=batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs)
+
+    sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)
+    
+    recon_inputs = keras.Input(shape=(1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
+    
+    recon_outputs = ReconSparse(batch_size=batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs)
+    
+    recon_model = keras.Model(inputs=recon_inputs, outputs=recon_outputs)
+
+    if args.sparse_checkpoint:
+        recon_model = keras.models.load_model(args.sparse_checkpoint)
+        
+    if args.checkpoint:
+        classifier_model = keras.models.load_model(args.checkpoint)
+    else:
+        classifier_inputs = keras.Input(shape=(1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
+
+        if args.dataset == 'pnb':
+            classifier_outputs = PNBClassifier()(classifier_inputs)
+        elif args.dataset == 'ptx':
+            classifier_outputs = PTXClassifier()(classifier_inputs)
+        else:
+            raise Exception('No classifier exists for that dataset')
+
+        classifier_model = keras.Model(inputs=classifier_inputs, outputs=classifier_outputs)
+        
+    yolo_model = YoloModel()
+        
+    transform = torchvision.transforms.Compose(
+    [VideoGrayScaler(),
+     MinMaxScaler(0, 255),
+     torchvision.transforms.Normalize((0.2592,), (0.1251,)),
+     torchvision.transforms.CenterCrop((100, 200))
+    ])
+    
+    all_predictions = []
+    
+    all_files = list(os.listdir(args.input_dir))
+    
+    for f in all_files:
+        print('Processing', f)
+        #start_time = time.time()
+        
+        clipstride = 15
+        
+        vc = VideoClips([os.path.join(args.input_dir, f)],
+                        clip_length_in_frames=5,
+                        frame_rate=20,
+                       frames_between_clips=clipstride)
+    
+        ### START time after loading video ###
+        start_time = time.time()
+        clip_predictions = []
+        i = 0
+        cliplist = []
+        countclips = 0
+        for i in range(vc.num_clips()):
+
+            clip, _, _, _ = vc.get_clip(i)
+            clip = clip.swapaxes(1, 3).swapaxes(0, 1).swapaxes(2, 3).numpy()
+            
+            bounding_boxes = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1)).squeeze(0)
+            if bounding_boxes.size == 0:
+                continue
+            #widths = []
+            countclips = countclips + len(bounding_boxes)
+            
+            widths = [(bounding_boxes[i][3] - bounding_boxes[i][1]) for i in range(len(bounding_boxes))]
+            
+            #for i in range(len(bounding_boxes)):
+            #    widths.append(bounding_boxes[i][3] - bounding_boxes[i][1])
+
+            ind =  np.argmax(np.array(widths))
+            #for bb in bounding_boxes:
+            bb = bounding_boxes[ind]
+            center_x = (bb[3] + bb[1]) / 2 * 1920
+            center_y = (bb[2] + bb[0]) / 2 * 1080
+
+            width=400
+            height=400
+
+            lower_y = round(center_y - height / 2)
+            upper_y = round(center_y + height / 2)
+            lower_x = round(center_x - width / 2)
+            upper_x = round(center_x + width / 2)
+
+            trimmed_clip = clip[:, :, lower_y:upper_y, lower_x:upper_x]
+
+            trimmed_clip = torch.tensor(trimmed_clip).to(torch.float)
+
+            trimmed_clip = transform(trimmed_clip)
+            trimmed_clip.pin_memory()
+            cliplist.append(trimmed_clip)
+
+        if len(cliplist) > 0:
+            with torch.no_grad():
+                trimmed_clip = torch.stack(cliplist)
+                images = trimmed_clip.permute(0, 2, 3, 4, 1).numpy()
+                activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.weights[0], axis=0))]))
+
+                pred = classifier_model(activations)
+                #print(torch.nn.Sigmoid()(pred))
+                clip_predictions = tf.math.round(tf.math.sigmoid(pred))
+
+            final_pred = torch.mode(torch.tensor(clip_predictions.numpy()).view(-1))[0].item()
+            if len(clip_predictions) % 2 == 0 and tf.math.reduce_sum(clip_predictions) == len(clip_predictions)//2:
+                #print("I'm here")
+                final_pred = torch.mode(torch.tensor(clip_predictions.numpy()).view(-1))[0].item()
+                
+            if final_pred == 1:
+                str_pred = 'No Sliding'
+            else:
+                str_pred = 'Sliding'
+
+        else:
+            str_pred = "No Sliding"
+            
+        print(str_pred)
+            
+        end_time = time.time()
+        
+        all_predictions.append({'FileName': f, 'Prediction': str_pred, 'TotalTimeSec': end_time - start_time})
+        
+    with open('output_' + datetime.now().strftime("%Y%m%d-%H%M%S") + '.csv', 'w+', newline='') as csv_out:
+        writer = csv.DictWriter(csv_out, fieldnames=all_predictions[0].keys())
+        
+        writer.writeheader()
+        writer.writerows(all_predictions)
diff --git a/run_tflite_pnb.py b/run_tflite_pnb.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b0f4fd0d57654a46a2d4e7266c067bbd00a4d03
--- /dev/null
+++ b/run_tflite_pnb.py
@@ -0,0 +1,86 @@
+import torch
+import os
+import time
+import numpy as np
+import torchvision
+from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler, get_yolo_regions
+from torchvision.datasets.video_utils import VideoClips
+import csv
+from datetime import datetime
+from yolov4.get_bounding_boxes import YoloModel
+import argparse
+import tensorflow as tf
+import scipy.stats
+import cv2
+
+if __name__ == "__main__":
+
+    parser = argparse.ArgumentParser(description='Python program for processing PNB data')
+    parser.add_argument('--classifier', type=str, default='keras/mobile_output/tf_lite_model.tflite')
+    parser.add_argument('--input_dir', type=str, default='input_videos')
+    args = parser.parse_args()
+
+    interpreter = tf.lite.Interpreter(args.classifier)
+    interpreter.allocate_tensors()
+
+    input_details = interpreter.get_input_details()
+    output_details = interpreter.get_output_details()
+
+    yolo_model = YoloModel()
+
+    transform = torchvision.transforms.Compose(
+    [VideoGrayScaler(),
+     MinMaxScaler(0, 255),
+     torchvision.transforms.Resize((360, 304))
+    ])
+
+    all_predictions = []
+
+    all_files = list(os.listdir(args.input_dir))
+
+    for f in all_files:
+        print('Processing', f)
+        
+        vc = tv.io.read_video(os.path.join(args.input_dir, f))[0].permute(3, 0, 1, 2)
+        
+        vc_sub = vc[:, -5:, :, :]
+        if vc_sub.size(1) < 5:
+            raise Exception(f + ' does not contain enough frames for processing')
+            
+        ### START time after loading video ###
+        start_time = time.time()
+        
+        clip = get_yolo_regions(yolo_model, vc_sub)
+        if clip:
+            clip = clip[0]
+            clip = transform(clip)
+
+            interpreter.set_tensor(input_details[0]['index'], clip)
+
+            interpreter.invoke()
+
+            output_array = np.array(interpreter.get_tensor(output_details[0]['index']))
+
+            pred = output_array[0][0]
+            print(pred)
+
+            final_pred = pred.round()
+
+                if final_pred == 1:
+                    str_pred = 'Positive'
+                else:
+                    str_pred = 'Negative'
+        else:
+            str_pred = "Positive"
+
+        end_time = time.time()
+
+        print(str_pred)
+
+        all_predictions.append({'FileName': f, 'Prediction': str_pred, 'TotalTimeSec': end_time - start_time})
+
+    with open('output_' + datetime.now().strftime("%Y%m%d-%H%M%S") + '.csv', 'w+', newline='') as csv_out:
+        writer = csv.DictWriter(csv_out, fieldnames=all_predictions[0].keys())
+
+        writer.writeheader()
+        writer.writerows(all_predictions)
diff --git a/run_tflite.py b/run_tflite_ptx.py
similarity index 94%
rename from run_tflite.py
rename to run_tflite_ptx.py
index d4e23ecf2e82ea70eb5e2ba204d32d83480d0bcb..09ec0045027c1ca8da76383d5765dad11217a044 100644
--- a/run_tflite.py
+++ b/run_tflite_ptx.py
@@ -44,7 +44,7 @@ if __name__ == "__main__":
 
     all_predictions = []
 
-    all_files = list(os.listdir('input_videos'))
+    all_files = list(os.listdir('/shared_data/bamc_data/PTX_Sliding'))
 
     for f in all_files:
         print('Processing', f)
@@ -56,7 +56,7 @@ if __name__ == "__main__":
         if args.accurate:
             clipstride = 10
 
-        vc = VideoClips([os.path.join('input_videos', f)],
+        vc = VideoClips([os.path.join('/shared_data/bamc_data/PTX_Sliding', f)],
                         clip_length_in_frames=5,
                         frame_rate=20,
                        frames_between_clips=clipstride)
diff --git a/sparse_coding_torch/convert_pytorch_to_keras.py b/sparse_coding_torch/convert_pytorch_to_keras.py
new file mode 100644
index 0000000000000000000000000000000000000000..9abf4ac07fdaab26ce25aeefc483eaf09639faf3
--- /dev/null
+++ b/sparse_coding_torch/convert_pytorch_to_keras.py
@@ -0,0 +1,65 @@
+import argparse
+from tensorflow import keras
+import tensorflow as tf
+from sparse_coding_torch.keras_model import SparseCode, PNBClassifier, PTXClassifier, ReconSparse, load_pytorch_weights
+import torch
+import os
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--sparse_checkpoint', default=None, type=str)
+    parser.add_argument('--classifier_checkpoint', default=None, type=str)
+    parser.add_argument('--kernel_size', default=15, type=int)
+    parser.add_argument('--kernel_depth', default=5, type=int)
+    parser.add_argument('--num_kernels', default=64, type=int)
+    parser.add_argument('--stride', default=2, type=int)
+    parser.add_argument('--dataset', default='ptx', type=str)
+    parser.add_argument('--input_image_height', default=100, type=int)
+    parser.add_argument('--input_image_width', default=200, type=int)
+    parser.add_argument('--output_dir', default='./converted_checkpoints', type=str)
+    parser.add_argument('--lam', default=0.05, type=float)
+    parser.add_argument('--activation_lr', default=1e-2, type=float)
+    parser.add_argument('--max_activation_iter', default=100, type=int)
+    parser.add_argument('--run_2d', action='store_true')
+    
+    args = parser.parse_args()
+    
+    if not os.path.exists(args.output_dir):
+        os.makedirs(args.output_dir)
+    
+    if args.classifier_checkpoint:
+        classifier_inputs = keras.Input(shape=(1, (args.input_image_height - args.kernel_size) // args.stride + 1, (args.input_image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
+
+        if args.dataset == 'pnb':
+            classifier_outputs = PNBClassifier()(classifier_inputs)
+            classifier_name = 'pnb_classifier'
+        elif args.dataset == 'ptx':
+            classifier_outputs = PTXClassifier()(classifier_inputs)
+            classifier_name = 'ptx_classifier'
+        else:
+            raise Exception('No classifier exists for that dataset')
+
+        classifier_model = keras.Model(inputs=classifier_inputs, outputs=classifier_outputs)
+        
+        pytorch_checkpoint = torch.load(args.classifier_checkpoint, map_location='cpu')['model_state_dict']
+        conv_weights = [pytorch_checkpoint['module.compress_activations_conv_1.weight'].squeeze(2).swapaxes(0, 2).swapaxes(1, 3).swapaxes(2, 3).numpy(), pytorch_checkpoint['module.compress_activations_conv_1.bias'].numpy()]
+        classifier_model.get_layer(classifier_name).conv_1.set_weights(conv_weights)
+        ff_3_weights = [pytorch_checkpoint['module.fc3.weight'].swapaxes(1,0).numpy(), pytorch_checkpoint['module.fc3.bias'].numpy()]
+        classifier_model.get_layer(classifier_name).ff_3.set_weights(ff_3_weights)
+        ff_4_weights = [pytorch_checkpoint['module.fc4.weight'].swapaxes(1,0).numpy(), pytorch_checkpoint['module.fc4.bias'].numpy()]
+        classifier_model.get_layer(classifier_name).ff_4.set_weights(ff_4_weights)
+        
+        classifier_model.save(os.path.join(args.output_dir, "classifier.pt"))
+        
+    if args.sparse_checkpoint:
+        input_shape = [1, (args.input_image_height - args.kernel_size) // args.stride + 1, (args.input_image_width - args.kernel_size) // args.stride + 1, args.num_kernels]
+        recon_inputs = keras.Input(shape=input_shape)
+    
+        recon_outputs = ReconSparse(batch_size=1, image_height=args.input_image_height, image_width=args.input_image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs)
+
+        recon_model = keras.Model(inputs=recon_inputs, outputs=recon_outputs)
+        
+        pytorch_weights = load_pytorch_weights(args.sparse_checkpoint)
+        recon_model.get_layer('recon_sparse').filters = tf.Variable(initial_value=pytorch_weights, dtype='float32', trainable=True)
+        
+        recon_model.save(os.path.join(args.output_dir, "sparse.pt"))
\ No newline at end of file
diff --git a/sparse_coding_torch/keras_model.py b/sparse_coding_torch/keras_model.py
index d932cfe612800529a7d19b460832399b09b84bcc..3f00ae302197225f660a9c24b851927964c5e3d4 100644
--- a/sparse_coding_torch/keras_model.py
+++ b/sparse_coding_torch/keras_model.py
@@ -15,7 +15,8 @@ def load_pytorch_weights(file_path):
     return weight_tensor
 
 # @tf.function
-def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, image_height, image_width, stride):
+def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, image_height, image_width, stride):
+    batch_size = tf.shape(activations)[0]
     out_1 = tf.nn.conv2d_transpose(activations, filters_1, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding='VALID')
     out_2 = tf.nn.conv2d_transpose(activations, filters_2, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding='VALID')
     out_3 = tf.nn.conv2d_transpose(activations, filters_3, output_shape=(batch_size, image_height, image_width, 1), strides=stride, padding='VALID')
@@ -27,8 +28,9 @@ def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations,
     return recon
 
 # @tf.function
-def do_recon_3d(filters, activations, batch_size, image_height, image_width, stride):
+def do_recon_3d(filters, activations, image_height, image_width, stride):
 #     activations = tf.pad(activations, paddings=[[0,0], [2, 2], [0, 0], [0, 0], [0, 0]])
+    batch_size = tf.shape(activations)[0]
     recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, image_height, image_width, 1), strides=[1, stride, stride], padding='VALID')
 
     return recon
@@ -99,7 +101,7 @@ class SparseCode(keras.layers.Layer):
         if self.run_2d:
             recon = do_recon(filters[0], filters[1], filters[2], filters[3], filters[4], activations, self.batch_size, self.image_height, self.image_width, self.stride)
         else:
-            recon = do_recon_3d(filters, activations, self.batch_size, self.image_height, self.image_width, self.stride)
+            recon = do_recon_3d(filters, activations, self.image_height, self.image_width, self.stride)
 
         e = images - recon
         g = -1 * u
@@ -135,13 +137,13 @@ class SparseCode(keras.layers.Layer):
     def call(self, images, filters):
         filters = tf.squeeze(filters, axis=0)
         if self.run_2d:
-            output_shape = (self.batch_size, (self.image_height - self.kernel_size) // self.stride + 1, (self.image_width - self.kernel_size) // self.stride + 1, self.out_channels)
+            output_shape = (len(images), (self.image_height - self.kernel_size) // self.stride + 1, (self.image_width - self.kernel_size) // self.stride + 1, self.out_channels)
         else:
-            output_shape = (self.batch_size, 1, (self.image_height - self.kernel_size) // self.stride + 1, (self.image_width - self.kernel_size) // self.stride + 1, self.out_channels)
+            output_shape = (len(images), 1, (self.image_height - self.kernel_size) // self.stride + 1, (self.image_width - self.kernel_size) // self.stride + 1, self.out_channels)
 
-        u = tf.zeros(shape=output_shape)
-        m = tf.zeros(shape=output_shape)
-        v = tf.zeros(shape=output_shape)
+        u = tf.stop_gradient(tf.zeros(shape=output_shape))
+        m = tf.stop_gradient(tf.zeros(shape=output_shape))
+        v = tf.stop_gradient(tf.zeros(shape=output_shape))
 #         tf.print('activations before:', tf.reduce_sum(u))
 
         b1 = tf.constant(0.9, dtype='float32')
@@ -182,6 +184,8 @@ class ReconSparse(keras.Model):
             self.filters_4 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
             self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
         else:
+#             pytorch_weights = load_pytorch_weights('sparse.pt')
+#             self.filters = tf.Variable(initial_value=pytorch_weights, dtype='float32', trainable=False)
             self.filters = tf.Variable(initial_value=initializer(shape=(5, kernel_size, kernel_size, in_channels, out_channels), dtype='float32'), trainable=True)
 
         if run_2d:
@@ -193,38 +197,74 @@ class ReconSparse(keras.Model):
 #     @tf.function
     def call(self, activations):
         if self.run_2d:
-            recon = do_recon(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, activations, self.batch_size, self.image_height, self.image_width, self.stride)
+            recon = do_recon(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, activations, self.image_height, self.image_width, self.stride)
         else:
-            recon = do_recon_3d(self.filters, activations, self.batch_size, self.image_height, self.image_width, self.stride)
+            recon = do_recon_3d(self.filters, activations, self.image_height, self.image_width, self.stride)
             
         return recon
 
-class Classifier(keras.layers.Layer):
+class PTXClassifier(keras.layers.Layer):
     def __init__(self):
-        super(Classifier, self).__init__()
+        super(PTXClassifier, self).__init__()
 
         self.max_pool = keras.layers.MaxPooling2D(pool_size=4, strides=4)
-        self.conv = keras.layers.Conv2D(24, kernel_size=8, strides=4, activation='relu', padding='valid')
+        self.conv_1 = keras.layers.Conv2D(24, kernel_size=8, strides=4, activation='relu', padding='valid')
+#         self.conv_2 = keras.layers.Conv2D(24, kernel_size=4, strides=2, activation='relu', padding='valid')
 
         self.flatten = keras.layers.Flatten()
 
         self.dropout = keras.layers.Dropout(0.5)
 
-        # self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True)
-        # self.ff_2 = keras.layers.Dense(100, activation='relu', use_bias=True)
+#         self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True)
+#         self.ff_2 = keras.layers.Dense(500, activation='relu', use_bias=True)
         self.ff_3 = keras.layers.Dense(20, activation='relu', use_bias=True)
-        self.ff_4 = keras.layers.Dense(1, activation='sigmoid')
+        self.ff_4 = keras.layers.Dense(1)
 
-    @tf.function
+#     @tf.function
+    def call(self, activations):
+        activations = tf.squeeze(activations, axis=1)
+        x = self.max_pool(activations)
+        x = self.conv_1(x)
+#         x = self.conv_2(x)
+        x = self.flatten(x)
+#         x = self.ff_1(x)
+#         x = self.dropout(x)
+#         x = self.ff_2(x)
+#         x = self.dropout(x)
+        x = self.ff_3(x)
+        x = self.dropout(x)
+        x = self.ff_4(x)
+
+        return x
+    
+class PNBClassifier(keras.layers.Layer):
+    def __init__(self):
+        super(PNBClassifier, self).__init__()
+
+        self.max_pool = keras.layers.MaxPooling2D(pool_size=4, strides=4)
+        self.conv_1 = keras.layers.Conv2D(24, kernel_size=8, strides=4, activation='relu', padding='valid')
+#         self.conv_2 = keras.layers.Conv2D(24, kernel_size=4, strides=2, activation='relu', padding='valid')
+
+        self.flatten = keras.layers.Flatten()
+
+        self.dropout = keras.layers.Dropout(0.5)
+
+#         self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True)
+#         self.ff_2 = keras.layers.Dense(500, activation='relu', use_bias=True)
+        self.ff_3 = keras.layers.Dense(100, activation='relu', use_bias=True)
+        self.ff_4 = keras.layers.Dense(1)
+
+#     @tf.function
     def call(self, activations):
         activations = tf.squeeze(activations, axis=1)
         x = self.max_pool(activations)
-        x = self.conv(x)
+        x = self.conv_1(x)
+#         x = self.conv_2(x)
         x = self.flatten(x)
-        # # x = self.ff_1(x)
-        # # x = self.dropout(x)
-        # # x = self.ff_2(x)
-        # # x = self.dropout(x)
+#         x = self.ff_1(x)
+#         x = self.dropout(x)
+#         x = self.ff_2(x)
+#         x = self.dropout(x)
         x = self.ff_3(x)
         x = self.dropout(x)
         x = self.ff_4(x)
diff --git a/sparse_coding_torch/load_data.py b/sparse_coding_torch/load_data.py
index c13c71ac563fec912b2d5e1864b0cbbd90d43ba5..4dd4228aa5445830df110cd51c3db56b27671895 100644
--- a/sparse_coding_torch/load_data.py
+++ b/sparse_coding_torch/load_data.py
@@ -3,24 +3,26 @@ import torchvision
 import torch
 from sklearn.model_selection import train_test_split
 from sparse_coding_torch.video_loader import MinMaxScaler
-from sparse_coding_torch.video_loader import YoloClipLoader, get_video_participants, PNBLoader
+from sparse_coding_torch.video_loader import YoloClipLoader, get_ptx_participants, PNBLoader, get_pnb_participants
 from sparse_coding_torch.video_loader import VideoGrayScaler
+from typing import Sequence, Iterator
 import csv
 from sklearn.model_selection import train_test_split, GroupShuffleSplit, LeaveOneGroupOut, LeaveOneOut, StratifiedGroupKFold, StratifiedKFold, KFold
 
 def load_yolo_clips(batch_size, mode, num_clips=1, num_positives=100, device=None, n_splits=None, sparse_model=None, whole_video=False, positive_videos=None):   
     video_path = "/shared_data/YOLO_Updated_PL_Model_Results/"
 
-    video_to_participant = get_video_participants()
+    video_to_participant = get_ptx_participants()
     
     transforms = torchvision.transforms.Compose(
     [VideoGrayScaler(),
+#      MinMaxScaler(0, 255),
      torchvision.transforms.Normalize((0.2592,), (0.1251,)),
     ])
     augment_transforms = torchvision.transforms.Compose(
     [torchvision.transforms.RandomRotation(45),
-     torchvision.transforms.RandomHorizontalFlip()
-#      torchvision.transforms.CenterCrop((100, 200))
+     torchvision.transforms.RandomHorizontalFlip(),
+     torchvision.transforms.CenterCrop((100, 200))
     ])
     if whole_video:
         dataset = YoloVideoLoader(video_path, num_clips=num_clips, num_positives=num_positives, transform=transforms, augment_transform=augment_transforms, sparse_model=sparse_model, device=device)
@@ -43,7 +45,7 @@ def load_yolo_clips(batch_size, mode, num_clips=1, num_positives=100, device=Non
                                                sampler=train_sampler)
         test_loader = None
         
-        return train_loader, test_loader
+        return train_loader, test_loader, dataset
     elif mode == 'k_fold':
         gss = StratifiedGroupKFold(n_splits=n_splits)
 
@@ -51,34 +53,69 @@ def load_yolo_clips(batch_size, mode, num_clips=1, num_positives=100, device=Non
         
         return gss.split(np.arange(len(targets)), targets, groups), dataset
     else:
-        return None
+        gss = GroupShuffleSplit(n_splits=n_splits, test_size=0.2)
 
+        groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
+        
+        train_idx, test_idx = list(gss.split(np.arange(len(targets)), targets, groups))[0]
+        
+        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                               sampler=train_sampler)
+        
+        test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)
+        test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                               sampler=test_sampler)
+        
+        return train_loader, test_loader, dataset
     
-def load_pnb_videos(batch_size, mode, classify_mode=False, device=None, n_splits=None, sparse_model=None):   
+class SubsetWeightedRandomSampler(torch.utils.data.Sampler[int]):
+    weights: torch.Tensor
+    num_samples: int
+    replacement: bool
+
+    def __init__(self, weights: Sequence[float], indicies: Sequence[int],
+                 replacement: bool = True, generator=None) -> None:
+        if not isinstance(replacement, bool):
+            raise ValueError("replacement should be a boolean value, but got "
+                             "replacement={}".format(replacement))
+        self.weights = torch.as_tensor(weights, dtype=torch.double)
+        self.indicies = indicies
+        self.replacement = replacement
+        self.generator = generator
+
+    def __iter__(self) -> Iterator[int]:
+        rand_tensor = torch.multinomial(self.weights, len(self.indicies), self.replacement, generator=self.generator)
+        for i in rand_tensor:
+            yield self.indicies[i]
+
+    def __len__(self) -> int:
+        return len(self.indicies)
+    
+def load_pnb_videos(batch_size, mode=None, classify_mode=False, balance_classes=False, device=None, n_splits=None, sparse_model=None):   
     video_path = "/shared_data/bamc_pnb_data/full_training_data"
     
     transforms = torchvision.transforms.Compose(
     [VideoGrayScaler(),
      MinMaxScaler(0, 255),
-     torchvision.transforms.Resize((360, 304))
+     torchvision.transforms.Resize((250, 600))
     ])
     augment_transforms = torchvision.transforms.Compose(
-    [torchvision.transforms.RandomAffine(45),
+    [torchvision.transforms.RandomRotation(30),
      torchvision.transforms.RandomHorizontalFlip(),
-     torchvision.transforms.ColorJitter(brightness=0.5),
-     torchvision.transforms.RandomAdjustSharpness(0, p=0.15),
-     torchvision.transforms.RandomAffine(degrees=0, translate=(0.05, 0))
+     torchvision.transforms.ColorJitter(brightness=0.1),
+#      torchvision.transforms.RandomAdjustSharpness(0, p=0.15),
+     torchvision.transforms.RandomAffine(degrees=0, translate=(0.01, 0))
 #      torchvision.transforms.CenterCrop((100, 200))
     ])
-    dataset = PNBLoader(video_path, classify_mode, num_frames=5, frame_rate=20, transform=transforms, augmentation=augment_transforms)
+    dataset = PNBLoader(video_path, classify_mode, balance_classes=balance_classes, num_frames=5, frame_rate=20, transform=transforms, augmentation=augment_transforms)
     
     targets = dataset.get_labels()
     
     if mode == 'leave_one_out':
         gss = LeaveOneGroupOut()
 
-        groups = [v for v in dataset.get_filenames()]
-#         groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
+        groups = get_pnb_participants(dataset.get_filenames())
         
         return gss.split(np.arange(len(targets)), targets, groups), dataset
     elif mode == 'all_train':
@@ -88,13 +125,26 @@ def load_pnb_videos(batch_size, mode, classify_mode=False, device=None, n_splits
                                                sampler=train_sampler)
         test_loader = None
         
-        return train_loader, test_loader
+        return train_loader, test_loader, dataset
     elif mode == 'k_fold':
-        gss = StratifiedKFold(n_splits=n_splits, shuffle=True)
+        gss = StratifiedGroupKFold(n_splits=n_splits, shuffle=True)
 
-#         groups = [video_to_participant[v.lower().replace('_clean', '')] for v in dataset.get_filenames()]
-        groups = [v for v in dataset.get_filenames()]
+        groups = get_pnb_participants(dataset.get_filenames())
         
-        return gss.split(np.arange(len(targets)), targets), dataset
+        return gss.split(np.arange(len(targets)), targets, groups), dataset
     else:
-        return None
\ No newline at end of file
+        gss = GroupShuffleSplit(n_splits=n_splits, test_size=0.2)
+
+        groups = get_pnb_participants(dataset.get_filenames())
+        
+        train_idx, test_idx = list(gss.split(np.arange(len(targets)), targets, groups))[0]
+        
+        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
+        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                               sampler=train_sampler)
+        
+        test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)
+        test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
+                                               sampler=test_sampler)
+        
+        return train_loader, test_loader, dataset
\ No newline at end of file
diff --git a/sparse_coding_torch/train_classifier.py b/sparse_coding_torch/train_classifier.py
index 1c6d15a03401ecff9aa89574d9cc7ab0b6ee5cdf..c78d6fd11d1b160f3a04d3e77432ed71af88be5b 100644
--- a/sparse_coding_torch/train_classifier.py
+++ b/sparse_coding_torch/train_classifier.py
@@ -4,8 +4,8 @@ import torch.nn.functional as F
 from tqdm import tqdm
 import argparse
 import os
-from sparse_coding_torch.load_data import load_yolo_clips, load_pnb_videos
-from sparse_coding_torch.keras_model import SparseCode, Classifier, ReconSparse
+from sparse_coding_torch.load_data import load_yolo_clips, load_pnb_videos, SubsetWeightedRandomSampler
+from sparse_coding_torch.keras_model import SparseCode, PNBClassifier, PTXClassifier, ReconSparse
 import time
 import numpy as np
 from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
@@ -14,6 +14,31 @@ import pickle
 import tensorflow.keras as keras
 import tensorflow as tf
 
+configproto = tf.compat.v1.ConfigProto()
+configproto.gpu_options.polling_inactive_delay_msecs = 5000
+configproto.gpu_options.allow_growth = True
+sess = tf.compat.v1.Session(config=configproto) 
+tf.compat.v1.keras.backend.set_session(sess)
+
+def get_sample_weights(train_idx, dataset):
+    dataset = list(dataset)
+
+    num_positive = len([clip[0] for clip in dataset if clip[0] == 'Positives'])
+    negative_weight = num_positive / len(dataset)
+    positive_weight = 1.0 - negative_weight
+    
+    weights = []
+    for idx in train_idx:
+        label = dataset[idx][0]
+        if label == 'Positives':
+            weights.append(positive_weight)
+        elif label == 'Negatives':
+            weights.append(negative_weight)
+        else:
+            raise Exception('Sampler encountered invalid label')
+    
+    return weights
+
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument('--batch_size', default=12, type=int)
@@ -22,25 +47,33 @@ if __name__ == "__main__":
     parser.add_argument('--num_kernels', default=64, type=int)
     parser.add_argument('--stride', default=1, type=int)
     parser.add_argument('--max_activation_iter', default=150, type=int)
-    parser.add_argument('--activation_lr', default=1e-1, type=float)
-    parser.add_argument('--lr', default=5e-5, type=float)
-    parser.add_argument('--epochs', default=10, type=int)
+    parser.add_argument('--activation_lr', default=1e-2, type=float)
+    parser.add_argument('--lr', default=5e-2, type=float)
+    parser.add_argument('--epochs', default=40, type=int)
     parser.add_argument('--lam', default=0.05, type=float)
     parser.add_argument('--output_dir', default='./output', type=str)
     parser.add_argument('--sparse_checkpoint', default=None, type=str)
     parser.add_argument('--checkpoint', default=None, type=str)
-    parser.add_argument('--splits', default='k_fold', type=str, help='k_fold or leave_one_out or all_train')
-    parser.add_argument('--seed', default=42, type=int)
+    parser.add_argument('--splits', default=None, type=str, help='k_fold or leave_one_out or all_train')
+    parser.add_argument('--seed', default=26, type=int)
     parser.add_argument('--train', action='store_true')
     parser.add_argument('--num_positives', default=100, type=int)
-    parser.add_argument('--n_splits', default=5, type=int)
+    parser.add_argument('--n_splits', default=1, type=int)
     parser.add_argument('--save_train_test_splits', action='store_true')
     parser.add_argument('--run_2d', action='store_true')
+    parser.add_argument('--balance_classes', action='store_true')
+    parser.add_argument('--dataset', default='pnb', type=str)
     
     args = parser.parse_args()
     
-    image_height = 360
-    image_width = 304
+    if args.dataset == 'pnb':
+        image_height = 250
+        image_width = 600
+    elif args.dataset == 'ptx':
+        image_height = 100
+        image_width = 200
+    else:
+        raise Exception('Invalid dataset')
     
     random.seed(args.seed)
     np.random.seed(args.seed)
@@ -55,6 +88,17 @@ if __name__ == "__main__":
 
     all_errors = []
     
+    if args.run_2d:
+        inputs = keras.Input(shape=(image_height, image_width, 5))
+    else:
+        inputs = keras.Input(shape=(5, image_height, image_width, 1))
+
+    filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')
+
+    output = SparseCode(batch_size=args.batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs)
+
+    sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)
+    
     recon_inputs = keras.Input(shape=(1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
     
     recon_outputs = ReconSparse(batch_size=args.batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(recon_inputs)
@@ -64,200 +108,299 @@ if __name__ == "__main__":
     if args.sparse_checkpoint:
         recon_model = keras.models.load_model(args.sparse_checkpoint)
         
-    splits, dataset = load_pnb_videos(args.batch_size, classify_mode=True, mode='k_fold', device=None, n_splits=args.n_splits, sparse_model=None)
-    i_fold = 0
+    positive_class = None
+    if args.dataset == 'pnb':
+        train_loader, test_loader, dataset = load_pnb_videos(args.batch_size, classify_mode=True, balance_classes=args.balance_classes, mode=args.splits, device=None, n_splits=args.n_splits, sparse_model=None)
+        positive_class = 'Positives'
+    elif args.dataset == 'ptx':
+        train_loader, test_loader, dataset = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode=args.splits, device=None, n_splits=args.n_splits, sparse_model=None, whole_video=False, positive_videos='positive_videos.json')
+        positive_class = 'No_Sliding'
+    else:
+        raise Exception('Invalid dataset')
     
     overall_true = []
     overall_pred = []
     fn_ids = []
     fp_ids = []
         
-    for train_idx, test_idx in splits:
-        
-        train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
-        test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)
-
-        train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
-                                                   # shuffle=True,
-                                                   sampler=train_sampler)
+    if args.checkpoint:
+        classifier_model = keras.models.load_model(args.checkpoint)
+    else:
+        if os.path.exists(os.path.join(args.output_dir, 'best_classifier.pt')):
+            classifier_model = keras.models.load_model(os.path.join(output_dir, 'best_classifier.pt'))
 
-        test_loader = torch.utils.data.DataLoader(dataset, batch_size=1,
-                                                        # shuffle=True,
-                                                        sampler=test_sampler)
-        
         classifier_inputs = keras.Input(shape=(1, (image_height - args.kernel_size) // args.stride + 1, (image_width - args.kernel_size) // args.stride + 1, args.num_kernels))
 
-        classifier_outputs = Classifier()(classifier_inputs)
+        if args.dataset == 'pnb':
+            classifier_outputs = PNBClassifier()(classifier_inputs)
+        elif args.dataset == 'ptx':
+            classifier_outputs = PTXClassifier()(classifier_inputs)
+        else:
+            raise Exception('No classifier exists for that dataset')
 
         classifier_model = keras.Model(inputs=classifier_inputs, outputs=classifier_outputs)
 
+    with open(os.path.join(output_dir, 'classifier_summary.txt'), 'w+') as out_f:
+        out_f.write(str(classifier_model.summary()))
+    
+    prediction_optimizer = keras.optimizers.Adam(learning_rate=args.lr)
 
-        best_so_far = float('inf')
+    best_so_far = float('inf')
 
-        criterion = keras.losses.BinaryCrossentropy(from_logits=False)
+    criterion = keras.losses.BinaryCrossentropy(from_logits=True, reduction=keras.losses.Reduction.SUM)
 
-        if args.checkpoint:
-            classifier_model.load(args.checkpoint)
+    if args.train:
+        for epoch in range(args.epochs):
+            epoch_loss = 0
+            t1 = time.perf_counter()
 
-        if args.train:
-            prediction_optimizer = keras.optimizers.Adam(learning_rate=args.lr)
+            y_true_train = None
+            y_pred_train = None
 
-            for epoch in range(args.epochs):
-                epoch_loss = 0
-                t1 = time.perf_counter()
-                
-                if args.run_2d:
-                    inputs = keras.Input(shape=(image_height, image_width, 5))
-                else:
-                    inputs = keras.Input(shape=(5, image_height, image_width, 1))
+            for labels, local_batch, vid_f in tqdm(train_loader):
+                images = local_batch.permute(0, 2, 3, 4, 1).numpy()
 
-                filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')
+                torch_labels = np.zeros(len(labels))
+                torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1
+                torch_labels = np.expand_dims(torch_labels, axis=1)
 
-                output = SparseCode(batch_size=args.batch_size, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs)
+                activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
+#                     print(tf.math.reduce_sum(activations))
 
-                sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)
+                with tf.GradientTape() as tape:
+                    pred = classifier_model(activations)
+                    loss = criterion(torch_labels, pred)
 
-                for labels, local_batch, vid_f in tqdm(train_loader):
-                    if local_batch.size(0) != args.batch_size:
-                        continue
-                    images = local_batch.permute(0, 2, 3, 4, 1).numpy()
+#                         print(pred)
+#                         print(tf.math.sigmoid(pred))
+#                         print(loss)
+#                     print(torch_labels)
 
-                    torch_labels = np.zeros(len(labels))
-                    torch_labels[[i for i in range(len(labels)) if labels[i] == 'Positives']] = 1
-                    torch_labels = np.expand_dims(torch_labels, axis=1)
+                epoch_loss += loss * local_batch.size(0)
 
-                    activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
+                gradients = tape.gradient(loss, classifier_model.trainable_weights)
 
-                    with tf.GradientTape() as tape:
-                        pred = classifier_model(activations)
-                        loss = criterion(torch_labels, pred)
+                prediction_optimizer.apply_gradients(zip(gradients, classifier_model.trainable_weights))
 
-                    epoch_loss += loss * local_batch.size(0)
+                if y_true_train is None:
+                    y_true_train = torch_labels
+                    y_pred_train = tf.math.round(tf.math.sigmoid(pred))
+                else:
+                    y_true_train = tf.concat((y_true_train, torch_labels), axis=0)
+                    y_pred_train = tf.concat((y_pred_train, tf.math.round(tf.math.sigmoid(pred))), axis=0)
 
-                    gradients = tape.gradient(loss, classifier_model.trainable_weights)
+            t2 = time.perf_counter()
 
-                    prediction_optimizer.apply_gradients(zip(gradients, classifier_model.trainable_weights))
+            y_true = None
+            y_pred = None
+            test_loss = 0.0
+            for labels, local_batch, vid_f in tqdm(test_loader):
+                images = local_batch.permute(0, 2, 3, 4, 1).numpy()
+
+                torch_labels = np.zeros(len(labels))
+                torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1
+                torch_labels = np.expand_dims(torch_labels, axis=1)
+
+                activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
+
+                pred = classifier_model(activations)
+                loss = criterion(torch_labels, pred)
+
+                test_loss += loss
 
-                t2 = time.perf_counter()
-                
-                if args.run_2d:
-                    inputs = keras.Input(shape=(image_height, image_width, 5))
+                if y_true is None:
+                    y_true = torch_labels
+                    y_pred = tf.math.round(tf.math.sigmoid(pred))
                 else:
-                    inputs = keras.Input(shape=(5, image_height, image_width, 1))
+                    y_true = tf.concat((y_true, torch_labels), axis=0)
+                    y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0)
 
-                filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')
+            t2 = time.perf_counter()
 
-                output = SparseCode(batch_size=1, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs)
+            y_true = tf.cast(y_true, tf.int32)
+            y_pred = tf.cast(y_pred, tf.int32)
 
-                sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)
+            y_true_train = tf.cast(y_true_train, tf.int32)
+            y_pred_train = tf.cast(y_pred_train, tf.int32)
 
-                y_true = None
-                y_pred = None
-                for labels, local_batch, vid_f in test_loader:
-                    images = local_batch.permute(0, 2, 3, 4, 1).numpy()
+            f1 = f1_score(y_true, y_pred, average='macro')
+            accuracy = accuracy_score(y_true, y_pred)
 
-                    torch_labels = np.zeros(len(labels))
-                    torch_labels[[i for i in range(len(labels)) if labels[i] == 'Positives']] = 1
-                    torch_labels = np.expand_dims(torch_labels, axis=1)
+            train_accuracy = accuracy_score(y_true_train, y_pred_train)
 
-                    activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
+            print('epoch={}, time={:.2f}, train_loss={:.2f}, test_loss={:.2f}, train_acc={:.2f}, test_f1={:.2f}, test_acc={:.2f}'.format(epoch, t2-t1, epoch_loss, test_loss, train_accuracy, f1, accuracy))
+#             print(epoch_loss)
+            if epoch_loss <= best_so_far:
+                print("found better model")
+                # Save model parameters
+                classifier_model.save(os.path.join(output_dir, "best_classifier.pt"))
+                pickle.dump(prediction_optimizer.get_weights(), open(os.path.join(output_dir, 'optimizer.pt'), 'wb+'))
+                best_so_far = epoch_loss
 
-                    pred = classifier_model(activations)
+        classifier_model = keras.models.load_model(os.path.join(output_dir, "best_classifier.pt"))
 
-                    if y_true is None:
-                        y_true = torch_labels
-                        y_pred = tf.math.round(tf.math.sigmoid(pred))
-                    else:
-                        y_true = tf.concat((y_true, torch_labels), axis=0)
-                        y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0)
-
-                t2 = time.perf_counter()
-                
-                y_true = tf.cast(y_true, tf.int32)
-                y_pred = tf.cast(y_pred, tf.int32)
-
-                f1 = f1_score(y_true, y_pred, average='macro')
-                accuracy = accuracy_score(y_true, y_pred)
-
-                print('fold={}, epoch={}, time={:.2f}, loss={:.2f}, f1={:.2f}, acc={:.2f}'.format(i_fold, epoch, t2-t1, epoch_loss, f1, accuracy))
-    #             print(epoch_loss)
-                if epoch_loss <= best_so_far:
-                    print("found better model")
-                    # Save model parameters
-                    classifier_model.save(os.path.join(output_dir, "model-best_fold_" + str(i_fold) + ".pt"))
-                    best_so_far = epoch_loss
-
-            classifier_model = keras.models.load_model(os.path.join(output_dir, "model-best_fold_" + str(i_fold) + ".pt"))
+    if args.dataset == 'pnb':
+        epoch_loss = 0
+
+        y_true = None
+        y_pred = None
+
+        pred_dict = {}
+        gt_dict = {}
+
+        t1 = time.perf_counter()
+#         test_videos = [vid_f for labels, local_batch, vid_f in batch for batch in test_loader]
+        test_videos = [single_vid for labels, local_batch, vid_f in test_loader for single_vid in vid_f]
+
+        for k, v in tqdm(dataset.get_final_clips().items()):
+            if k not in test_videos:
+                continue
+            labels, local_batch, vid_f = v
+            images = local_batch.unsqueeze(0).permute(0, 2, 3, 4, 1).numpy()
+            labels = [labels]
+
+            torch_labels = np.zeros(len(labels))
+            torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1
+            torch_labels = np.expand_dims(torch_labels, axis=1)
+
+            activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
+
+            pred = classifier_model(activations)
+
+            loss = criterion(torch_labels, pred)
+            epoch_loss += loss
+
+            final_pred = tf.math.round(tf.math.sigmoid(pred))
+            gt = torch_labels
             
-            if args.run_2d:
-                inputs = keras.Input(shape=(image_height, image_width, 5))
+            if final_pred != gt:
+                if final_pred == 0:
+                    fn_ids.append(k)
+                else:
+                    fp_ids.append(k)
+
+            overall_true.append(gt)
+            overall_pred.append(final_pred)
+
+            if final_pred != gt:
+                if final_pred == 0:
+                    fn_ids.append(vid_f)
+                else:
+                    fp_ids.append(vid_f)
+
+            if y_true is None:
+                y_true = torch_labels
+                y_pred = tf.math.round(tf.math.sigmoid(pred))
             else:
-                inputs = keras.Input(shape=(5, image_height, image_width, 1))
+                y_true = tf.concat((y_true, torch_labels), axis=0)
+                y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0)
 
-            filter_inputs = keras.Input(shape=(5, args.kernel_size, args.kernel_size, 1, args.num_kernels), dtype='float32')
+        t2 = time.perf_counter()
 
-            output = SparseCode(batch_size=1, image_height=image_height, image_width=image_width, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs, filter_inputs)
+        print('loss={:.2f}, time={:.2f}'.format(loss, t2-t1))
 
-            sparse_model = keras.Model(inputs=(inputs, filter_inputs), outputs=output)
+        y_true = tf.cast(y_true, tf.int32)
+        y_pred = tf.cast(y_pred, tf.int32)
 
-            epoch_loss = 0
+        f1 = f1_score(y_true, y_pred, average='macro')
+        accuracy = accuracy_score(y_true, y_pred)
 
-            y_true = None
-            y_pred = None
+        print("Test f1={:.2f}, vid_acc={:.2f}".format(f1, accuracy))
 
-            pred_dict = {}
-            gt_dict = {}
+        print(confusion_matrix(y_true, y_pred))
+    elif args.dataset == 'ptx':
+        epoch_loss = 0
 
-            t1 = time.perf_counter()
-            for labels, local_batch, vid_f in test_loader:
-                images = local_batch.permute(0, 2, 3, 4, 1).numpy()
+        y_true = None
+        y_pred = None
 
-                torch_labels = np.zeros(len(labels))
-                torch_labels[[i for i in range(len(labels)) if labels[i] == 'Positives']] = 1
-                torch_labels = np.expand_dims(torch_labels, axis=1)
+        pred_dict = {}
+        gt_dict = {}
 
-                activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.trainable_weights[0], axis=0))]))
+        t1 = time.perf_counter()
+        for labels, local_batch, vid_f in test_loader:
+            images = local_batch.permute(0, 2, 3, 4, 1).numpy()
 
-                pred = classifier_model(activations)
+            torch_labels = np.zeros(len(labels))
+            torch_labels[[i for i in range(len(labels)) if labels[i] == positive_class]] = 1
+            torch_labels = np.expand_dims(torch_labels, axis=1)
 
-                loss = criterion(torch_labels, pred)
-                epoch_loss += loss * local_batch.size(0)
+            activations = tf.stop_gradient(sparse_model([images, tf.stop_gradient(tf.expand_dims(recon_model.weights[0], axis=0))]))
 
-                for i, v_f in enumerate(vid_f):
-                    final_pred = tf.math.round(pred[i])[0]
-                    gt = torch_labels[i]
-                    
-                    overall_true.append(gt)
-                    overall_pred.append(final_pred)
-                
-                    if final_pred != gt:
-                        if final_pred == 0:
-                            fn_ids.append(v_f)
-                        else:
-                            fp_ids.append(v_f)
+            pred = classifier_model(activations)
 
-                if y_true is None:
-                    y_true = torch_labels
-                    y_pred = tf.math.round(tf.math.sigmoid(pred))
+            loss = criterion(torch_labels, pred)
+            epoch_loss += loss * local_batch.size(0)
+
+            for i, v_f in enumerate(vid_f):
+                if v_f not in pred_dict:
+                    pred_dict[v_f] = tf.math.round(tf.math.sigmoid(pred[i]))
                 else:
-                    y_true = tf.concat((y_true, torch_labels), axis=0)
-                    y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0)
+                    pred_dict[v_f] = tf.concat((pred_dict[v_f], tf.math.round(tf.math.sigmoid(pred[i]))), axis=0)
 
-            t2 = time.perf_counter()
+                if v_f not in gt_dict:
+                    gt_dict[v_f] = torch_labels[i]
+                else:
+                    gt_dict[v_f] = tf.concat((gt_dict[v_f], torch_labels[i]), axis=0)
 
-            print('fold={}, loss={:.2f}, time={:.2f}'.format(i_fold, loss, t2-t1))
-                
-            y_true = tf.cast(y_true, tf.int32)
-            y_pred = tf.cast(y_pred, tf.int32)
+            if y_true is None:
+                y_true = torch_labels
+                y_pred = tf.math.round(tf.math.sigmoid(pred))
+            else:
+                y_true = tf.concat((y_true, torch_labels), axis=0)
+                y_pred = tf.concat((y_pred, tf.math.round(tf.math.sigmoid(pred))), axis=0)
+
+        t2 = time.perf_counter()
+
+        vid_acc = []
+        for k in pred_dict.keys():
+            gt_tmp = torch.tensor(gt_dict[k].numpy())
+            pred_tmp = torch.tensor(pred_dict[k].numpy())
+
+            gt_mode = torch.mode(gt_tmp)[0].item()
+#             perm = torch.randperm(pred_tmp.size(0))
+#             cutoff = int(pred_tmp.size(0)/4)
+#             if cutoff < 3:
+#                 cutoff = 3
+#             idx = perm[:cutoff]
+#             samples = pred_tmp[idx]
+            pred_mode = torch.mode(pred_tmp)[0].item()
+            overall_true.append(gt_mode)
+            overall_pred.append(pred_mode)
+            if pred_mode == gt_mode:
+                vid_acc.append(1)
+            else:
+                vid_acc.append(0)
+                if pred_mode == 0:
+                    fn_ids.append(k)
+                else:
+                    fp_ids.append(k)
 
-            f1 = f1_score(y_true, y_pred, average='macro')
-            accuracy = accuracy_score(y_true, y_pred)
+        vid_acc = np.array(vid_acc)
+
+        print('----------------------------------------------------------------------------')
+        for k in pred_dict.keys():
+            print(k)
+            print('Predictions:')
+            print(pred_dict[k])
+            print('Ground Truth:')
+            print(gt_dict[k])
+            print('Overall Prediction:')
+            print(torch.mode(torch.tensor(pred_dict[k].numpy()))[0].item())
+            print('----------------------------------------------------------------------------')
 
-            print("Test f1={:.2f}, clip_acc={:.2f}, fold={}".format(f1, accuracy, i_fold))
+        print('loss={:.2f}, time={:.2f}'.format(loss, t2-t1))
 
-            print(confusion_matrix(y_true, y_pred))
+        y_true = tf.cast(y_true, tf.int32)
+        y_pred = tf.cast(y_pred, tf.int32)
 
-        i_fold = i_fold + 1
+        f1 = f1_score(y_true, y_pred, average='macro')
+        accuracy = accuracy_score(y_true, y_pred)
+        all_errors.append(np.sum(vid_acc) / len(vid_acc))
+
+        print("Test f1={:.2f}, clip_acc={:.2f}, vid_acc={:.2f}".format(f1, accuracy, np.sum(vid_acc) / len(vid_acc)))
+
+        print(confusion_matrix(y_true, y_pred))
 
     fp_fn_file = os.path.join(args.output_dir, 'fp_fn.txt')
     with open(fp_fn_file, 'w+') as in_f:
@@ -265,13 +408,3 @@ if __name__ == "__main__":
         in_f.write(str(fp_ids) + '\n\n')
         in_f.write('FN:\n')
         in_f.write(str(fn_ids) + '\n\n')
-
-    overall_true = np.array(overall_true)
-    overall_pred = np.array(overall_pred)
-
-    final_f1 = f1_score(overall_true, overall_pred, average='macro')
-    final_acc = accuracy_score(overall_true, overall_pred)
-    final_conf = confusion_matrix(overall_true, overall_pred)
-
-    print("Final accuracy={:.2f}, f1={:.2f}".format(final_acc, final_f1))
-    print(final_conf)
diff --git a/sparse_coding_torch/video_loader.py b/sparse_coding_torch/video_loader.py
index 96ff478fa5320f5d62b1e0b9bf3e6554eaa57fdf..c0a6bab61a4a5d21a35be424d9d7cb75b8f52011 100644
--- a/sparse_coding_torch/video_loader.py
+++ b/sparse_coding_torch/video_loader.py
@@ -22,8 +22,9 @@ import torchvision.transforms.functional as tv_f
 import csv
 import random
 import cv2
+from yolov4.get_bounding_boxes import YoloModel
 
-def get_video_participants():
+def get_ptx_participants():
     video_to_participant = {}
     with open('/shared_data/bamc_data/bamc_video_info.csv', 'r') as csv_in:
         reader = csv.DictReader(csv_in)
@@ -35,6 +36,9 @@ def get_video_participants():
             
     return video_to_participant
 
+def get_pnb_participants(filenames):
+    return [f.split('/')[-2] for f in filenames]
+
 class MinMaxScaler(object):
     """
     Transforms each channel to the range [0, 1].
@@ -59,45 +63,210 @@ class VideoGrayScaler(nn.Module):
         # print(video.shape)
         return video
     
+def load_pnb_region_labels(file_path):
+    all_regions = {}
+    with open(file_path, newline='') as csv_in:
+        reader = csv.DictReader(csv_in)
+        for row in reader:
+            idx = row['idx']
+            positive_regions = row['positive_regions'].strip()
+            negative_regions = row['negative_regions'].strip()
+            
+            all_regions[idx] = (negative_regions, positive_regions)
+            
+        return all_regions
+    
+def get_yolo_regions(yolo_model, clip):
+    orig_height = clip.size(2)
+    orig_width = clip.size(3)
+    bounding_boxes = yolo_model.get_bounding_boxes(clip[:, 2, :, :].swapaxes(0, 2).swapaxes(0, 1).numpy()).squeeze(0)
+    
+    all_clips = []
+    for bb in bounding_boxes:
+        center_x = (bb[3] + bb[1]) / 2 * orig_width
+        center_y = (bb[2] + bb[0]) / 2 * orig_height
+
+        width_left = 400
+        width_right = 400
+        height_top = 200
+        height_bottom = 50
+
+        lower_y = round(center_y - height_top)
+        upper_y = round(center_y + height_bottom)
+        lower_x = round(center_x - width_left)
+        upper_x = round(center_x + width_right)
+
+        trimmed_clip = clip[:, :, lower_y:upper_y, lower_x:upper_x]
+        
+#         print(trimmed_clip.size())
+        
+#         cv2.imwrite('test.png', clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
+#         cv2.imwrite('test_yolo.png', trimmed_clip.numpy()[:, 0, :, :].swapaxes(0,1).swapaxes(1,2))
+#         raise Exception
+        
+        if trimmed_clip.shape[2] == 0 or trimmed_clip.shape[3] == 0:
+            continue
+        all_clips.append(trimmed_clip)
+
+    return all_clips
+                
+    
 class PNBLoader(Dataset):
     
-    def __init__(self, video_path, classify_mode=False, num_frames=5, frame_rate=20, frames_between_clips=None, transform=None, augmentation=None):
+    def __init__(self, video_path, classify_mode=False, balance_classes=False, num_frames=5, frame_rate=20, frames_between_clips=None, transform=None, augmentation=None):
         self.transform = transform
         self.augmentation = augmentation
         self.labels = [name for name in listdir(video_path) if isdir(join(video_path, name))]
         
+        clip_cache_file = 'clip_cache_pnb.pt'
+        clip_cache_final_file = 'clip_cache_pnb_final.pt'
+        
+        region_labels = load_pnb_region_labels(join(video_path, 'sme_region_labels.csv'))
+
+        yolo_model = YoloModel()
+        
         self.videos = []
         for label in self.labels:
             self.videos.extend([(label, abspath(join(video_path, label, f)), f) for f in glob.glob(join(video_path, label, '*', '*.mp4'))])
+        
+#         self.videos = list(filter(lambda x: x[1].split('/')[-2] in ['67', '94', '134', '193', '222', '240'], self.videos))
+#         self.videos = list(filter(lambda x: x[1].split('/')[-2] in ['67'], self.videos))
+
             
         if not frames_between_clips:
             frames_between_clips = num_frames
             
         self.clips = []
-                   
-        self.video_idx = []
+        self.final_clips = {}
         
-        vid_idx = 0
-        for _, path, _ in tqdm(self.videos):
-            vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2)
-            if classify_mode:
-                if vc.size(1) < 5:
-                    continue
-                vc_sub = vc[:, -5:, :, :]
-                if self.transform:
-                    vc_sub = self.transform(vc_sub)
+        if exists(clip_cache_file):
+            self.clips = torch.load(open(clip_cache_file, 'rb'))
+            self.final_clips = torch.load(open(clip_cache_final_file, 'rb'))
+        else:
+            vid_idx = 0
+            for label, path, _ in tqdm(self.videos):
+                vc = tv.io.read_video(path)[0].permute(3, 0, 1, 2)
+                if classify_mode:
+                    person_idx = path.split('/')[-2]
+
+                    if vc.size(1) < 5:
+                        continue
+
+                    if label == 'Positives' and person_idx in region_labels:
+                        negative_regions, positive_regions = region_labels[person_idx]
+                        for sub_region in negative_regions.split(','):
+                            sub_region = sub_region.split('-')
+                            start_loc = int(sub_region[0])
+                            end_loc = int(sub_region[1]) + 1
+                            for frame in range(start_loc, end_loc - 5, 5):
+                                vc_sub = vc[:, frame:frame+5, :, :]
+                                if vc_sub.size(1) < 5:
+                                    continue
+                                    
+#                                 cv2.imwrite('test.png', vc_sub[0, 0, :, :].unsqueeze(2).numpy())
+
+                                for clip in get_yolo_regions(yolo_model, vc_sub):
+                                    if self.transform:
+                                        clip = self.transform(clip)
+                                        
+#                                     print(clip[0, 0, :, :].size())
+#                                     cv2.imwrite('test_yolo.png', clip[0, 0, :, :].unsqueeze(2).numpy())
+#                                     raise Exception
+
+                                    self.clips.append(('Negatives', clip, self.videos[vid_idx][2]))
+
+                        if positive_regions:
+                            for sub_region in positive_regions.split(','):
+                                sub_region = sub_region.split('-')
+                                start_loc = int(sub_region[0])
+                                if len(sub_region) == 1:
+                                    vc_sub = vc[:, start_loc:start_loc+5, :, :]
+                                    if vc_sub.size(1) < 5:
+                                        continue
+                                        
+                                    for clip in get_yolo_regions(yolo_model, vc_sub):
+                                        if self.transform:
+                                            clip = self.transform(clip)
+
+                                        self.clips.append(('Positives', clip, self.videos[vid_idx][2]))
+                                else:
+                                    end_loc = sub_region[1]
+                                    if end_loc.strip().lower() == 'end':
+                                        end_loc = vc.size(1)
+                                    else:
+                                        end_loc = int(end_loc)
+                                    for frame in range(start_loc, end_loc - 5, 5):
+                                        vc_sub = vc[:, frame:frame+5, :, :]
+#                                         cv2.imwrite('test.png', vc_sub[0, 0, :, :].unsqueeze(2).numpy())
+                                        if vc_sub.size(1) < 5:
+                                            continue
+                                        for clip in get_yolo_regions(yolo_model, vc_sub):
+                                            if self.transform:
+                                                clip = self.transform(clip)
+                                                
+#                                             cv2.imwrite('test_yolo.png', clip[0, 0, :, :].unsqueeze(2).numpy())
+#                                             raise Exception
+
+                                            self.clips.append(('Positives', clip, self.videos[vid_idx][2]))
+                    elif label == 'Positives':
+                        vc_sub = vc[:, -5:, :, :]
+                        if vc_sub.size(1) < 5:
+                            continue
+                        for clip in get_yolo_regions(yolo_model, vc_sub):
+                            if self.transform:
+                                clip = self.transform(clip)
+
+                            self.clips.append((self.videos[vid_idx][0], clip, self.videos[vid_idx][2]))
+                    elif label == 'Negatives':
+                        for j in range(0, vc.size(1) - 5, 5):
+                            vc_sub = vc[:, j:j+5, :, :]
+                            if vc_sub.size(1) < 5:
+                                continue
+                            for clip in get_yolo_regions(yolo_model, vc_sub):
+                                if self.transform:
+                                    clip = self.transform(clip)
+
+                                self.clips.append((self.videos[vid_idx][0], clip, self.videos[vid_idx][2]))
+                    else:
+                        raise Exception('Invalid label')
+                else:
+                    for j in range(0, vc.size(1) - 5, 5):
+                        vc_sub = vc[:, j:j+5, :, :]
+                        if vc_sub.size(1) < 5:
+                            continue
+                        if self.transform:
+                            vc_sub = self.transform(vc_sub)
+
+                        self.clips.append((self.videos[vid_idx][0], vc_sub, self.videos[vid_idx][2]))
+
+                self.final_clips[self.videos[vid_idx][2]] = self.clips[-1]
+                vid_idx += 1
+                
+            torch.save(self.clips, open(clip_cache_file, 'wb+'))
+            torch.save(self.final_clips, open(clip_cache_final_file, 'wb+'))
+            
+        num_positive = len([clip[0] for clip in self.clips if clip[0] == 'Positives'])
+        num_negative = len([clip[0] for clip in self.clips if clip[0] == 'Negatives'])
+        
+        random.shuffle(self.clips)
+        
+        if balance_classes:
+            new_clips = []
+            count_negative = 0
+            for clip in self.clips:
+                if clip[0] == 'Negatives':
+                    if count_negative < num_positive:
+                        new_clips.append(clip)
+                    count_negative += 1
+                else:
+                    new_clips.append(clip)
                     
-                self.clips.append((self.videos[vid_idx][0], vc_sub, self.videos[vid_idx][2]))
-                self.video_idx.append(vid_idx)
-            else:
-                for j in range(0, vc.size(1) - 5, 5):
-                    vc_sub = vc[:, j:j+5, :, :]
-                    if self.transform:
-                        vc_sub = self.transform(vc_sub)
-
-                    self.clips.append((self.videos[vid_idx][0], vc_sub, self.videos[vid_idx][2]))
-                    self.video_idx.append(vid_idx)
-            vid_idx += 1
+            self.clips = new_clips
+            num_positive = len([clip[0] for clip in self.clips if clip[0] == 'Positives'])
+            num_negative = len([clip[0] for clip in self.clips if clip[0] == 'Negatives'])
+        
+        print('Loaded', num_positive, 'positive examples.')
+        print('Loaded', num_negative, 'negative examples.')
         
     def get_filenames(self):
         return [self.clips[i][2] for i in range(len(self.clips))]
@@ -108,6 +277,9 @@ class PNBLoader(Dataset):
     def get_labels(self):
         return [self.clips[i][0] for i in range(len(self.clips))]
     
+    def get_final_clips(self):
+        return self.final_clips
+    
     def __getitem__(self, index):
         label, clip, vid_f = self.clips[index]
         if self.augmentation:
@@ -218,7 +390,7 @@ class YoloClipLoader(Dataset):
     #             video_to_clips[video].append(clip)
                 video_to_labels[video].append(lbl)
 
-            video_to_participants = get_video_participants()
+            video_to_participants = get_ptx_participants()
             participants_to_video = {}
             for k, v in video_to_participants.items():
                 if video_to_labels[k][0] == 'Sliding':