Skip to content
Snippets Groups Projects
Commit 3d931dab authored by hannandarryl's avatar hannandarryl
Browse files

keras training code

parent 13ae3ca4
Branches
Tags
No related merge requests found
...@@ -9,84 +9,89 @@ import torch.nn as nn ...@@ -9,84 +9,89 @@ import torch.nn as nn
from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
from sparse_coding_torch.conv_sparse_model import ConvSparseLayer from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
class SparseCode(keras.layers.Layer):
def __init__(self, pytorch_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter):
super(SparseCode, self).__init__()
self.out_channels = out_channels
self.in_channels = in_channels
self.stride = stride
self.lam = lam
self.activation_lr = activation_lr
self.max_activation_iter = max_activation_iter
self.batch_size = batch_size
# pytorch_checkpoint = torch.load(pytorch_checkpoint, map_location='cpu')
# weight_tensor = pytorch_checkpoint['model_state_dict']['filters'].swapaxes(1,3).swapaxes(2,4).swapaxes(0,4)
# self.filter = tf.Variable(initial_value=weight_tensor.numpy(), dtype='float32')
# weight_list = torch.chunk(weight_tensor, 5, dim=2)
#
initializer = tf.keras.initializers.HeNormal()
self.filters_1 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
self.filters_2 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
self.filters_3 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
self.filters_4 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
def normalize_weights(self):
norms = tf.norm(tf.reshape(tf.stack([self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5]), (self.out_channels, -1)), axis=1, keepdims=True)
norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), self.filters_1.shape)
self.filters_1 = self.filters_1 / norms
self.filters_2 = self.filters_2 / norms
self.filters_3 = self.filters_3 / norms
self.filters_4 = self.filters_4 / norms
self.filters_5 = self.filters_5 / norms
@tf.function @tf.function
def do_recon(self, activations): def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, stride):
out_1 = tf.nn.conv2d_transpose(activations, self.filters_1, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride) out_1 = tf.nn.conv2d_transpose(activations, filters_1, output_shape=(batch_size, 100, 200, 1), strides=stride)
out_2 = tf.nn.conv2d_transpose(activations, self.filters_2, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride) out_2 = tf.nn.conv2d_transpose(activations, filters_2, output_shape=(batch_size, 100, 200, 1), strides=stride)
out_3 = tf.nn.conv2d_transpose(activations, self.filters_3, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride) out_3 = tf.nn.conv2d_transpose(activations, filters_3, output_shape=(batch_size, 100, 200, 1), strides=stride)
out_4 = tf.nn.conv2d_transpose(activations, self.filters_4, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride) out_4 = tf.nn.conv2d_transpose(activations, filters_4, output_shape=(batch_size, 100, 200, 1), strides=stride)
out_5 = tf.nn.conv2d_transpose(activations, self.filters_5, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride) out_5 = tf.nn.conv2d_transpose(activations, filters_5, output_shape=(batch_size, 100, 200, 1), strides=stride)
recon = tf.concat([out_1, out_2, out_3, out_4, out_5], axis=3) recon = tf.concat([out_1, out_2, out_3, out_4, out_5], axis=3)
return recon return recon
@tf.function @tf.function
def do_recon_3d(self, activations): def do_recon_3d(filters, activations, batch_size, stride):
recon = tf.nn.conv3d_transpose(activations, self.filter, output_shape=(self.batch_size, 5, 100, 200, 1), strides=self.stride) recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, 100, 200, 1), strides=stride)
return recon return recon
@tf.function @tf.function
def conv_error(self, e): def conv_error(filters_1, filters_2, filters_3, filters_4, filters_5, e, stride):
e1, e2, e3, e4, e5 = tf.split(e, 5, axis=3) e1, e2, e3, e4, e5 = tf.split(e, 5, axis=3)
g = tf.nn.conv2d(e1, self.filters_1, strides=self.stride, padding='SAME') g = tf.nn.conv2d(e1, filters_1, strides=stride, padding='SAME')
g = g + tf.nn.conv2d(e2, self.filters_2, strides=self.stride, padding='SAME') g = g + tf.nn.conv2d(e2, filters_2, strides=stride, padding='SAME')
g = g + tf.nn.conv2d(e3, self.filters_3, strides=self.stride, padding='SAME') g = g + tf.nn.conv2d(e3, filters_3, strides=stride, padding='SAME')
g = g + tf.nn.conv2d(e4, self.filters_4, strides=self.stride, padding='SAME') g = g + tf.nn.conv2d(e4, filters_4, strides=stride, padding='SAME')
g = g + tf.nn.conv2d(e5, self.filters_5, strides=self.stride, padding='SAME') g = g + tf.nn.conv2d(e5, filters_5, strides=stride, padding='SAME')
return g return g
@tf.function @tf.function
def conv_error_3d(self, e): def conv_error_3d(filters, e, stride):
g = tf.nn.conv3d(e, self.filter, strides=[1, 1, 1, 1, 1], padding='SAME') g = tf.nn.conv3d(e, filters, strides=[stride, stride, stride, stride, stride], padding='SAME')
return g return g
@tf.function
def normalize_weights(filters, out_channels):
norms = tf.norm(tf.reshape(tf.stack(filters), (out_channels, -1)), axis=1)
norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape)
adjusted = [f / norms for f in filters]
return adjusted
class SparseCode(keras.layers.Layer):
def __init__(self, pytorch_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter):
super(SparseCode, self).__init__()
self.out_channels = out_channels
self.in_channels = in_channels
self.stride = stride
self.lam = lam
self.activation_lr = activation_lr
self.max_activation_iter = max_activation_iter
self.batch_size = batch_size
# pytorch_checkpoint = torch.load(pytorch_checkpoint, map_location='cpu')
# weight_tensor = pytorch_checkpoint['model_state_dict']['filters'].swapaxes(1,3).swapaxes(2,4).swapaxes(0,4)
# self.filter = tf.Variable(initial_value=weight_tensor.numpy(), dtype='float32')
# weight_list = torch.chunk(weight_tensor, 5, dim=2)
#
initializer = tf.keras.initializers.HeNormal()
self.filters_1 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
self.filters_2 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
self.filters_3 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
self.filters_4 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
weights = normalize_weights(self.get_weights(), out_channels)
self.set_weights(weights)
@tf.function @tf.function
def do_update(self, images, u, m, v, b1, b2, eps, i): def do_update(self, images, u, m, v, b1, b2, eps, i):
activations = tf.nn.relu(u - self.lam) activations = tf.nn.relu(u - self.lam)
recon = self.do_recon(activations) recon = do_recon(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, activations, self.batch_size, self.stride)
e = images - recon e = images - recon
g = -1 * u g = -1 * u
g = g + self.conv_error(e) convd_error = conv_error(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, e, self.stride)
g = g + convd_error
g = g + activations g = g + activations
...@@ -94,37 +99,30 @@ class SparseCode(keras.layers.Layer): ...@@ -94,37 +99,30 @@ class SparseCode(keras.layers.Layer):
v = b2 * v + (1-b2) * g**2 v = b2 * v + (1-b2) * g**2
mh = m / (1 - b1**(1+i)) mh = m / (1 - b1**(1+i))
vh = v / (1 - b2**(1+i)) vh = v / (1 - b2**(1+i))
u = u + (self.activation_lr * mh / (tf.math.sqrt(vh) + eps))
return u, m, v du = self.activation_lr * mh / (tf.math.sqrt(vh) + eps)
u += du
def loss(self, images, activations): return u, m, v
recon = self.do_recon(activations)
loss = 0.5 * (1/images.shape[0]) * tf.sum(tf.math.pow(images - recon, 2))
loss += self.lam * tf.reduce_mean(tf.sum(tf.math.abs(activations.reshape(activations.shape[0], -1)), axis=1))
return loss
@tf.function @tf.function
def call(self, images): def call(self, images):
u = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels)) u = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
m = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels)) m = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
v = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels)) v = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
tf.print('activations before:', tf.reduce_sum(u))
b1 = 0.9 b1 = 0.9
b2 = 0.999 b2 = 0.999
eps = 1e-8 eps = 1e-8
# images, u, m, v, b1, b2, eps, i = self.do_update(images, u, m, v, b1, b2, eps, 0)
# i = tf.constant(0, dtype=tf.float16)
# c = lambda images, u, m, v, b1, b2, eps, i: tf.less(i, 20)
# images, u, m, v, b1, b2, eps, i = tf.while_loop(c, self.do_update, [images, u, m, v, b1, b2, eps, i], shape_invariants=[images.get_shape(), tf.TensorShape([None, 100, 200, self.out_channels]), tf.TensorShape([None, 100, 200, self.out_channels]), tf.TensorShape([None, 100, 200, self.out_channels]), None, None, None, i.get_shape()])
for i in range(self.max_activation_iter): for i in range(self.max_activation_iter):
u, m, v = self.do_update(images, u, m, v, b1, b2, eps, i) u, m, v = self.do_update(images, u, m, v, b1, b2, eps, i)
u = tf.nn.relu(u - self.lam) u = tf.nn.relu(u - self.lam)
self.add_loss(self.loss(images, u)) tf.print('activations after:', tf.reduce_sum(u))
return u return u
......
...@@ -8,6 +8,9 @@ from tqdm import tqdm ...@@ -8,6 +8,9 @@ from tqdm import tqdm
import argparse import argparse
import os import os
from sparse_coding_torch.load_data import load_yolo_clips from sparse_coding_torch.load_data import load_yolo_clips
import tensorflow.keras as keras
import tensorflow as tf
from keras_model import SparseCode, do_recon, normalize_weights
def plot_video(video): def plot_video(video):
...@@ -83,16 +86,21 @@ def plot_filters(filters): ...@@ -83,16 +86,21 @@ def plot_filters(filters):
return FuncAnimation(plt.gcf(), update, interval=1000/20) return FuncAnimation(plt.gcf(), update, interval=1000/20)
def sparse_loss(filters_1, filters_2, filters_3, filters_4, filters_5, images, activations, batch_size, lam, stride):
recon = do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, stride)
loss = 0.5 * (1/batch_size) * tf.math.reduce_sum(tf.math.pow(images - recon, 2))
loss += lam * tf.reduce_mean(tf.math.reduce_sum(tf.math.abs(tf.reshape(activations, (batch_size, -1))), axis=1))
return loss
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=12, type=int) parser.add_argument('--batch_size', default=8, type=int)
parser.add_argument('--kernel_size', default=15, type=int) parser.add_argument('--kernel_size', default=15, type=int)
parser.add_argument('--num_kernels', default=64, type=int) parser.add_argument('--num_kernels', default=64, type=int)
parser.add_argument('--stride', default=1, type=int) parser.add_argument('--stride', default=2, type=int)
parser.add_argument('--max_activation_iter', default=200, type=int) parser.add_argument('--max_activation_iter', default=150, type=int)
parser.add_argument('--activation_lr', default=1e-2, type=float) parser.add_argument('--activation_lr', default=1e-2, type=float)
parser.add_argument('--lr', default=1e-3, type=float) parser.add_argument('--lr', default=1e-5, type=float)
parser.add_argument('--epochs', default=100, type=int) parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lam', default=0.05, type=float) parser.add_argument('--lam', default=0.05, type=float)
parser.add_argument('--output_dir', default='./output', type=str) parser.add_argument('--output_dir', default='./output', type=str)
...@@ -104,15 +112,17 @@ if __name__ == "__main__": ...@@ -104,15 +112,17 @@ if __name__ == "__main__":
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f: with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f:
out_f.write(str(args)) out_f.write(str(args))
train_loader, _ = load_yolo_clips(batch_size, mode='all_train') train_loader, _ = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='../positive_videos.json')
print('Loaded', len(train_loader), 'train examples') print('Loaded', len(train_loader), 'train examples')
example_data = next(iter(train_loader)) example_data = next(iter(train_loader))
inputs = keras.Input(shape=(5, 100, 200, 1)) inputs = keras.Input(shape=(100, 200, 5))
output = SparseCode('../sparse.pt', batch_size=args.batch_size, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter)(inputs) output = SparseCode('../sparse.pt', batch_size=args.batch_size, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter)(inputs)
...@@ -124,21 +134,40 @@ if __name__ == "__main__": ...@@ -124,21 +134,40 @@ if __name__ == "__main__":
loss_log = [] loss_log = []
best_so_far = float('inf') best_so_far = float('inf')
for epoch in tqdm(range(args.epochs)): for epoch in range(args.epochs):
epoch_loss = 0 epoch_loss = 0
epoch_start = time.perf_counter() epoch_start = time.perf_counter()
for labels, local_batch, vid_f in train_loader: for labels, local_batch, vid_f in tqdm(train_loader):
tf.print(vid_f)
if local_batch.size(0) != args.batch_size:
continue
images = local_batch.squeeze(1).permute(0, 2, 3, 1).numpy()
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
activations = model(images)
filters = model.get_weights()
activations = model(local_batch.numpy()) loss = sparse_loss(filters[0], filters[1], filters[2], filters[3], filters[4], images, activations, args.batch_size, args.lam, args.stride)
loss = tf.sum(model.losses) # recon = model.do_recon(activations)
# loss = 0.5 * (1/batch_size) * tf.math.reduce_sum(tf.math.pow(images - recon, 2))
# loss += lam * tf.reduce_mean(tf.math.reduce_sum(tf.math.abs(tf.reshape(activations, (batch_size, -1))), axis=1))
epoch_loss += loss * local_batch.size(0) epoch_loss += loss * local_batch.size(0)
tf.print('loss:', loss)
gradients = tape.gradient(loss, model.trainable_weights) gradients = tape.gradient(loss, model.trainable_weights)
tf.print('gradients:', tf.reduce_sum(gradients))
filter_optimizer.apply_gradients(zip(gradients, model.trainable_weights))
tf.print('weights:', tf.reduce_sum(model.trainable_weights))
weights = normalize_weights(model.get_weights(), args.num_kernels)
model.set_weights(weights)
optimizer.apply_gradients(zip(gradients, model.trainable_weights)) tf.print('normalized weights:', tf.reduce_sum(model.trainable_weights))
epoch_end = time.perf_counter() epoch_end = time.perf_counter()
epoch_loss /= len(train_loader.sampler) epoch_loss /= len(train_loader.sampler)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment