import tensorflow as tf
physical_devices = tf.config.experimental.list_physical_devices('GPU')
from absl import app, flags, logging
from absl.flags import FLAGS
import yolov4.core.utils as utils
from yolov4.core.yolov4 import filter_boxes
from tensorflow.python.saved_model import tag_constants
from PIL import Image
import cv2
import numpy as np
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
import time

class YoloModel():
    def __init__(self):
        flags.DEFINE_string('framework', 'tf', '(tf, tflite, trt')
        flags.DEFINE_string('weights', 'yolov4/Pleural_Line_TensorFlow',
                            'path to weights file')
        flags.DEFINE_integer('size', 416, 'resize images to')
        flags.DEFINE_boolean('tiny', False, 'yolo or yolo-tiny')
        flags.DEFINE_string('model', 'yolov4', 'yolov3 or yolov4')
        flags.DEFINE_string('image', '/shared_data/YOLO_Updated_PL_Model_Results/Sliding/image_677741729740_clean/frame0.png', 'path to input image')
        flags.DEFINE_string('output', 'result.png', 'path to output image')
        flags.DEFINE_float('iou', 0.45, 'iou threshold')
        flags.DEFINE_float('score', 0.25, 'score threshold')

        FLAGS(['detect.py'])

        config = ConfigProto()
        config.gpu_options.allow_growth = True
        session = InteractiveSession(config=config)
        STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config(FLAGS)
        self.input_size = FLAGS.size
        image_path = FLAGS.image

        start = time.time()
        print('loading model\n')
        self.saved_model_loaded = tf.saved_model.load(FLAGS.weights, tags=[tag_constants.SERVING])
        end = time.time()
        elapsed_time = end - start
        print('model loaded\n')
        print('Took %.2f seconds to load model\n' % (elapsed_time))

    def get_bounding_boxes(self, original_image):

        start = time.time()
#         print('image resizing\n')
        image_data = cv2.resize(original_image, (self.input_size, self.input_size))
        image_data = image_data / 255.
        images_data = []
        for i in range(1):
            images_data.append(image_data)
        images_data = np.asarray(images_data).astype(np.float32)
        
#         print('running as tensorflow\n')

        #print('loading model\n')
#         print(FLAGS.weights)
#         saved_model_loaded = tf.saved_model.load(FLAGS.weights, tags=[tag_constants.SERVING])
        #print('model loaded\n')

        if FLAGS.framework == 'tflite':
#         print('running as tflite\n')
            interpreter = tf.lite.Interpreter(model_path=FLAGS.weights)
            interpreter.allocate_tensors()
            input_details = interpreter.get_input_details()
            output_details = interpreter.get_output_details()
            interpreter.set_tensor(input_details[0]['index'], images_data)
            interpreter.invoke()
            pred = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))]
            if FLAGS.model == 'yolov3' and FLAGS.tiny == True:
                boxes, pred_conf = filter_boxes(pred[1], pred[0], score_threshold=0.25, input_shape=tf.constant([input_size, input_size]))
            else:
                boxes, pred_conf = filter_boxes(pred[0], pred[1], score_threshold=0.25, input_shape=tf.constant([input_size, input_size]))
        else:
            infer = self.saved_model_loaded.signatures['serving_default']
    #         print('batch data\n')
            batch_data = tf.constant(images_data)
    #         print('computing bounding box data\n')
            yolo_start_time = time.time()
            pred_bbox = infer(batch_data)
            for key, value in pred_bbox.items():
                boxes = value[:, :, 0:4]
                pred_conf = value[:, :, 4:]
#                 print("VALUE", value)
            yolo_end_time = time.time()
            yolo_elapsed_time = yolo_end_time - yolo_start_time
#         print('Took %.2f seconds to run yolo\n' % (yolo_elapsed_time))

#         print('non max suppression\n')
        boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression(
            boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)),
            scores=tf.reshape(
                pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])),
            max_output_size_per_class=50,
            max_total_size=50,
            iou_threshold=0.5,
            score_threshold=0.25
        )

#         print('formatting bounding box data\n')
        boxes = boxes.numpy()

        # remove bounding boxes with zero area
        boxes = boxes.tolist()
        boxes = boxes[0]
        boxes_list = []
        for box in boxes:
            sum = 0
            for value in box:
                sum += value
            if sum > 0:
                boxes_list.append(box)
        boxes_list = [boxes_list]
        boxes = np.array(boxes_list)

        end = time.time()
        elapsed_time = end - start
#         print('Took %.2f seconds to run whole bounding box function\n' % (elapsed_time))
        return boxes