Skip to content
Snippets Groups Projects
Commit 3d931dab authored by hannandarryl's avatar hannandarryl
Browse files

keras training code

parent 13ae3ca4
No related branches found
No related tags found
No related merge requests found
......@@ -9,84 +9,89 @@ import torch.nn as nn
from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
class SparseCode(keras.layers.Layer):
def __init__(self, pytorch_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter):
super(SparseCode, self).__init__()
self.out_channels = out_channels
self.in_channels = in_channels
self.stride = stride
self.lam = lam
self.activation_lr = activation_lr
self.max_activation_iter = max_activation_iter
self.batch_size = batch_size
# pytorch_checkpoint = torch.load(pytorch_checkpoint, map_location='cpu')
# weight_tensor = pytorch_checkpoint['model_state_dict']['filters'].swapaxes(1,3).swapaxes(2,4).swapaxes(0,4)
# self.filter = tf.Variable(initial_value=weight_tensor.numpy(), dtype='float32')
# weight_list = torch.chunk(weight_tensor, 5, dim=2)
#
initializer = tf.keras.initializers.HeNormal()
self.filters_1 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
self.filters_2 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
self.filters_3 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
self.filters_4 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
def normalize_weights(self):
norms = tf.norm(tf.reshape(tf.stack([self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5]), (self.out_channels, -1)), axis=1, keepdims=True)
norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), self.filters_1.shape)
self.filters_1 = self.filters_1 / norms
self.filters_2 = self.filters_2 / norms
self.filters_3 = self.filters_3 / norms
self.filters_4 = self.filters_4 / norms
self.filters_5 = self.filters_5 / norms
@tf.function
def do_recon(self, activations):
out_1 = tf.nn.conv2d_transpose(activations, self.filters_1, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
out_2 = tf.nn.conv2d_transpose(activations, self.filters_2, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
out_3 = tf.nn.conv2d_transpose(activations, self.filters_3, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
out_4 = tf.nn.conv2d_transpose(activations, self.filters_4, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
out_5 = tf.nn.conv2d_transpose(activations, self.filters_5, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
def do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, stride):
out_1 = tf.nn.conv2d_transpose(activations, filters_1, output_shape=(batch_size, 100, 200, 1), strides=stride)
out_2 = tf.nn.conv2d_transpose(activations, filters_2, output_shape=(batch_size, 100, 200, 1), strides=stride)
out_3 = tf.nn.conv2d_transpose(activations, filters_3, output_shape=(batch_size, 100, 200, 1), strides=stride)
out_4 = tf.nn.conv2d_transpose(activations, filters_4, output_shape=(batch_size, 100, 200, 1), strides=stride)
out_5 = tf.nn.conv2d_transpose(activations, filters_5, output_shape=(batch_size, 100, 200, 1), strides=stride)
recon = tf.concat([out_1, out_2, out_3, out_4, out_5], axis=3)
return recon
@tf.function
def do_recon_3d(self, activations):
recon = tf.nn.conv3d_transpose(activations, self.filter, output_shape=(self.batch_size, 5, 100, 200, 1), strides=self.stride)
def do_recon_3d(filters, activations, batch_size, stride):
recon = tf.nn.conv3d_transpose(activations, filters, output_shape=(batch_size, 5, 100, 200, 1), strides=stride)
return recon
@tf.function
def conv_error(self, e):
def conv_error(filters_1, filters_2, filters_3, filters_4, filters_5, e, stride):
e1, e2, e3, e4, e5 = tf.split(e, 5, axis=3)
g = tf.nn.conv2d(e1, self.filters_1, strides=self.stride, padding='SAME')
g = g + tf.nn.conv2d(e2, self.filters_2, strides=self.stride, padding='SAME')
g = g + tf.nn.conv2d(e3, self.filters_3, strides=self.stride, padding='SAME')
g = g + tf.nn.conv2d(e4, self.filters_4, strides=self.stride, padding='SAME')
g = g + tf.nn.conv2d(e5, self.filters_5, strides=self.stride, padding='SAME')
g = tf.nn.conv2d(e1, filters_1, strides=stride, padding='SAME')
g = g + tf.nn.conv2d(e2, filters_2, strides=stride, padding='SAME')
g = g + tf.nn.conv2d(e3, filters_3, strides=stride, padding='SAME')
g = g + tf.nn.conv2d(e4, filters_4, strides=stride, padding='SAME')
g = g + tf.nn.conv2d(e5, filters_5, strides=stride, padding='SAME')
return g
@tf.function
def conv_error_3d(self, e):
g = tf.nn.conv3d(e, self.filter, strides=[1, 1, 1, 1, 1], padding='SAME')
def conv_error_3d(filters, e, stride):
g = tf.nn.conv3d(e, filters, strides=[stride, stride, stride, stride, stride], padding='SAME')
return g
@tf.function
def normalize_weights(filters, out_channels):
norms = tf.norm(tf.reshape(tf.stack(filters), (out_channels, -1)), axis=1)
norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), filters[0].shape)
adjusted = [f / norms for f in filters]
return adjusted
class SparseCode(keras.layers.Layer):
def __init__(self, pytorch_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter):
super(SparseCode, self).__init__()
self.out_channels = out_channels
self.in_channels = in_channels
self.stride = stride
self.lam = lam
self.activation_lr = activation_lr
self.max_activation_iter = max_activation_iter
self.batch_size = batch_size
# pytorch_checkpoint = torch.load(pytorch_checkpoint, map_location='cpu')
# weight_tensor = pytorch_checkpoint['model_state_dict']['filters'].swapaxes(1,3).swapaxes(2,4).swapaxes(0,4)
# self.filter = tf.Variable(initial_value=weight_tensor.numpy(), dtype='float32')
# weight_list = torch.chunk(weight_tensor, 5, dim=2)
#
initializer = tf.keras.initializers.HeNormal()
self.filters_1 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
self.filters_2 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
self.filters_3 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
self.filters_4 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, in_channels, out_channels)), dtype='float32', trainable=True)
weights = normalize_weights(self.get_weights(), out_channels)
self.set_weights(weights)
@tf.function
def do_update(self, images, u, m, v, b1, b2, eps, i):
activations = tf.nn.relu(u - self.lam)
recon = self.do_recon(activations)
recon = do_recon(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, activations, self.batch_size, self.stride)
e = images - recon
g = -1 * u
g = g + self.conv_error(e)
convd_error = conv_error(self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5, e, self.stride)
g = g + convd_error
g = g + activations
......@@ -94,37 +99,30 @@ class SparseCode(keras.layers.Layer):
v = b2 * v + (1-b2) * g**2
mh = m / (1 - b1**(1+i))
vh = v / (1 - b2**(1+i))
u = u + (self.activation_lr * mh / (tf.math.sqrt(vh) + eps))
return u, m, v
du = self.activation_lr * mh / (tf.math.sqrt(vh) + eps)
u += du
def loss(self, images, activations):
recon = self.do_recon(activations)
loss = 0.5 * (1/images.shape[0]) * tf.sum(tf.math.pow(images - recon, 2))
loss += self.lam * tf.reduce_mean(tf.sum(tf.math.abs(activations.reshape(activations.shape[0], -1)), axis=1))
return loss
return u, m, v
@tf.function
def call(self, images):
u = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
m = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
v = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
u = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
m = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
v = tf.zeros(shape=(self.batch_size, 100 // self.stride, 200 // self.stride, self.out_channels))
tf.print('activations before:', tf.reduce_sum(u))
b1 = 0.9
b2 = 0.999
eps = 1e-8
# images, u, m, v, b1, b2, eps, i = self.do_update(images, u, m, v, b1, b2, eps, 0)
# i = tf.constant(0, dtype=tf.float16)
# c = lambda images, u, m, v, b1, b2, eps, i: tf.less(i, 20)
# images, u, m, v, b1, b2, eps, i = tf.while_loop(c, self.do_update, [images, u, m, v, b1, b2, eps, i], shape_invariants=[images.get_shape(), tf.TensorShape([None, 100, 200, self.out_channels]), tf.TensorShape([None, 100, 200, self.out_channels]), tf.TensorShape([None, 100, 200, self.out_channels]), None, None, None, i.get_shape()])
for i in range(self.max_activation_iter):
u, m, v = self.do_update(images, u, m, v, b1, b2, eps, i)
u = tf.nn.relu(u - self.lam)
self.add_loss(self.loss(images, u))
tf.print('activations after:', tf.reduce_sum(u))
return u
......
......@@ -8,6 +8,9 @@ from tqdm import tqdm
import argparse
import os
from sparse_coding_torch.load_data import load_yolo_clips
import tensorflow.keras as keras
import tensorflow as tf
from keras_model import SparseCode, do_recon, normalize_weights
def plot_video(video):
......@@ -83,16 +86,21 @@ def plot_filters(filters):
return FuncAnimation(plt.gcf(), update, interval=1000/20)
def sparse_loss(filters_1, filters_2, filters_3, filters_4, filters_5, images, activations, batch_size, lam, stride):
recon = do_recon(filters_1, filters_2, filters_3, filters_4, filters_5, activations, batch_size, stride)
loss = 0.5 * (1/batch_size) * tf.math.reduce_sum(tf.math.pow(images - recon, 2))
loss += lam * tf.reduce_mean(tf.math.reduce_sum(tf.math.abs(tf.reshape(activations, (batch_size, -1))), axis=1))
return loss
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=12, type=int)
parser.add_argument('--batch_size', default=8, type=int)
parser.add_argument('--kernel_size', default=15, type=int)
parser.add_argument('--num_kernels', default=64, type=int)
parser.add_argument('--stride', default=1, type=int)
parser.add_argument('--max_activation_iter', default=200, type=int)
parser.add_argument('--stride', default=2, type=int)
parser.add_argument('--max_activation_iter', default=150, type=int)
parser.add_argument('--activation_lr', default=1e-2, type=float)
parser.add_argument('--lr', default=1e-3, type=float)
parser.add_argument('--lr', default=1e-5, type=float)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lam', default=0.05, type=float)
parser.add_argument('--output_dir', default='./output', type=str)
......@@ -104,15 +112,17 @@ if __name__ == "__main__":
if not os.path.exists(output_dir):
os.makedirs(output_dir)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f:
out_f.write(str(args))
train_loader, _ = load_yolo_clips(batch_size, mode='all_train')
train_loader, _ = load_yolo_clips(args.batch_size, num_clips=1, num_positives=15, mode='all_train', device=device, n_splits=1, sparse_model=None, whole_video=False, positive_videos='../positive_videos.json')
print('Loaded', len(train_loader), 'train examples')
example_data = next(iter(train_loader))
inputs = keras.Input(shape=(5, 100, 200, 1))
inputs = keras.Input(shape=(100, 200, 5))
output = SparseCode('../sparse.pt', batch_size=args.batch_size, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter)(inputs)
......@@ -124,21 +134,40 @@ if __name__ == "__main__":
loss_log = []
best_so_far = float('inf')
for epoch in tqdm(range(args.epochs)):
for epoch in range(args.epochs):
epoch_loss = 0
epoch_start = time.perf_counter()
for labels, local_batch, vid_f in train_loader:
for labels, local_batch, vid_f in tqdm(train_loader):
tf.print(vid_f)
if local_batch.size(0) != args.batch_size:
continue
images = local_batch.squeeze(1).permute(0, 2, 3, 1).numpy()
with tf.GradientTape() as tape:
activations = model(images)
filters = model.get_weights()
activations = model(local_batch.numpy())
loss = tf.sum(model.losses)
loss = sparse_loss(filters[0], filters[1], filters[2], filters[3], filters[4], images, activations, args.batch_size, args.lam, args.stride)
# recon = model.do_recon(activations)
# loss = 0.5 * (1/batch_size) * tf.math.reduce_sum(tf.math.pow(images - recon, 2))
# loss += lam * tf.reduce_mean(tf.math.reduce_sum(tf.math.abs(tf.reshape(activations, (batch_size, -1))), axis=1))
epoch_loss += loss * local_batch.size(0)
tf.print('loss:', loss)
gradients = tape.gradient(loss, model.trainable_weights)
tf.print('gradients:', tf.reduce_sum(gradients))
filter_optimizer.apply_gradients(zip(gradients, model.trainable_weights))
tf.print('weights:', tf.reduce_sum(model.trainable_weights))
weights = normalize_weights(model.get_weights(), args.num_kernels)
model.set_weights(weights)
optimizer.apply_gradients(zip(gradients, model.trainable_weights))
tf.print('normalized weights:', tf.reduce_sum(model.trainable_weights))
epoch_end = time.perf_counter()
epoch_loss /= len(train_loader.sampler)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment