Skip to content
Snippets Groups Projects
Commit 13ae3ca4 authored by Darryl Hannan's avatar Darryl Hannan
Browse files

added keras modeling

parent d00c9911
No related branches found
No related tags found
No related merge requests found
from tensorflow import keras
import numpy as np
import torch
import tensorflow as tf
import cv2
import torchvision as tv
import torch
import torch.nn as nn
from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
from keras_model import SparseCode, Classifier
inputs = keras.Input(shape=(5, 100, 200, 1))
x = SparseCode('../sparse.pt', batch_size=1, in_channels=1, out_channels=64, kernel_size=15, stride=1, lam=0.05, activation_lr=1, max_activation_iter=1)(inputs)
outputs = Classifier()(x)
model = keras.Model(inputs=inputs, outputs=x)
pytorch_checkpoint = torch.load('../classifier.pt', map_location='cpu')['model_state_dict']
conv_weights = [pytorch_checkpoint['module.compress_activations_conv_1.weight'].view(8, 8, 64, 24).numpy(), pytorch_checkpoint['module.compress_activations_conv_1.bias'].numpy()]
model.get_layer('classifier').conv.set_weights(conv_weights)
ff_1_weights = [pytorch_checkpoint['module.fc1.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc1.bias'].numpy()]
model.get_layer('classifier').ff_1.set_weights(ff_1_weights)
ff_2_weights = [pytorch_checkpoint['module.fc2.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc2.bias'].numpy()]
model.get_layer('classifier').ff_2.set_weights(ff_2_weights)
ff_3_weights = [pytorch_checkpoint['module.fc3.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc3.bias'].numpy()]
model.get_layer('classifier').ff_3.set_weights(ff_3_weights)
ff_4_weights = [pytorch_checkpoint['module.fc4.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc4.bias'].numpy()]
model.get_layer('classifier').ff_4.set_weights(ff_4_weights)
# frozen_sparse = ConvSparseLayer(in_channels=1,
# out_channels=64,
# kernel_size=(5, 15, 15),
# stride=1,
# padding=(0, 7, 7),
# convo_dim=3,
# rectifier=True,
# lam=0.05,
# max_activation_iter=1,
# activation_lr=1)
#
# sparse_param = torch.load('../sparse.pt', map_location='cpu')
# frozen_sparse.load_state_dict(sparse_param['model_state_dict'])
#
# # pytorch_filter = frozen_sparse.filters[30, :, 0, :, :].squeeze(0).unsqueeze(2).detach().numpy()
# # keras_filter = model.get_layer('sparse_code').filter[0,:,:,:,30].numpy()
# #
# # cv2.imwrite('pytorch_filter.png', pytorch_filter / np.max(pytorch_filter) * 255.)
# # cv2.imwrite('keras_filter.png', keras_filter / np.max(keras_filter) * 255.)
# # raise Exception
#
# img = tv.io.read_video('../clips/No_Sliding/Image_262499828648_clean1050.mp4')[0].permute(3, 0, 1, 2)
# transform = tv.transforms.Compose(
# [VideoGrayScaler(),
# MinMaxScaler(0, 255),
# tv.transforms.Normalize((0.2592,), (0.1251,)),
# tv.transforms.CenterCrop((100, 200))
# ])
# img = transform(img)
#
# with torch.no_grad():
# activations = frozen_sparse(img.unsqueeze(0))
#
# output = model(img.unsqueeze(4).numpy())
input_name = model.input_names[0]
index = model.input_names.index(input_name)
model.inputs[index].set_shape([1, 100, 200, 5])
converter = tf.lite.TFLiteConverter.from_keras_model(model)
# converter.experimental_new_converter = True
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
tflite_model = converter.convert()
print('Converted')
with open("./output/tf_lite_model.tflite", "wb") as f:
f.write(tflite_model)
from tensorflow import keras
import numpy as np
import torch
import tensorflow as tf
import cv2
import torchvision as tv
import torch
import torch.nn as nn
from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
class SparseCode(keras.layers.Layer):
def __init__(self, pytorch_checkpoint, batch_size, in_channels, out_channels, kernel_size, stride, lam, activation_lr, max_activation_iter):
super(SparseCode, self).__init__()
self.out_channels = out_channels
self.in_channels = in_channels
self.stride = stride
self.lam = lam
self.activation_lr = activation_lr
self.max_activation_iter = max_activation_iter
self.batch_size = batch_size
# pytorch_checkpoint = torch.load(pytorch_checkpoint, map_location='cpu')
# weight_tensor = pytorch_checkpoint['model_state_dict']['filters'].swapaxes(1,3).swapaxes(2,4).swapaxes(0,4)
# self.filter = tf.Variable(initial_value=weight_tensor.numpy(), dtype='float32')
# weight_list = torch.chunk(weight_tensor, 5, dim=2)
#
initializer = tf.keras.initializers.HeNormal()
self.filters_1 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
self.filters_2 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
self.filters_3 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
self.filters_4 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
self.filters_5 = tf.Variable(initial_value=initializer(shape=(kernel_size, kernel_size, out_channels, in_channels)), dtype='float32', trainable=True)
def normalize_weights(self):
norms = tf.norm(tf.reshape(tf.stack([self.filters_1, self.filters_2, self.filters_3, self.filters_4, self.filters_5]), (self.out_channels, -1)), axis=1, keepdims=True)
norms = tf.broadcast_to(tf.math.maximum(norms, 1e-12*tf.ones_like(norms)), self.filters_1.shape)
self.filters_1 = self.filters_1 / norms
self.filters_2 = self.filters_2 / norms
self.filters_3 = self.filters_3 / norms
self.filters_4 = self.filters_4 / norms
self.filters_5 = self.filters_5 / norms
@tf.function
def do_recon(self, activations):
out_1 = tf.nn.conv2d_transpose(activations, self.filters_1, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
out_2 = tf.nn.conv2d_transpose(activations, self.filters_2, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
out_3 = tf.nn.conv2d_transpose(activations, self.filters_3, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
out_4 = tf.nn.conv2d_transpose(activations, self.filters_4, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
out_5 = tf.nn.conv2d_transpose(activations, self.filters_5, output_shape=(self.batch_size, 100, 200, 1), strides=self.stride)
recon = tf.concat([out_1, out_2, out_3, out_4, out_5], axis=3)
return recon
@tf.function
def do_recon_3d(self, activations):
recon = tf.nn.conv3d_transpose(activations, self.filter, output_shape=(self.batch_size, 5, 100, 200, 1), strides=self.stride)
return recon
@tf.function
def conv_error(self, e):
e1, e2, e3, e4, e5 = tf.split(e, 5, axis=3)
g = tf.nn.conv2d(e1, self.filters_1, strides=self.stride, padding='SAME')
g = g + tf.nn.conv2d(e2, self.filters_2, strides=self.stride, padding='SAME')
g = g + tf.nn.conv2d(e3, self.filters_3, strides=self.stride, padding='SAME')
g = g + tf.nn.conv2d(e4, self.filters_4, strides=self.stride, padding='SAME')
g = g + tf.nn.conv2d(e5, self.filters_5, strides=self.stride, padding='SAME')
return g
@tf.function
def conv_error_3d(self, e):
g = tf.nn.conv3d(e, self.filter, strides=[1, 1, 1, 1, 1], padding='SAME')
return g
@tf.function
def do_update(self, images, u, m, v, b1, b2, eps, i):
activations = tf.nn.relu(u - self.lam)
recon = self.do_recon(activations)
e = images - recon
g = -1 * u
g = g + self.conv_error(e)
g = g + activations
m = b1 * m + (1-b1) * g
v = b2 * v + (1-b2) * g**2
mh = m / (1 - b1**(1+i))
vh = v / (1 - b2**(1+i))
u = u + (self.activation_lr * mh / (tf.math.sqrt(vh) + eps))
return u, m, v
def loss(self, images, activations):
recon = self.do_recon(activations)
loss = 0.5 * (1/images.shape[0]) * tf.sum(tf.math.pow(images - recon, 2))
loss += self.lam * tf.reduce_mean(tf.sum(tf.math.abs(activations.reshape(activations.shape[0], -1)), axis=1))
return loss
@tf.function
def call(self, images):
u = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
m = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
v = tf.zeros(shape=(self.batch_size, 5, 100 // self.stride, 200 // self.stride, self.out_channels))
b1 = 0.9
b2 = 0.999
eps = 1e-8
# images, u, m, v, b1, b2, eps, i = self.do_update(images, u, m, v, b1, b2, eps, 0)
# i = tf.constant(0, dtype=tf.float16)
# c = lambda images, u, m, v, b1, b2, eps, i: tf.less(i, 20)
# images, u, m, v, b1, b2, eps, i = tf.while_loop(c, self.do_update, [images, u, m, v, b1, b2, eps, i], shape_invariants=[images.get_shape(), tf.TensorShape([None, 100, 200, self.out_channels]), tf.TensorShape([None, 100, 200, self.out_channels]), tf.TensorShape([None, 100, 200, self.out_channels]), None, None, None, i.get_shape()])
for i in range(self.max_activation_iter):
u, m, v = self.do_update(images, u, m, v, b1, b2, eps, i)
u = tf.nn.relu(u - self.lam)
self.add_loss(self.loss(images, u))
return u
class Classifier(keras.layers.Layer):
def __init__(self):
super(Classifier, self).__init__()
self.max_pool = keras.layers.MaxPooling2D(pool_size=4, strides=4)
self.conv = keras.layers.Conv2D(24, kernel_size=8, strides=4, activation='relu', padding='SAME')
self.flatten = keras.layers.Flatten()
self.dropout = keras.layers.Dropout(0.5)
self.ff_1 = keras.layers.Dense(1000, activation='relu', use_bias=True)
self.ff_2 = keras.layers.Dense(100, activation='relu', use_bias=True)
self.ff_3 = keras.layers.Dense(20, activation='relu', use_bias=True)
self.ff_4 = keras.layers.Dense(1, activation='sigmoid')
@tf.function
def call(self, activations):
x = self.max_pool(activations)
x = self.conv(x)
x = self.flatten(x)
x = self.ff_1(x)
x = self.dropout(x)
x = self.ff_2(x)
x = self.dropout(x)
x = self.ff_3(x)
x = self.dropout(x)
x = self.ff_4(x)
return x
import time
import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib import cm
from matplotlib.animation import FuncAnimation
from tqdm import tqdm
import argparse
import os
from sparse_coding_torch.load_data import load_yolo_clips
def plot_video(video):
fig = plt.gcf()
ax = plt.gca()
DPI = fig.get_dpi()
fig.set_size_inches(video.shape[2]/float(DPI), video.shape[3]/float(DPI))
ax.set_title("Video")
T = video.shape[1]
im = ax.imshow(video[0, 0, :, :],
cmap=cm.Greys_r)
def update(i):
t = i % T
im.set_data(video[0, t, :, :])
return FuncAnimation(plt.gcf(), update, interval=1000/20)
def plot_original_vs_recon(original, reconstruction, idx=0):
# create two subplots
ax1 = plt.subplot(1, 2, 1)
ax2 = plt.subplot(1, 2, 2)
ax1.set_title("Original")
ax2.set_title("Reconstruction")
T = original.shape[2]
im1 = ax1.imshow(original[idx, 0, 0, :, :],
cmap=cm.Greys_r)
im2 = ax2.imshow(reconstruction[idx, 0, 0, :, :],
cmap=cm.Greys_r)
def update(i):
t = i % T
im1.set_data(original[idx, 0, t, :, :])
im2.set_data(reconstruction[idx, 0, t, :, :])
return FuncAnimation(plt.gcf(), update, interval=1000/30)
def plot_filters(filters):
num_filters = filters.shape[0]
ncol = 3
# ncol = int(np.sqrt(num_filters))
# nrow = int(np.sqrt(num_filters))
T = filters.shape[2]
if num_filters // ncol == num_filters / ncol:
nrow = num_filters // ncol
else:
nrow = num_filters // ncol + 1
fig, axes = plt.subplots(ncols=ncol, nrows=nrow,
constrained_layout=True,
figsize=(ncol*2, nrow*2))
ims = {}
for i in range(num_filters):
r = i // ncol
c = i % ncol
ims[(r, c)] = axes[r, c].imshow(filters[i, 0, 0, :, :],
cmap=cm.Greys_r)
def update(i):
t = i % T
for i in range(num_filters):
r = i // ncol
c = i % ncol
ims[(r, c)].set_data(filters[i, 0, t, :, :])
return FuncAnimation(plt.gcf(), update, interval=1000/20)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=12, type=int)
parser.add_argument('--kernel_size', default=15, type=int)
parser.add_argument('--num_kernels', default=64, type=int)
parser.add_argument('--stride', default=1, type=int)
parser.add_argument('--max_activation_iter', default=200, type=int)
parser.add_argument('--activation_lr', default=1e-2, type=float)
parser.add_argument('--lr', default=1e-3, type=float)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lam', default=0.05, type=float)
parser.add_argument('--output_dir', default='./output', type=str)
parser.add_argument('--seed', default=42, type=int)
args = parser.parse_args()
output_dir = args.output_dir
if not os.path.exists(output_dir):
os.makedirs(output_dir)
with open(os.path.join(output_dir, 'arguments.txt'), 'w+') as out_f:
out_f.write(str(args))
train_loader, _ = load_yolo_clips(batch_size, mode='all_train')
print('Loaded', len(train_loader), 'train examples')
example_data = next(iter(train_loader))
inputs = keras.Input(shape=(5, 100, 200, 1))
output = SparseCode('../sparse.pt', batch_size=args.batch_size, in_channels=1, out_channels=args.num_kernels, kernel_size=args.kernel_size, stride=args.stride, lam=args.lam, activation_lr=args.activation_lr, max_activation_iter=args.max_activation_iter)(inputs)
model = keras.Model(inputs=inputs, outputs=output)
learning_rate = args.lr
filter_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
loss_log = []
best_so_far = float('inf')
for epoch in tqdm(range(args.epochs)):
epoch_loss = 0
epoch_start = time.perf_counter()
for labels, local_batch, vid_f in train_loader:
with tf.GradientTape() as tape:
activations = model(local_batch.numpy())
loss = tf.sum(model.losses)
epoch_loss += loss * local_batch.size(0)
gradients = tape.gradient(loss, model.trainable_weights)
optimizer.apply_gradients(zip(gradients, model.trainable_weights))
epoch_end = time.perf_counter()
epoch_loss /= len(train_loader.sampler)
if epoch_loss < best_so_far:
print("found better model")
# Save model parameters
model.save(os.path.join(output_dir, "sparse_conv3d_model-best.pt"))
best_so_far = epoch_loss
loss_log.append(epoch_loss)
print('epoch={}, epoch_loss={:.2f}, time={:.2f}'.format(epoch, epoch_loss, epoch_end - epoch_start))
plt.plot(loss_log)
plt.savefig(os.path.join(output_dir, 'loss_graph.png'))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment