Skip to content
Snippets Groups Projects
Commit bd43054c authored by hannandarryl's avatar hannandarryl
Browse files

keras model updates

parent e3e0dafd
No related branches found
No related tags found
No related merge requests found
...@@ -8,28 +8,27 @@ import torch ...@@ -8,28 +8,27 @@ import torch
import torch.nn as nn import torch.nn as nn
from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
from sparse_coding_torch.conv_sparse_model import ConvSparseLayer from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
from keras_model import SparseCode, Classifier from keras_model import MobileModel
inputs = keras.Input(shape=(5, 100, 200, 1)) inputs = keras.Input(shape=(100, 200, 5))
x = SparseCode('../sparse.pt', batch_size=1, in_channels=1, out_channels=64, kernel_size=15, stride=1, lam=0.05, activation_lr=1, max_activation_iter=1)(inputs) outputs = MobileModel(sparse_checkpoint='../sparse.pt', batch_size=1, in_channels=1, out_channels=64, kernel_size=15, stride=1, lam=0.05, activation_lr=1e-2, max_activation_iter=40, run_2d=True)(inputs)
outputs = Classifier()(x)
model = keras.Model(inputs=inputs, outputs=x) model = keras.Model(inputs=inputs, outputs=outputs)
pytorch_checkpoint = torch.load('../classifier.pt', map_location='cpu')['model_state_dict'] pytorch_checkpoint = torch.load('../output/final_model_75_iter/model-best_fold_0.pt', map_location='cpu')['model_state_dict']
conv_weights = [pytorch_checkpoint['module.compress_activations_conv_1.weight'].view(8, 8, 64, 24).numpy(), pytorch_checkpoint['module.compress_activations_conv_1.bias'].numpy()] conv_weights = [pytorch_checkpoint['module.compress_activations_conv_1.weight'].view(8, 8, 64, 24).numpy(), pytorch_checkpoint['module.compress_activations_conv_1.bias'].numpy()]
model.get_layer('classifier').conv.set_weights(conv_weights) model.get_layer('mobile_model').classifier.conv.set_weights(conv_weights)
ff_1_weights = [pytorch_checkpoint['module.fc1.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc1.bias'].numpy()] ff_1_weights = [pytorch_checkpoint['module.fc1.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc1.bias'].numpy()]
model.get_layer('classifier').ff_1.set_weights(ff_1_weights) model.get_layer('mobile_model').classifier.ff_1.set_weights(ff_1_weights)
ff_2_weights = [pytorch_checkpoint['module.fc2.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc2.bias'].numpy()] ff_2_weights = [pytorch_checkpoint['module.fc2.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc2.bias'].numpy()]
model.get_layer('classifier').ff_2.set_weights(ff_2_weights) model.get_layer('mobile_model').classifier.ff_2.set_weights(ff_2_weights)
ff_3_weights = [pytorch_checkpoint['module.fc3.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc3.bias'].numpy()] ff_3_weights = [pytorch_checkpoint['module.fc3.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc3.bias'].numpy()]
model.get_layer('classifier').ff_3.set_weights(ff_3_weights) model.get_layer('mobile_model').classifier.ff_3.set_weights(ff_3_weights)
ff_4_weights = [pytorch_checkpoint['module.fc4.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc4.bias'].numpy()] ff_4_weights = [pytorch_checkpoint['module.fc4.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc4.bias'].numpy()]
model.get_layer('classifier').ff_4.set_weights(ff_4_weights) model.get_layer('mobile_model').classifier.ff_4.set_weights(ff_4_weights)
# frozen_sparse = ConvSparseLayer(in_channels=1, # frozen_sparse = ConvSparseLayer(in_channels=1,
# out_channels=64, # out_channels=64,
...@@ -39,9 +38,9 @@ model.get_layer('classifier').ff_4.set_weights(ff_4_weights) ...@@ -39,9 +38,9 @@ model.get_layer('classifier').ff_4.set_weights(ff_4_weights)
# convo_dim=3, # convo_dim=3,
# rectifier=True, # rectifier=True,
# lam=0.05, # lam=0.05,
# max_activation_iter=1, # max_activation_iter=10,
# activation_lr=1) # activation_lr=1e-2)
#
# sparse_param = torch.load('../sparse.pt', map_location='cpu') # sparse_param = torch.load('../sparse.pt', map_location='cpu')
# frozen_sparse.load_state_dict(sparse_param['model_state_dict']) # frozen_sparse.load_state_dict(sparse_param['model_state_dict'])
# #
...@@ -60,11 +59,16 @@ model.get_layer('classifier').ff_4.set_weights(ff_4_weights) ...@@ -60,11 +59,16 @@ model.get_layer('classifier').ff_4.set_weights(ff_4_weights)
# tv.transforms.CenterCrop((100, 200)) # tv.transforms.CenterCrop((100, 200))
# ]) # ])
# img = transform(img) # img = transform(img)
#
# with torch.no_grad(): # with torch.no_grad():
# activations = frozen_sparse(img.unsqueeze(0)) # activations = frozen_sparse(img.unsqueeze(0))
#
# output = model(img.unsqueeze(4).numpy()) # output = model(img.swapaxes(1, 3).swapaxes(1,2).numpy())
# print(activations.size())
# print(output.shape)
# print(torch.sum(activations))
# print(tf.math.reduce_sum(output))
input_name = model.input_names[0] input_name = model.input_names[0]
index = model.input_names.index(input_name) index = model.input_names.index(input_name)
...@@ -80,5 +84,5 @@ tflite_model = converter.convert() ...@@ -80,5 +84,5 @@ tflite_model = converter.convert()
print('Converted') print('Converted')
with open("./output/tf_lite_model.tflite", "wb") as f: with open("./mobile_output/tf_lite_model.tflite", "wb") as f:
f.write(tflite_model) f.write(tflite_model)
...@@ -10,7 +10,7 @@ import os ...@@ -10,7 +10,7 @@ import os
from sparse_coding_torch.load_data import load_yolo_clips from sparse_coding_torch.load_data import load_yolo_clips
import tensorflow.keras as keras import tensorflow.keras as keras
import tensorflow as tf import tensorflow as tf
from keras_model import normalize_weights_3d, normalize_weights, SparseCodeConv from keras_model import normalize_weights_3d, normalize_weights, SparseCodeConv, load_pytorch_weights
def plot_video(video): def plot_video(video):
...@@ -99,7 +99,7 @@ if __name__ == "__main__": ...@@ -99,7 +99,7 @@ if __name__ == "__main__":
parser.add_argument('--stride', default=2, type=int) parser.add_argument('--stride', default=2, type=int)
parser.add_argument('--max_activation_iter', default=50, type=int) parser.add_argument('--max_activation_iter', default=50, type=int)
parser.add_argument('--activation_lr', default=1e-2, type=float) parser.add_argument('--activation_lr', default=1e-2, type=float)
parser.add_argument('--lr', default=1e-2, type=float) parser.add_argument('--lr', default=5e-2, type=float)
parser.add_argument('--epochs', default=100, type=int) parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lam', default=0.05, type=float) parser.add_argument('--lam', default=0.05, type=float)
parser.add_argument('--output_dir', default='./output', type=str) parser.add_argument('--output_dir', default='./output', type=str)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment