diff --git a/keras/generate_tflite.py b/keras/generate_tflite.py index f68b95b30a5b37321e04f793a5884b1aebfcad26..3dd871a95cc24a0d310393e502829c642600cfd1 100644 --- a/keras/generate_tflite.py +++ b/keras/generate_tflite.py @@ -8,28 +8,27 @@ import torch import torch.nn as nn from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler from sparse_coding_torch.conv_sparse_model import ConvSparseLayer -from keras_model import SparseCode, Classifier +from keras_model import MobileModel -inputs = keras.Input(shape=(5, 100, 200, 1)) +inputs = keras.Input(shape=(100, 200, 5)) -x = SparseCode('../sparse.pt', batch_size=1, in_channels=1, out_channels=64, kernel_size=15, stride=1, lam=0.05, activation_lr=1, max_activation_iter=1)(inputs) -outputs = Classifier()(x) +outputs = MobileModel(sparse_checkpoint='../sparse.pt', batch_size=1, in_channels=1, out_channels=64, kernel_size=15, stride=1, lam=0.05, activation_lr=1e-2, max_activation_iter=40, run_2d=True)(inputs) -model = keras.Model(inputs=inputs, outputs=x) +model = keras.Model(inputs=inputs, outputs=outputs) -pytorch_checkpoint = torch.load('../classifier.pt', map_location='cpu')['model_state_dict'] +pytorch_checkpoint = torch.load('../output/final_model_75_iter/model-best_fold_0.pt', map_location='cpu')['model_state_dict'] conv_weights = [pytorch_checkpoint['module.compress_activations_conv_1.weight'].view(8, 8, 64, 24).numpy(), pytorch_checkpoint['module.compress_activations_conv_1.bias'].numpy()] -model.get_layer('classifier').conv.set_weights(conv_weights) +model.get_layer('mobile_model').classifier.conv.set_weights(conv_weights) ff_1_weights = [pytorch_checkpoint['module.fc1.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc1.bias'].numpy()] -model.get_layer('classifier').ff_1.set_weights(ff_1_weights) +model.get_layer('mobile_model').classifier.ff_1.set_weights(ff_1_weights) ff_2_weights = [pytorch_checkpoint['module.fc2.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc2.bias'].numpy()] -model.get_layer('classifier').ff_2.set_weights(ff_2_weights) +model.get_layer('mobile_model').classifier.ff_2.set_weights(ff_2_weights) ff_3_weights = [pytorch_checkpoint['module.fc3.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc3.bias'].numpy()] -model.get_layer('classifier').ff_3.set_weights(ff_3_weights) +model.get_layer('mobile_model').classifier.ff_3.set_weights(ff_3_weights) ff_4_weights = [pytorch_checkpoint['module.fc4.weight'].permute(1,0).numpy(), pytorch_checkpoint['module.fc4.bias'].numpy()] -model.get_layer('classifier').ff_4.set_weights(ff_4_weights) +model.get_layer('mobile_model').classifier.ff_4.set_weights(ff_4_weights) # frozen_sparse = ConvSparseLayer(in_channels=1, # out_channels=64, @@ -39,9 +38,9 @@ model.get_layer('classifier').ff_4.set_weights(ff_4_weights) # convo_dim=3, # rectifier=True, # lam=0.05, -# max_activation_iter=1, -# activation_lr=1) -# +# max_activation_iter=10, +# activation_lr=1e-2) + # sparse_param = torch.load('../sparse.pt', map_location='cpu') # frozen_sparse.load_state_dict(sparse_param['model_state_dict']) # @@ -60,11 +59,16 @@ model.get_layer('classifier').ff_4.set_weights(ff_4_weights) # tv.transforms.CenterCrop((100, 200)) # ]) # img = transform(img) -# + # with torch.no_grad(): # activations = frozen_sparse(img.unsqueeze(0)) -# -# output = model(img.unsqueeze(4).numpy()) + +# output = model(img.swapaxes(1, 3).swapaxes(1,2).numpy()) + +# print(activations.size()) +# print(output.shape) +# print(torch.sum(activations)) +# print(tf.math.reduce_sum(output)) input_name = model.input_names[0] index = model.input_names.index(input_name) @@ -80,5 +84,5 @@ tflite_model = converter.convert() print('Converted') -with open("./output/tf_lite_model.tflite", "wb") as f: +with open("./mobile_output/tf_lite_model.tflite", "wb") as f: f.write(tflite_model) diff --git a/keras/train_sparse_model.py b/keras/train_sparse_model.py index f54caede35daaf45bdd1c266a7a4c86f7f0e452e..d543e3af8d470e2db1b090f0718a829e40fcd18c 100644 --- a/keras/train_sparse_model.py +++ b/keras/train_sparse_model.py @@ -10,7 +10,7 @@ import os from sparse_coding_torch.load_data import load_yolo_clips import tensorflow.keras as keras import tensorflow as tf -from keras_model import normalize_weights_3d, normalize_weights, SparseCodeConv +from keras_model import normalize_weights_3d, normalize_weights, SparseCodeConv, load_pytorch_weights def plot_video(video): @@ -99,7 +99,7 @@ if __name__ == "__main__": parser.add_argument('--stride', default=2, type=int) parser.add_argument('--max_activation_iter', default=50, type=int) parser.add_argument('--activation_lr', default=1e-2, type=float) - parser.add_argument('--lr', default=1e-2, type=float) + parser.add_argument('--lr', default=5e-2, type=float) parser.add_argument('--epochs', default=100, type=int) parser.add_argument('--lam', default=0.05, type=float) parser.add_argument('--output_dir', default='./output', type=str)