from tensorflow import keras
import numpy as np
import torch
import tensorflow as tf
import cv2
import torchvision as tv
import torch
import torch.nn as nn
from sparse_coding_torch.video_loader import VideoGrayScaler, MinMaxScaler
from sparse_coding_torch.keras_model import MobileModel

inputs = keras.Input(shape=(5, 100, 200, 3))

outputs = MobileModel(sparse_checkpoint='../sparse.pt', batch_size=1, in_channels=1, out_channels=64, kernel_size=15, stride=2, lam=0.05, activation_lr=1e-1, max_activation_iter=100, run_2d=True)(inputs)

model = keras.Model(inputs=inputs, outputs=outputs)


pytorch_checkpoint = torch.load('../stride_2_100_iter.pt', map_location='cpu')['model_state_dict']
conv_weights = [pytorch_checkpoint['module.compress_activations_conv_1.weight'].squeeze(2).swapaxes(0, 2).swapaxes(1, 3).swapaxes(2, 3).numpy(), pytorch_checkpoint['module.compress_activations_conv_1.bias'].numpy()]
model.get_layer('mobile_model').classifier.conv.set_weights(conv_weights)
ff_3_weights = [pytorch_checkpoint['module.fc3.weight'].swapaxes(1,0).numpy(), pytorch_checkpoint['module.fc3.bias'].numpy()]
model.get_layer('mobile_model').classifier.ff_3.set_weights(ff_3_weights)
ff_4_weights = [pytorch_checkpoint['module.fc4.weight'].swapaxes(1,0).numpy(), pytorch_checkpoint['module.fc4.bias'].numpy()]
model.get_layer('mobile_model').classifier.ff_4.set_weights(ff_4_weights)

input_name = model.input_names[0]
index = model.input_names.index(input_name)
model.inputs[index].set_shape([1, 5, 100, 200, 3])

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]

tflite_model = converter.convert()

print('Converted')

with open("./mobile_output/tf_lite_model.tflite", "wb") as f:
    f.write(tflite_model)
