Skip to content
Snippets Groups Projects
Commit d59500a8 authored by Darryl Hannan's avatar Darryl Hannan
Browse files

pushing code for debugging

parent bd43054c
Branches
No related tags found
No related merge requests found
File added
......@@ -8,49 +8,56 @@ import torch
import torch.nn as nn
from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
from sparse_coding_torch.small_data_classifier import SmallDataClassifierConv3d
from keras_model import MobileModel
inputs = keras.Input(shape=(100, 200, 5))
inputs = keras.Input(shape=(5, 100, 200, 3))
outputs = MobileModel(sparse_checkpoint='../sparse.pt', batch_size=1, in_channels=1, out_channels=64, kernel_size=15, stride=1, lam=0.05, activation_lr=1e-2, max_activation_iter=40, run_2d=True)(inputs)
outputs = MobileModel(sparse_checkpoint='../sparse.pt', batch_size=1, in_channels=1, out_channels=64, kernel_size=15, stride=2, lam=0.05, activation_lr=1e-1, max_activation_iter=100, run_2d=True)(inputs)
# outputs = tf.math.add(inputs, 1)
model = keras.Model(inputs=inputs, outputs=outputs)
pytorch_checkpoint = torch.load('../output/final_model_75_iter/model-best_fold_0.pt', map_location='cpu')['model_state_dict']
conv_weights = [pytorch_checkpoint['module.compress_activations_conv_1.weight'].view(8, 8, 64, 24).numpy(), pytorch_checkpoint['module.compress_activations_conv_1.bias'].numpy()]
pytorch_checkpoint = torch.load('../stride_2_100_iter.pt', map_location='cpu')['model_state_dict']
conv_weights = [pytorch_checkpoint['module.compress_activations_conv_1.weight'].squeeze(2).swapaxes(0, 2).swapaxes(1, 3).swapaxes(2, 3).numpy(), pytorch_checkpoint['module.compress_activations_conv_1.bias'].numpy()]
model.get_layer('mobile_model').classifier.conv.set_weights(conv_weights)
ff_1_weights = [pytorch_checkpoint['module.fc1.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc1.bias'].numpy()]
model.get_layer('mobile_model').classifier.ff_1.set_weights(ff_1_weights)
ff_2_weights = [pytorch_checkpoint['module.fc2.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc2.bias'].numpy()]
model.get_layer('mobile_model').classifier.ff_2.set_weights(ff_2_weights)
ff_3_weights = [pytorch_checkpoint['module.fc3.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc3.bias'].numpy()]
# # ff_1_weights = [pytorch_checkpoint['module.fc1.weight'].swapaxes(1,0).numpy(), pytorch_checkpoint['module.fc1.bias'].numpy()]
# # model.get_layer('mobile_model').classifier.ff_1.set_weights(ff_1_weights)
# # ff_2_weights = [pytorch_checkpoint['module.fc2.weight'].swapaxes(1,0).numpy(), pytorch_checkpoint['module.fc2.bias'].numpy()]
# # model.get_layer('mobile_model').classifier.ff_2.set_weights(ff_2_weights)
ff_3_weights = [pytorch_checkpoint['module.fc3.weight'].swapaxes(1,0).numpy(), pytorch_checkpoint['module.fc3.bias'].numpy()]
model.get_layer('mobile_model').classifier.ff_3.set_weights(ff_3_weights)
ff_4_weights = [pytorch_checkpoint['module.fc4.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc4.bias'].numpy()]
ff_4_weights = [pytorch_checkpoint['module.fc4.weight'].swapaxes(1,0).numpy(), pytorch_checkpoint['module.fc4.bias'].numpy()]
model.get_layer('mobile_model').classifier.ff_4.set_weights(ff_4_weights)
# frozen_sparse = ConvSparseLayer(in_channels=1,
# out_channels=64,
# kernel_size=(5, 15, 15),
# stride=1,
# padding=(0, 7, 7),
# stride=2,
# padding=0,
# convo_dim=3,
# rectifier=True,
# lam=0.05,
# max_activation_iter=10,
# activation_lr=1e-2)
# max_activation_iter=100,
# activation_lr=1e-1)
#
# sparse_param = torch.load('../sparse.pt', map_location='cpu')
# frozen_sparse.load_state_dict(sparse_param['model_state_dict'])
#
# # pytorch_filter = frozen_sparse.filters[30, :, 0, :, :].squeeze(0).unsqueeze(2).detach().numpy()
# # keras_filter = model.get_layer('sparse_code').filter[0,:,:,:,30].numpy()
# #
# # cv2.imwrite('pytorch_filter.png', pytorch_filter / np.max(pytorch_filter) * 255.)
# # cv2.imwrite('keras_filter.png', keras_filter / np.max(keras_filter) * 255.)
# # raise Exception
# predictive_model = SmallDataClassifierConv3d()
# classifier_param = {k.replace('module.', ''): v for k,v in torch.load('../stride_2_100_iter.pt', map_location='cpu')['model_state_dict'].items()}
# predictive_model.load_state_dict(classifier_param)
#
# predictive_model.eval()
# #
# # # pytorch_filter = frozen_sparse.filters[30, :, 0, :, :].squeeze(0).unsqueeze(2).detach().numpy()
# # # keras_filter = model.get_layer('sparse_code').filter[0,:,:,:,30].numpy()
# # #
# # # cv2.imwrite('pytorch_filter.png', pytorch_filter / np.max(pytorch_filter) * 255.)
# # # cv2.imwrite('keras_filter.png', keras_filter / np.max(keras_filter) * 255.)
# # # raise Exception
# #
# img = tv.io.read_video('../clips/No_Sliding/Image_262499828648_clean1050.mp4')[0].permute(3, 0, 1, 2)
# transform = tv.transforms.Compose(
# [VideoGrayScaler(),
......@@ -59,12 +66,13 @@ model.get_layer('mobile_model').classifier.ff_4.set_weights(ff_4_weights)
# tv.transforms.CenterCrop((100, 200))
# ])
# img = transform(img)
#
# with torch.no_grad():
# activations = frozen_sparse(img.unsqueeze(0))
# activations, _ = predictive_model(frozen_sparse(img.unsqueeze(0)).squeeze(2))
# activations = torch.nn.Sigmoid()(activations)
#
# output = model(img.swapaxes(1, 3).swapaxes(1,2).numpy())
#
# print(activations.size())
# print(output.shape)
# print(torch.sum(activations))
......@@ -72,7 +80,7 @@ model.get_layer('mobile_model').classifier.ff_4.set_weights(ff_4_weights)
input_name = model.input_names[0]
index = model.input_names.index(input_name)
model.inputs[index].set_shape([1, 100, 200, 5])
model.inputs[index].set_shape([1, 5, 100, 200, 3])
converter = tf.lite.TFLiteConverter.from_keras_model(model)
# converter.experimental_new_converter = True
......
......@@ -17,17 +17,17 @@ def load_pytorch_weights(file_path):
@tf.function
def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, stride):
out_1 = tf.nn.conv2d_transpose(activations, filters_1, output_shape=(batch_size, 100, 200, 1), strides=stride)
out_2 = tf.nn.conv2d_transpose(activations, filters_2, output_shape=(batch_size, 100, 200, 1), strides=stride)
out_3 = tf.nn.conv2d_transpose(activations, filters_3, output_shape=(batch_size, 100, 200, 1), strides=stride)
out_4 = tf.nn.conv2d_transpose(activations, filters_4, output_shape=(batch_size, 100, 200, 1), strides=stride)
out_5 = tf.nn.conv2d_transpose(activations, filters_5, output_shape=(batch_size, 100, 200, 1), strides=stride)
out_1 = tf.nn.conv2d_transpose(activations, filters_1, output_shape=(batch_size, 100, 200, 1), strides=stride, padding='VALID')
out_2 = tf.nn.conv2d_transpose(activations, filters_2, output_shape=(batch_size, 100, 200, 1), strides=stride, padding='VALID')
out_3 = tf.nn.conv2d_transpose(activations, filters_3, output_shape=(batch_size, 100, 200, 1), strides=stride, padding='VALID')
out_4 = tf.nn.conv2d_transpose(activations, filters_4, output_shape=(batch_size, 100, 200, 1), strides=stride, padding='VALID')
out_5 = tf.nn.conv2d_transpose(activations, filters_5, output_shape=(batch_size, 100, 200, 1), strides=stride, padding='VALID')
recon = tf.concat([out_1, out_2, out_3, out_4, out_5], axis=3)
return recon
# @tf.function
@tf.function
def do_recon_3d(filters, activations, batch_size, stride):
activations = tf.pad(activations, paddings=[[0,0], [2, 2], [0, 0], [0, 0], [0, 0]])
recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, 100, 200, 1), strides=[1, stride, stride])
......@@ -35,18 +35,13 @@ def do_recon_3d(filters, activations, batch_size, stride):
return recon
@tf.function
def conv_error(filters_1, filters_2, filters_3, filters_4, filters_5, e, stride):
e1, e2, e3, e4, e5 = tf.split(e, 5, axis=3)
def conv_error(filter, e, stride):
g = tf.nn.conv2d(e1, filters_1, strides=stride, padding='SAME')
g = g + tf.nn.conv2d(e2, filters_2, strides=stride, padding='SAME')
g = g + tf.nn.conv2d(e3, filters_3, strides=stride, padding='SAME')
g = g + tf.nn.conv2d(e4, filters_4, strides=stride, padding='SAME')
g = g + tf.nn.conv2d(e5, filters_5, strides=stride, padding='SAME')
g = tf.nn.conv2d(e, filter, strides=stride, padding='VALID')
return g
# @tf.function
@tf.function
def conv_error_3d(filters, e, stride):
e = tf.pad(e, paddings=[[0,0], [0, 0], [7, 7], [7, 7], [0, 0]])
g = tf.nn.conv3d(e, filters, strides=[1, 1, stride, stride, 1], padding='VALID')
......@@ -65,7 +60,7 @@ def normalize_weights(filters, out_channels):
return adjusted
# @tf.function
@tf.function
def normalize_weights_3d(filters, out_channels):
#for f in filters:
# print('filters 3d shape', f.shape)
......@@ -88,6 +83,7 @@ class SparseCode(keras.layers.Layer):
self.out_channels = out_channels
self.in_channels = in_channels
self.kernel_size = kernel_size
self.stride = stride
self.lam = lam
self.activation_lr = activation_lr
......@@ -108,7 +104,12 @@ class SparseCode(keras.layers.Layer):
g = -1 * u
if self.run_2d:
convd_error = conv_error(filters[0], filters[1], filters[2], filters[3], filters[4], e, self.stride)
e1, e2, e3, e4, e5 = tf.split(e, 5, axis=3)
g += conv_error(filters[0], e1, self.stride)
g += conv_error(filters[1], e2, self.stride)
g += conv_error(filters[2], e3, self.stride)
g += conv_error(filters[3], e4, self.stride)
g += conv_error(filters[4], e5, self.stride)
else:
convd_error = conv_error_3d(filters, e, self.stride)
......@@ -132,9 +133,9 @@ class SparseCode(keras.layers.Layer):
@tf.function
def call(self, images, filters):
if self.run_2d:
u = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
m = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
v = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
u = tf.zeros(shape=(self.batch_size, 100 // self.stride - self.kernel_size // self.stride, 200 // self.stride - self.kernel_size // self.stride, self.out_channels))
m = tf.zeros(shape=(self.batch_size, 100 // self.stride - self.kernel_size // self.stride, 200 // self.stride - self.kernel_size // self.stride, self.out_channels))
v = tf.zeros(shape=(self.batch_size, 100 // self.stride - self.kernel_size // self.stride, 200 // self.stride - self.kernel_size // self.stride, self.out_channels))
else:
u = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels))
m = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels))
......@@ -205,14 +206,14 @@ class Classifier(keras.layers.Layer):
super(Classifier, self).__init__()
self.max_pool = keras.layers.MaxPooling2D(pool_size=4, strides=4)
self.conv = keras.layers.Conv2D(24, kernel_size=8, strides=4, activation='relu', padding='SAME')
self.conv = keras.layers.Conv2D(24, kernel_size=8, strides=4, activation='relu', padding='valid')
self.flatten = keras.layers.Flatten()
self.dropout = keras.layers.Dropout(0.5)
self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True)
self.ff_2 = keras.layers.Dense(100, activation='relu', use_bias=True)
# self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True)
# self.ff_2 = keras.layers.Dense(100, activation='relu', use_bias=True)
self.ff_3 = keras.layers.Dense(20, activation='relu', use_bias=True)
self.ff_4 = keras.layers.Dense(1, activation='sigmoid')
......@@ -221,10 +222,10 @@ class Classifier(keras.layers.Layer):
x = self.max_pool(activations)
x = self.conv(x)
x = self.flatten(x)
x = self.ff_1(x)
x = self.dropout(x)
x = self.ff_2(x)
x = self.dropout(x)
# # x = self.ff_1(x)
# # x = self.dropout(x)
# # x = self.ff_2(x)
# # x = self.dropout(x)
x = self.ff_3(x)
x = self.dropout(x)
x = self.ff_4(x)
......@@ -260,6 +261,11 @@ class MobileModel(keras.Model):
@tf.function
def call(self, images):
images = tf.squeeze(tf.image.rgb_to_grayscale(images), axis=-1)
images = tf.transpose(images, perm=[0, 2, 3, 1])
images = images / 255
images = (images - 0.2592) / 0.1251
if self.run_2d:
activations = self.sparse_code(images, [tf.stop_gradient(self.filters_1), tf.stop_gradient(self.filters_2), tf.stop_gradient(self.filters_3), tf.stop_gradient(self.filters_4), tf.stop_gradient(self.filters_5)])
else:
......
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment