Skip to content
Snippets Groups Projects
Commit bd43054c authored by hannandarryl's avatar hannandarryl
Browse files

keras model updates

parent e3e0dafd
No related branches found
No related tags found
No related merge requests found
......@@ -8,28 +8,27 @@ import torch
import torch.nn as nn
from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
from keras_model import SparseCode, Classifier
from keras_model import MobileModel
inputs = keras.Input(shape=(5, 100, 200, 1))
inputs = keras.Input(shape=(100, 200, 5))
x = SparseCode('../sparse.pt', batch_size=1, in_channels=1, out_channels=64, kernel_size=15, stride=1, lam=0.05, activation_lr=1, max_activation_iter=1)(inputs)
outputs = Classifier()(x)
outputs = MobileModel(sparse_checkpoint='../sparse.pt', batch_size=1, in_channels=1, out_channels=64, kernel_size=15, stride=1, lam=0.05, activation_lr=1e-2, max_activation_iter=40, run_2d=True)(inputs)
model = keras.Model(inputs=inputs, outputs=x)
model = keras.Model(inputs=inputs, outputs=outputs)
pytorch_checkpoint = torch.load('../classifier.pt', map_location='cpu')['model_state_dict']
pytorch_checkpoint = torch.load('../output/final_model_75_iter/model-best_fold_0.pt', map_location='cpu')['model_state_dict']
conv_weights = [pytorch_checkpoint['module.compress_activations_conv_1.weight'].view(8, 8, 64, 24).numpy(), pytorch_checkpoint['module.compress_activations_conv_1.bias'].numpy()]
model.get_layer('classifier').conv.set_weights(conv_weights)
model.get_layer('mobile_model').classifier.conv.set_weights(conv_weights)
ff_1_weights = [pytorch_checkpoint['module.fc1.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc1.bias'].numpy()]
model.get_layer('classifier').ff_1.set_weights(ff_1_weights)
model.get_layer('mobile_model').classifier.ff_1.set_weights(ff_1_weights)
ff_2_weights = [pytorch_checkpoint['module.fc2.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc2.bias'].numpy()]
model.get_layer('classifier').ff_2.set_weights(ff_2_weights)
model.get_layer('mobile_model').classifier.ff_2.set_weights(ff_2_weights)
ff_3_weights = [pytorch_checkpoint['module.fc3.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc3.bias'].numpy()]
model.get_layer('classifier').ff_3.set_weights(ff_3_weights)
model.get_layer('mobile_model').classifier.ff_3.set_weights(ff_3_weights)
ff_4_weights = [pytorch_checkpoint['module.fc4.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc4.bias'].numpy()]
model.get_layer('classifier').ff_4.set_weights(ff_4_weights)
model.get_layer('mobile_model').classifier.ff_4.set_weights(ff_4_weights)
# frozen_sparse = ConvSparseLayer(in_channels=1,
# out_channels=64,
......@@ -39,9 +38,9 @@ model.get_layer('classifier').ff_4.set_weights(ff_4_weights)
# convo_dim=3,
# rectifier=True,
# lam=0.05,
# max_activation_iter=1,
# activation_lr=1)
#
# max_activation_iter=10,
# activation_lr=1e-2)
# sparse_param = torch.load('../sparse.pt', map_location='cpu')
# frozen_sparse.load_state_dict(sparse_param['model_state_dict'])
#
......@@ -60,11 +59,16 @@ model.get_layer('classifier').ff_4.set_weights(ff_4_weights)
# tv.transforms.CenterCrop((100, 200))
# ])
# img = transform(img)
#
# with torch.no_grad():
# activations = frozen_sparse(img.unsqueeze(0))
#
# output = model(img.unsqueeze(4).numpy())
# output = model(img.swapaxes(1, 3).swapaxes(1,2).numpy())
# print(activations.size())
# print(output.shape)
# print(torch.sum(activations))
# print(tf.math.reduce_sum(output))
input_name = model.input_names[0]
index = model.input_names.index(input_name)
......@@ -80,5 +84,5 @@ tflite_model = converter.convert()
print('Converted')
with open("./output/tf_lite_model.tflite", "wb") as f:
with open("./mobile_output/tf_lite_model.tflite", "wb") as f:
f.write(tflite_model)
......@@ -10,7 +10,7 @@ import os
from sparse_coding_torch.load_data import load_yolo_clips
import tensorflow.keras as keras
import tensorflow as tf
from keras_model import normalize_weights_3d, normalize_weights, SparseCodeConv
from keras_model import normalize_weights_3d, normalize_weights, SparseCodeConv, load_pytorch_weights
def plot_video(video):
......@@ -99,7 +99,7 @@ if __name__ == "__main__":
parser.add_argument('--stride', default=2, type=int)
parser.add_argument('--max_activation_iter', default=50, type=int)
parser.add_argument('--activation_lr', default=1e-2, type=float)
parser.add_argument('--lr', default=1e-2, type=float)
parser.add_argument('--lr', default=5e-2, type=float)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lam', default=0.05, type=float)
parser.add_argument('--output_dir', default='./output', type=str)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment