Skip to content
Snippets Groups Projects
Commit 6c365931 authored by hannandarryl's avatar hannandarryl
Browse files

fixed padding

parent 4d1cd450
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,12 @@ import torch.nn as nn
from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
def load_pytorch_weights(file_path):
pytorch_checkpoint = torch.load(file_path, map_location='cpu')
weight_tensor = pytorch_checkpoint['model_state_dict']['filters'].swapaxes(1,3).swapaxes(2,4).swapaxes(0,4).numpy()
return weight_tensor
@tf.function
def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, stride):
out_1 = tf.nn.conv2d_transpose(activations, filters_1, output_shape=(batch_size, 100, 200, 1), strides=stride)
......@@ -21,8 +27,9 @@ def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations,
return recon
@tf.function
# @tf.function
def do_recon_3d(filters, activations, batch_size, stride):
activations = tf.pad(activations, paddings=[[0,0], [2, 2], [0, 0], [0, 0], [0, 0]])
recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, 100, 200, 1), strides=[1, stride, stride])
return recon
......@@ -39,9 +46,10 @@ def conv_error(filters_1, filters_2, filters_3, filters_4, filters_5, e, stride)
return g
@tf.function
# @tf.function
def conv_error_3d(filters, e, stride):
g = tf.nn.conv3d(e, filters, strides=[1, 1, stride, stride, 1], padding='SAME')
e = tf.pad(e, paddings=[[0,0], [0, 0], [7, 7], [7, 7], [0, 0]])
g = tf.nn.conv3d(e, filters, strides=[1, 1, stride, stride, 1], padding='VALID')
return g
......@@ -54,7 +62,7 @@ def normalize_weights(filters, out_channels):
return adjusted
@tf.function
# @tf.function
def normalize_weights_3d(filters, out_channels):
norms = tf.norm(tf.reshape(filters[0], (out_channels, -1)), axis=1)
norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape)
......@@ -117,9 +125,9 @@ class SparseCode(keras.layers.Layer):
m = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
v = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
else:
u = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
m = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
v = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
u = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels))
m = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels))
v = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels))
# tf.print('activations before:', tf.reduce_sum(u))
......@@ -211,3 +219,41 @@ class Classifier(keras.layers.Layer):
x = self.ff_4(x)
return x
class MobileModel(keras.Model):
def __init__(self, sparse_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d):
super().__init__()
self.sparse_code = SparseCode(batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d)
self.classifier = Classifier()
self.out_channels = out_channels
self.in_channels = in_channels
self.stride = stride
self.lam = lam
self.activation_lr = activation_lr
self.max_activation_iter = max_activation_iter
self.batch_size = batch_size
self.run_2d = run_2d
pytorch_weights = load_pytorch_weights(sparse_checkpoint)
if run_2d:
weight_list = np.split(pytorch_weights, 5, axis=0)
self.filters_1 = tf.Variable(initial_value=weight_list[0].squeeze(0), dtype='float32', trainable=False)
self.filters_2 = tf.Variable(initial_value=weight_list[1].squeeze(0), dtype='float32', trainable=False)
self.filters_3 = tf.Variable(initial_value=weight_list[2].squeeze(0), dtype='float32', trainable=False)
self.filters_4 = tf.Variable(initial_value=weight_list[3].squeeze(0), dtype='float32', trainable=False)
self.filters_5 = tf.Variable(initial_value=weight_list[4].squeeze(0), dtype='float32', trainable=False)
else:
self.filters = tf.Variable(initial_value=pytorch_weights, dtype='float32', trainable=False)
@tf.function
def call(self, images):
if self.run_2d:
activations = self.sparse_code(images, [tf.stop_gradient(self.filters_1), tf.stop_gradient(self.filters_2), tf.stop_gradient(self.filters_3), tf.stop_gradient(self.filters_4), tf.stop_gradient(self.filters_5)])
else:
activations = self.sparse_code(images, tf.stop_gradient(self.filters))
pred = self.classifier(activations)
return pred
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment