From 6c3659313c0929bf804e075711e86bf25a0b873c Mon Sep 17 00:00:00 2001 From: hannandarryl <hannandarryl@gmail.com> Date: Fri, 11 Feb 2022 20:29:33 +0000 Subject: [PATCH] fixed padding --- keras/keras_model.py | 60 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 53 insertions(+), 7 deletions(-) diff --git a/keras/keras_model.py b/keras/keras_model.py index 6327fe9..919dc7d 100644 --- a/keras/keras_model.py +++ b/keras/keras_model.py @@ -9,6 +9,12 @@ import torch.nn as nn from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler from sparse_coding_torch.conv_sparse_model import ConvSparseLayer +def load_pytorch_weights(file_path): + pytorch_checkpoint = torch.load(file_path, map_location='cpu') + weight_tensor = pytorch_checkpoint['model_state_dict']['filters'].swapaxes(1,3).swapaxes(2,4).swapaxes(0,4).numpy() + + return weight_tensor + @tf.function def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, stride): out_1 = tf.nn.conv2d_transpose(activations, filters_1, output_shape=(batch_size, 100, 200, 1), strides=stride) @@ -21,8 +27,9 @@ def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, return recon -@tf.function +# @tf.function def do_recon_3d(filters, activations, batch_size, stride): + activations = tf.pad(activations, paddings=[[0,0], [2, 2], [0, 0], [0, 0], [0, 0]]) recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, 100, 200, 1), strides=[1, stride, stride]) return recon @@ -39,9 +46,10 @@ def conv_error(filters_1, filters_2, filters_3, filters_4, filters_5, e, stride) return g -@tf.function +# @tf.function def conv_error_3d(filters, e, stride): - g = tf.nn.conv3d(e, filters, strides=[1, 1, stride, stride, 1], padding='SAME') + e = tf.pad(e, paddings=[[0,0], [0, 0], [7, 7], [7, 7], [0, 0]]) + g = tf.nn.conv3d(e, filters, strides=[1, 1, stride, stride, 1], padding='VALID') return g @@ -54,7 +62,7 @@ def normalize_weights(filters, out_channels): return adjusted -@tf.function +# @tf.function def normalize_weights_3d(filters, out_channels): norms = tf.norm(tf.reshape(filters[0], (out_channels, -1)), axis=1) norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape) @@ -117,9 +125,9 @@ class SparseCode(keras.layers.Layer): m = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels)) v = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels)) else: - u = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels)) - m = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels)) - v = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels)) + u = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels)) + m = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels)) + v = tf.zeros(shape=(self.batch_size, 1, 100 // self.stride, 200 // self.stride, self.out_channels)) # tf.print('activations before:', tf.reduce_sum(u)) @@ -211,3 +219,41 @@ class Classifier(keras.layers.Layer): x = self.ff_4(x) return x + +class MobileModel(keras.Model): + def __init__(self, sparse_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d): + super().__init__() + self.sparse_code = SparseCode(batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d) + self.classifier = Classifier() + + self.out_channels = out_channels + self.in_channels = in_channels + self.stride = stride + self.lam = lam + self.activation_lr = activation_lr + self.max_activation_iter = max_activation_iter + self.batch_size = batch_size + self.run_2d = run_2d + + pytorch_weights = load_pytorch_weights(sparse_checkpoint) + + if run_2d: + weight_list = np.split(pytorch_weights, 5, axis=0) + self.filters_1 = tf.Variable(initial_value=weight_list[0].squeeze(0), dtype='float32', trainable=False) + self.filters_2 = tf.Variable(initial_value=weight_list[1].squeeze(0), dtype='float32', trainable=False) + self.filters_3 = tf.Variable(initial_value=weight_list[2].squeeze(0), dtype='float32', trainable=False) + self.filters_4 = tf.Variable(initial_value=weight_list[3].squeeze(0), dtype='float32', trainable=False) + self.filters_5 = tf.Variable(initial_value=weight_list[4].squeeze(0), dtype='float32', trainable=False) + else: + self.filters = tf.Variable(initial_value=pytorch_weights, dtype='float32', trainable=False) + + @tf.function + def call(self, images): + if self.run_2d: + activations = self.sparse_code(images, [tf.stop_gradient(self.filters_1), tf.stop_gradient(self.filters_2), tf.stop_gradient(self.filters_3), tf.stop_gradient(self.filters_4), tf.stop_gradient(self.filters_5)]) + else: + activations = self.sparse_code(images, tf.stop_gradient(self.filters)) + + pred = self.classifier(activations) + + return pred \ No newline at end of file -- GitLab