Skip to content
Snippets Groups Projects
Commit bd43054c authored by hannandarryl's avatar hannandarryl
Browse files

keras model updates

parent e3e0dafd
Branches
Tags
No related merge requests found
......@@ -8,28 +8,27 @@ import torch
import torch.nn as nn
from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
from keras_model import SparseCode, Classifier
from keras_model import MobileModel
inputs = keras.Input(shape=(5, 100, 200, 1))
inputs = keras.Input(shape=(100, 200, 5))
x = SparseCode('../sparse.pt', batch_size=1, in_channels=1, out_channels=64, kernel_size=15, stride=1, lam=0.05, activation_lr=1, max_activation_iter=1)(inputs)
outputs = Classifier()(x)
outputs = MobileModel(sparse_checkpoint='../sparse.pt', batch_size=1, in_channels=1, out_channels=64, kernel_size=15, stride=1, lam=0.05, activation_lr=1e-2, max_activation_iter=40, run_2d=True)(inputs)
model = keras.Model(inputs=inputs, outputs=x)
model = keras.Model(inputs=inputs, outputs=outputs)
pytorch_checkpoint = torch.load('../classifier.pt', map_location='cpu')['model_state_dict']
pytorch_checkpoint = torch.load('../output/final_model_75_iter/model-best_fold_0.pt', map_location='cpu')['model_state_dict']
conv_weights = [pytorch_checkpoint['module.compress_activations_conv_1.weight'].view(8, 8, 64, 24).numpy(), pytorch_checkpoint['module.compress_activations_conv_1.bias'].numpy()]
model.get_layer('classifier').conv.set_weights(conv_weights)
model.get_layer('mobile_model').classifier.conv.set_weights(conv_weights)
ff_1_weights = [pytorch_checkpoint['module.fc1.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc1.bias'].numpy()]
model.get_layer('classifier').ff_1.set_weights(ff_1_weights)
model.get_layer('mobile_model').classifier.ff_1.set_weights(ff_1_weights)
ff_2_weights = [pytorch_checkpoint['module.fc2.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc2.bias'].numpy()]
model.get_layer('classifier').ff_2.set_weights(ff_2_weights)
model.get_layer('mobile_model').classifier.ff_2.set_weights(ff_2_weights)
ff_3_weights = [pytorch_checkpoint['module.fc3.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc3.bias'].numpy()]
model.get_layer('classifier').ff_3.set_weights(ff_3_weights)
model.get_layer('mobile_model').classifier.ff_3.set_weights(ff_3_weights)
ff_4_weights = [pytorch_checkpoint['module.fc4.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc4.bias'].numpy()]
model.get_layer('classifier').ff_4.set_weights(ff_4_weights)
model.get_layer('mobile_model').classifier.ff_4.set_weights(ff_4_weights)
# frozen_sparse = ConvSparseLayer(in_channels=1,
# out_channels=64,
......@@ -39,9 +38,9 @@ model.get_layer('classifier').ff_4.set_weights(ff_4_weights)
# convo_dim=3,
# rectifier=True,
# lam=0.05,
# max_activation_iter=1,
# activation_lr=1)
#
# max_activation_iter=10,
# activation_lr=1e-2)
# sparse_param = torch.load('../sparse.pt', map_location='cpu')
# frozen_sparse.load_state_dict(sparse_param['model_state_dict'])
#
......@@ -60,11 +59,16 @@ model.get_layer('classifier').ff_4.set_weights(ff_4_weights)
# tv.transforms.CenterCrop((100, 200))
# ])
# img = transform(img)
#
# with torch.no_grad():
# activations = frozen_sparse(img.unsqueeze(0))
#
# output = model(img.unsqueeze(4).numpy())
# output = model(img.swapaxes(1, 3).swapaxes(1,2).numpy())
# print(activations.size())
# print(output.shape)
# print(torch.sum(activations))
# print(tf.math.reduce_sum(output))
input_name = model.input_names[0]
index = model.input_names.index(input_name)
......@@ -80,5 +84,5 @@ tflite_model = converter.convert()
print('Converted')
with open("./output/tf_lite_model.tflite", "wb") as f:
with open("./mobile_output/tf_lite_model.tflite", "wb") as f:
f.write(tflite_model)
......@@ -10,7 +10,7 @@ import os
from sparse_coding_torch.load_data import load_yolo_clips
import tensorflow.keras as keras
import tensorflow as tf
from keras_model import normalize_weights_3d, normalize_weights, SparseCodeConv
from keras_model import normalize_weights_3d, normalize_weights, SparseCodeConv, load_pytorch_weights
def plot_video(video):
......@@ -99,7 +99,7 @@ if __name__ == "__main__":
parser.add_argument('--stride', default=2, type=int)
parser.add_argument('--max_activation_iter', default=50, type=int)
parser.add_argument('--activation_lr', default=1e-2, type=float)
parser.add_argument('--lr', default=1e-2, type=float)
parser.add_argument('--lr', default=5e-2, type=float)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lam', default=0.05, type=float)
parser.add_argument('--output_dir', default='./output', type=str)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment