Skip to content
Snippets Groups Projects
Commit 092b5e42 authored by hannandarryl's avatar hannandarryl
Browse files

added updates to keras model to run in 2d and 3d

parent 3d931dab
Branches
Tags
No related merge requests found
......@@ -23,7 +23,7 @@ def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations,
@tf.function
def do_recon_3d(filters, activations, batch_size, stride):
recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, 100, 200, 1), strides=stride)
recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, 100, 200, 1), strides=[1, stride, stride])
return recon
......@@ -41,7 +41,7 @@ def conv_error(filters_1, filters_2, filters_3, filters_4, filters_5, e, stride)
@tf.function
def conv_error_3d(filters, e, stride):
g = tf.nn.conv3d(e, filters, strides=[stride, stride, stride, stride, stride], padding='SAME')
g = tf.nn.conv3d(e, filters, strides=[1, 1, stride, stride, 1], padding='SAME')
return g
......@@ -54,8 +54,17 @@ def normalize_weights(filters, out_channels):
return adjusted
@tf.function
def normalize_weights_3d(filters, out_channels):
norms = tf.norm(tf.reshape(filters[0], (out_channels, -1)), axis=1)
norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape)
adjusted = [f / norms for f in filters]
return adjusted
class SparseCode(keras.layers.Layer):
def __init__(self, pytorch_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter):
def __init__(self, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d):
super(SparseCode, self).__init__()
self.out_channels = out_channels
......@@ -65,31 +74,24 @@ class SparseCode(keras.layers.Layer):
self.activation_lr = activation_lr
self.max_activation_iter = max_activation_iter
self.batch_size = batch_size
# pytorch_checkpoint = torch.load(pytorch_checkpoint, map_location='cpu')
# weight_tensor = pytorch_checkpoint['model_state_dict']['filters'].swapaxes(1,3).swapaxes(2,4).swapaxes(0,4)
# self.filter = tf.Variable(initial_value=weight_tensor.numpy(), dtype='float32')
# weight_list = torch.chunk(weight_tensor, 5, dim=2)
#
initializer = tf.keras.initializers.HeNormal()
self.filters_1 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
self.filters_2 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
self.filters_3 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
self.filters_4 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
weights = normalize_weights(self.get_weights(), out_channels)
self.set_weights(weights)
self.run_2d = run_2d
@tf.function
def do_update(self, images, u, m, v, b1, b2, eps, i):
def do_update(self, images, filters, u, m, v, b1, b2, eps, i):
activations = tf.nn.relu(u - self.lam)
recon = do_recon(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, activations, self.batch_size, self.stride)
if self.run_2d:
recon = do_recon(filters[0], filters[1], filters[2], filters[3], filters[4], activations, self.batch_size, self.stride)
else:
recon = do_recon_3d(filters, activations, self.batch_size, self.stride)
e = images - recon
g = -1 * u
convd_error = conv_error(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, e, self.stride)
if self.run_2d:
convd_error = conv_error(filters[0], filters[1], filters[2], filters[3], filters[4], e, self.stride)
else:
convd_error = conv_error_3d(filters, e, self.stride)
g = g + convd_error
......@@ -103,29 +105,82 @@ class SparseCode(keras.layers.Layer):
du = self.activation_lr * mh / (tf.math.sqrt(vh) + eps)
u += du
# i += 1
# return images, u, m, v, b1, b2, eps, i
return u, m, v
@tf.function
def call(self, images):
def call(self, images, filters):
if self.run_2d:
u = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
m = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
v = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
else:
u = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
m = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
v = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
tf.print('activations before:', tf.reduce_sum(u))
# tf.print('activations before:', tf.reduce_sum(u))
b1 = 0.9
b2 = 0.999
eps = 1e-8
b1 = tf.constant(0.9, dtype='float32')
b2 = tf.constant(0.999, dtype='float32')
eps = tf.constant(1e-8, dtype='float32')
# i = tf.constant(0, dtype='float32')
# c = lambda images, u, m, v, b1, b2, eps, i: tf.less(i, self.max_activation_iter)
# images, u, m, v, b1, b2, eps, i = tf.while_loop(c, self.do_update, [images, u, m, v, b1, b2, eps, i])
for i in range(self.max_activation_iter):
u, m, v = self.do_update(images, u, m, v, b1, b2, eps, i)
u, m, v = self.do_update(images, filters, u, m, v, b1, b2, eps, i)
u = tf.nn.relu(u - self.lam)
tf.print('activations after:', tf.reduce_sum(u))
# tf.print('activations after:', tf.reduce_sum(u))
return u
class SparseCodeConv(keras.Model):
def __init__(self, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d):
super().__init__()
self.sparse_code = SparseCode(batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter, run_2d)
self.out_channels = out_channels
self.in_channels = in_channels
self.stride = stride
self.lam = lam
self.activation_lr = activation_lr
self.max_activation_iter = max_activation_iter
self.batch_size = batch_size
self.run_2d = run_2d
initializer = tf.keras.initializers.HeNormal()
if run_2d:
self.filters_1 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
self.filters_2 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
self.filters_3 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
self.filters_4 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
else:
self.filters = tf.Variable(initial_value=initializer(shape=(5, kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
if run_2d:
weights = normalize_weights(self.get_weights(), out_channels)
else:
weights = normalize_weights_3d(self.get_weights(), out_channels)
self.set_weights(weights)
@tf.function
def call(self, images):
if self.run_2d:
activations = self.sparse_code(images, [tf.stop_gradient(self.filters_1), tf.stop_gradient(self.filters_2), tf.stop_gradient(self.filters_3), tf.stop_gradient(self.filters_4), tf.stop_gradient(self.filters_5)])
recon = do_recon(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, activations, self.batch_size, self.stride)
else:
activations = self.sparse_code(images, tf.stop_gradient(self.filters))
recon = do_recon_3d(self.filters, activations, self.batch_size, self.stride)
return recon, activations
class Classifier(keras.layers.Layer):
def __init__(self):
super(Classifier, self).__init__()
......
......@@ -10,7 +10,7 @@ import os
from sparse_coding_torch.load_data import load_yolo_clips
import tensorflow.keras as keras
import tensorflow as tf
from keras_model import SparseCode, do_recon, normalize_weights
from keras_model import normalize_weights_3d, normalize_weights, SparseCodeConv
def plot_video(video):
......@@ -55,11 +55,11 @@ def plot_original_vs_recon(original, reconstruction, idx=0):
def plot_filters(filters):
num_filters = filters.shape[0]
num_filters = filters.shape[4]
ncol = 3
# ncol = int(np.sqrt(num_filters))
# nrow = int(np.sqrt(num_filters))
T = filters.shape[2]
T = filters.shape[0]
if num_filters // ncol == num_filters / ncol:
nrow = num_filters // ncol
......@@ -74,7 +74,7 @@ def plot_filters(filters):
for i in range(num_filters):
r = i // ncol
c = i % ncol
ims[(r, c)] = axes[r, c].imshow(filters[i, 0, 0, :, :],
ims[(r, c)] = axes[r, c].imshow(filters[0, :, :, 0, i],
cmap=cm.Greys_r)
def update(i):
......@@ -82,29 +82,30 @@ def plot_filters(filters):
for i in range(num_filters):
r = i // ncol
c = i % ncol
ims[(r, c)].set_data(filters[i, 0, t, :, :])
ims[(r, c)].set_data(filters[t, :, :, 0, i])
return FuncAnimation(plt.gcf(), update, interval=1000/20)
def sparse_loss(filters_1, filters_2, filters_3, filters_4, filters_5, images, activations, batch_size, lam, stride):
recon = do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, stride)
def sparse_loss(recon, activations, batch_size, lam, stride):
loss = 0.5 * (1/batch_size) * tf.math.reduce_sum(tf.math.pow(images - recon, 2))
loss += lam * tf.reduce_mean(tf.math.reduce_sum(tf.math.abs(tf.reshape(activations, (batch_size, -1))), axis=1))
return loss
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=8, type=int)
parser.add_argument('--batch_size', default=6, type=int)
parser.add_argument('--kernel_size', default=15, type=int)
parser.add_argument('--num_kernels', default=64, type=int)
parser.add_argument('--stride', default=2, type=int)
parser.add_argument('--max_activation_iter', default=150, type=int)
parser.add_argument('--max_activation_iter', default=50, type=int)
parser.add_argument('--activation_lr', default=1e-2, type=float)
parser.add_argument('--lr', default=1e-5, type=float)
parser.add_argument('--lr', default=1e-2, type=float)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lam', default=0.05, type=float)
parser.add_argument('--output_dir', default='./output', type=str)
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--run_2d', action='store_true')
parser.add_argument('--save_filters', action='store_true')
args = parser.parse_args()
......@@ -122,14 +123,24 @@ if __name__ == "__main__":
example_data = next(iter(train_loader))
if args.run_2d:
inputs = keras.Input(shape=(100, 200, 5))
else:
inputs = keras.Input(shape=(5, 100, 200, 1))
output = SparseCode('../sparse.pt', batch_size=args.batch_size, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter)(inputs)
output = SparseCodeConv(batch_size=args.batch_size, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter, run_2d=args.run_2d)(inputs)
model = keras.Model(inputs=inputs, outputs=output)
if args.save_filters:
if args.run_2d:
filters = plot_filters(tf.stack(model.get_weights(), axis=0))
else:
filters = plot_filters(model.get_weights()[0])
filters.save(os.path.join(args.output_dir, 'filters_start.mp4'))
learning_rate = args.lr
filter_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
filter_optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss_log = []
best_so_far = float('inf')
......@@ -139,39 +150,39 @@ if __name__ == "__main__":
epoch_start = time.perf_counter()
for labels, local_batch, vid_f in tqdm(train_loader):
tf.print(vid_f)
if local_batch.size(0) != args.batch_size:
continue
if args.run_2d:
images = local_batch.squeeze(1).permute(0, 2, 3, 1).numpy()
else:
images = local_batch.permute(0, 2, 3, 4, 1).numpy()
with tf.GradientTape() as tape:
activations = model(images)
filters = model.get_weights()
loss = sparse_loss(filters[0], filters[1], filters[2], filters[3], filters[4], images, activations, args.batch_size, args.lam, args.stride)
# recon = model.do_recon(activations)
# loss = 0.5 * (1/batch_size) * tf.math.reduce_sum(tf.math.pow(images - recon, 2))
# loss += lam * tf.reduce_mean(tf.math.reduce_sum(tf.math.abs(tf.reshape(activations, (batch_size, -1))), axis=1))
recon, activations = model(images)
loss = sparse_loss(recon, activations, args.batch_size, args.lam, args.stride)
epoch_loss += loss * local_batch.size(0)
tf.print('loss:', loss)
gradients = tape.gradient(loss, model.trainable_weights)
tf.print('gradients:', tf.reduce_sum(gradients))
filter_optimizer.apply_gradients(zip(gradients, model.trainable_weights))
tf.print('weights:', tf.reduce_sum(model.trainable_weights))
if args.run_2d:
weights = normalize_weights(model.get_weights(), args.num_kernels)
else:
weights = normalize_weights_3d(model.get_weights(), args.num_kernels)
model.set_weights(weights)
tf.print('normalized weights:', tf.reduce_sum(model.trainable_weights))
epoch_end = time.perf_counter()
epoch_loss /= len(train_loader.sampler)
if args.save_filters and epoch % 5 == 0:
if args.run_2d:
filters = plot_filters(tf.stack(model.get_weights(), axis=0))
else:
filters = plot_filters(model.get_weights()[0])
filters.save(os.path.join(args.output_dir, 'filters_' + str(epoch) +'.mp4'))
if epoch_loss < best_so_far:
print("found better model")
# Save model parameters
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment