# -*- coding: utf-8 -*-
# USAGE
# python paseespere.py --conf conf.json

# import the necessary packages
from __future__ import print_function
from picamera2 import Picamera2, Preview
from fractions import Fraction
import argparse
import warnings
import datetime
import imutils
import json
import time
import cv2
import paho.mqtt.client as mqtt
from time import sleep
import requests
import numpy as np
import collections
import os

from enum import Enum


# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-c", "--conf", required=True,
                help="path to the JSON configuration file")
args = vars(ap.parse_args())
# filter warnings, load the configuration
# warnings.filterwarnings("ignore")
conf = json.load(open(args["conf"]))


############## VARIABLES GLOBALES ###################


STATE_SEGURO_SIN_CLIENTE = 0
STATE_SEGURO_CON_CLIENTE = 2


cosas_en_balanza = False
hay_movimiento = False

# setup ring buffers
ring_buffer = collections.deque(maxlen=conf["ring_buffer_size"])
ring_buffer_pesa = collections.deque(maxlen=conf["ring_buffer_size"])


min_temporal_whites = 9999999999
max_temporal_whites = 0

new_reference_candidates = []

# fsm state defintions

lastHeartbeat = datetime.datetime.now()
showMask = True
showThreshold = True

temporal_whites_pesa = 0

backgroundImage = None
backgroundImage_bn = None

mask_low = None
mask_hi = None
mask_pesa = None

contours_low = None
contours_hi = None
contours_pesa = None

color_lineas = (255, 0, 0, 125)
color_blanco = (255, 255, 255, 125)
thickness = 2

last_state_txt = STATE_SEGURO_SIN_CLIENTE
frames_con_tempral_whites_invariantes = 0
fsm_state = STATE_SEGURO_SIN_CLIENTE
last_temporal_whites = 0

######################################################


def test_channels(probe, image, need):
    if image is None:
        print("Probe:"+probe + "SIN IMAGEN DEFINIDA")
    elif len(image.shape) == 2:
        print("Probe:"+probe+", Need:" + str(need) +
              ", Has:1" + "Shape" + str(image.shape))
    else:
        print("Probe:"+probe+", Need:" + str(need) +
              ", Has" + str(image.shape[2]) + "Shape" + str(image.shape))


def convert_to_4_channel(image, alpha_value=255):
    """
    Converts an image to 4-channel (RGBA) format by adding an alpha channel.

    Parameters:
    - image_path: str, path to the image file.
    - alpha_value: int, the value to set for the alpha channel (default is 255 for full opacity).

    Returns:
    - image_rgba: The 4-channel image (RGBA) as a numpy array.
    """
    # Load the image (any format)

    # If the image is already 4-channel, return it as is
    if image is None:
        raise ValueError(f"Image could not be loaded.")
    if image.shape[2] == 4:
        print("Image already has 4 channels (RGBA).")
        return image

    # Convert grayscale images (1 channel) or RGB images (3 channels) to RGBA
    if len(image.shape) == 2:  # If the image is grayscale
        # Convert grayscale to RGB by duplicating the grayscale values across 3 channels
        image_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    elif image.shape[2] == 3:  # If the image is RGB
        image_rgb = image
    else:
        raise ValueError("Unsupported image format.")

    # Create an alpha channel full of the specified value (255 = full opacity)
    alpha_channel = np.ones(
        (image_rgb.shape[0], image_rgb.shape[1]), dtype=np.uint8) * alpha_value

    # Add the alpha channel to the RGB image to make it RGBA
    image_rgba = cv2.merge(
        [image_rgb[:, :, 0], image_rgb[:, :, 1], image_rgb[:, :, 2], alpha_channel])

    return image_rgba


def getContours(gray_image):
    edges = cv2.Canny(gray_image, 100, 200)

    # Step 4: Find contours
    contours, _ = cv2.findContours(
        edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    return contours


################### INICIALIZAR CAMARA #########################
print("Iniciando PiCam")
picam2 = Picamera2()
# Get available sensor modes to check if we're using the full sensor resolution
sensor_modes = picam2.sensor_modes
print(f"Available sensor modes: {sensor_modes}")
camera_config = picam2.create_still_configuration(
    main={"size": conf["resolution"],  # 640x480
          "format": "RGB888"},   # Full-color image in RGB format
    raw={"size": (3280, 2464)}    # Full resolution for raw data as well
)
picam2.configure(camera_config)
picam2.start()
picam2.set_controls({"ScalerCrop": (0, 0, 3280, 2464)})
sleep(2)  # Allow the camera to warm up
picam2.set_controls({"ExposureTime": 0, "AwbEnable": True, "AeEnable": True})
# picam2.set_controls({"AwbGainRed": Fraction(353, 256),
#                    "AwbGainBlue": Fraction(273, 128)})
################################################################


def load_image(image_path, color):
    resolution = conf["resolution"]
    if os.path.exists(image_path):
        # Load in grayscale (0: force grayscale)
        img = cv2.imread(image_path, color)
        # test_channels("LOADADIMG", img, 4)
        if img is None:
            print(
                f"Image {image_path} could not be loaded, creating a black image.")
            return np.zeros((resolution[1], resolution[0]), dtype=np.uint8)
        else:
            print(f"Loaded image {image_path}")
            return img
    else:
        print(f"Image {image_path} not found, creating a black image.")
        # np.zeros((resolution[1], resolution[0]), dtype=np.uint8)
        frame_temp = picam2.capture_array()

    if color == 0:
        frame_temp = cv2.cvtColor(frame_temp, cv2.COLOR_BGR2GRAY)
    else:
        frame_temp = convert_to_4_channel(frame_temp)

    cv2.imwrite(image_path, frame_temp)
    return frame_temp


def reloadMask():
    global backgroundImage, backgroundImage_bn, mask_low, mask_hi, mask_pesa, mask_multi
    global mask_low_rgb, mask_hi_rgb, mask_pesa_rgb, mask_multi_rgb
    global whites_hi, whites_low, whites_pesa
    global contours_low, contours_hi, contours_pesa

    # load ref image and masks
    # backgroundImage = cv2.imread(conf["backgroundImage"])

    backgroundImage = load_image(
        conf["backgroundImage"], 1)
    backgroundImage_bn = cv2.cvtColor(backgroundImage, cv2.COLOR_RGB2GRAY)
    backgroundImage = convert_to_4_channel(backgroundImage)

   # Convert RGB image to RGBA
   # backgroundImage = cv2.cvtColor(backgroundImage, cv2.COLOR_RGB2RGBA)
   # alpha_channel = np.ones(
   #    backgroundImage.shape[:2], dtype=backgroundImage.dtype) * 255  # Fully opaque

   # backgroundImage = cv2.merge(
   #    [backgroundImage[:, :, 0], backgroundImage[:, :, 1], backgroundImage[:, :, 2], alpha_channel])
    mask_low = load_image(conf["mask_low"], 0)
    mask_hi = load_image(conf["mask_hi"], 0)
    mask_pesa = load_image(conf["mask_pesa"], 0)

    contours_low = getContours(mask_low)
    contours_hi = getContours(mask_hi)
    contours_pesa = getContours(mask_pesa)

    mask_multi = cv2.bitwise_or(cv2.bitwise_or(mask_low, mask_hi), mask_pesa)

    whites_hi = np.sum(mask_hi == 255)
    whites_low = np.sum(mask_low == 255)
    whites_pesa = np.sum(mask_pesa == 255)


reloadMask()
############## MANEJO DE MASCARA ############################


def click_event(event, x, y, flags, params):
    # checking for left mouse clicks
    if event == cv2.EVENT_LBUTTONDOWN:
        global polygon_pts_array
        par_ordenado_clickeado = np.concatenate((x, y), axis=None)
        polygon_pts_array = np.append(
            polygon_pts_array, par_ordenado_clickeado)
        polygon_pts_array = polygon_pts_array.reshape((-1, 1, 2))


# relativos a maskcreator
if conf["show_video"]:
    cv2.imshow("maskcreator", backgroundImage)

    height, width, channels = backgroundImage.shape
    img_mask = np.zeros((height, width, 1), np.uint8)
    polygon_pts_array = np.empty(shape=[0, 1])
    polygon_pts_array = polygon_pts_array.astype(int)
    cv2.setMouseCallback('maskcreator', click_event)


#############################################################


################### SOCKET ##################################
def on_connect(client, userdata, flags, rc):
    print("Estado de conexión a Socket:"+str(rc))
    # Subscribing in on_connect() means that if we lose the connection
    # and reconnect then subscriptions will be renewed.
    # client.subscribe(topic)

# The callback for when a PUBLISH message is received from the server.


def on_message(client, userdata, msg):
    print("Received message '" + str(msg.payload) +
          "' on topic '" + msg.topic + "' with QoS " + str(msg.qos))


client = mqtt.Client(client_id=conf["sensor"], protocol=mqtt.MQTTv311)
print("Conectando al socket...")
# client.username_pw_set("000000000000001", "000000000000001")
try:
    client.on_connect = on_connect
    client.on_message = on_message
    client.connect(conf["maincpu_ip"], conf["broker_port"], 60)
    print("Ok")

except:
    print("ERROR AL CONECTAR AL SOCKET")

client.loop_start()
###############################################################


def load_image(image_path, color):
    resolution = conf["resolution"]
    if os.path.exists(image_path):
        # Load in grayscale (0: force grayscale)
        img = cv2.imread(image_path, color)
        # test_channels("LOADADIMG", img, 4)
        if img is None:
            print(
                f"Image {image_path} could not be loaded, creating a black image.")
            return np.zeros((resolution[1], resolution[0]), dtype=np.uint8)
        else:
            print(f"Loaded image {image_path}")
            return img
    else:
        print(f"Image {image_path} not found, creating a black image.")
        # np.zeros((resolution[1], resolution[0]), dtype=np.uint8)
        frame_temp = picam2.capture_array()

        if color == 0:
            frame_temp = cv2.cvtColor(frame_temp, cv2.COLOR_BGR2GRAY)
        else:

            frame_temp = convert_to_4_channel(frame_temp)

        cv2.imwrite(image_path, frame_temp)
        return frame_temp


mask_actual = mask_low
contours_actual = contours_low
whites_actual = whites_low
# Inicializamos con cualquier cosa a color
overlay21 = backgroundImage
overlay12 = backgroundImage
overlay22 = backgroundImage


print("[INFO] Iniciando...")
# capture frames from the camera
while True:
    frame = picam2.capture_array()
    if frame is None:
        print("Sin Frame")
        sleep(2)
        continue

    overlay11 = frame.copy()
    overlay11 = convert_to_4_channel(overlay11)
   # test_channels("E", frame, 3)
   # frame = imutils.resize(frame, width=500)
   # cv2.imshow("frame original", frame)
   # cv2.waitKey(100)

   ##### MASK CREATOR SHOW POLYGON REAL TIME #######################
    if conf["show_video"]:
        frame_orig_maskcreator = frame.copy()
        cv2.polylines(frame_orig_maskcreator, [
                      polygon_pts_array], True, color_lineas, thickness)
        cv2.imshow("maskcreator", frame_orig_maskcreator)
    ################################################################

    ######### CHECK MOVIMIENTO Y PRESENCIA DE CLIENTES Y PRODUCTOS EN PESA #####
    frame_bn = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
    # frame_bn = cv2.GaussianBlur(frame_bn, (21, 21), 0)

    # IMAGEN PRINCIPAL
    # backgroundImage_bn = cv2.GaussianBlur(backgroundImage_bn, (21, 21), 0)
    absdiff_main = cv2.absdiff(backgroundImage_bn, frame_bn)

    absdiff_main = cv2.threshold(
        absdiff_main, conf["abs_thresh"], 255, cv2.THRESH_BINARY)[1]

    kernel = np.ones((3, 3), np.uint8)
    # Apply an opening operation to remove small white dots (erosion followed by dilation)
    absdiff_main = cv2.morphologyEx(
        absdiff_main, cv2.MORPH_OPEN, kernel)

    # Optional: You can also try applying a median blur to further reduce noise
    # absdiff_main = cv2.medianBlur(
    #    absdiff_main, 5)

    absdiff_main_masked = cv2.bitwise_and(
        absdiff_main, mask_actual)

 # Find contours (areas of movement) in the thresholded image
    contours, _ = cv2.findContours(
        absdiff_main_masked, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Draw rectangles around detected movements
    for contour in contours:
        if cv2.contourArea(contour) < 500:  # Ignore small movements
            continue

        (x, y, w, h) = cv2.boundingRect(contour)
        cv2.rectangle(overlay11, (x, y), (x + w, y + h), (0, 255, 0, 200), 2)

    thresh_white_pcnt = np.sum(
        absdiff_main_masked == 255)/whites_actual
    ring_buffer.append(thresh_white_pcnt)
    temporal_whites = sum(ring_buffer)

    if (fsm_state == STATE_SEGURO_SIN_CLIENTE):
        mask_actual = mask_low
        contours_actual = contours_low
        whites_actual = whites_low
        estado_txt = "DESOCUPADA"

        if (temporal_whites > conf["change_state_thershold"]):
            timer_iniciado = False
            fsm_state = STATE_SEGURO_CON_CLIENTE

    elif (fsm_state == STATE_SEGURO_CON_CLIENTE):
        mask_actual = mask_hi
        contours_actual = contours_hi
        whites_actual = whites_hi
        estado_txt = "OCUPADA"

        if (temporal_whites <= conf["change_state_thershold"]):
            # Cliente se alejó
            print("Cliente se alejo")
            if (temporal_whites_pesa < conf["change_state_thershold_pesa"]):
                # Pesa vacía
                print("pesa vacia")
                fsm_state = STATE_SEGURO_SIN_CLIENTE
                cosas_en_balanza = False
            else:
                print("pesa ocupada")
                cosas_en_balanza = True
                # Cliente dejó cosas en la pesa y se alejó.
                if (not timer_iniciado):
                    cliente_se_fue_a_las = datetime.datetime.now()

                    timer_iniciado = True
                else:
                    lleva = round((datetime.datetime.now() -
                                  cliente_se_fue_a_las).total_seconds(), 0)
                    print("Lleva", lleva, "/",
                          conf["segundos_cliente_puede_alejarse"])

                    if (lleva > conf["segundos_cliente_puede_alejarse"]):
                        fsm_state = STATE_SEGURO_SIN_CLIENTE
                        cosas_en_balanza = False
                        timer_iniciado = False

        thresh_pesa = cv2.bitwise_and(
            absdiff_main, mask_pesa)
        overlay21 = cv2.cvtColor(thresh_pesa, cv2.COLOR_GRAY2RGBA)

        thresh_white_pesa_pcnt = np.sum(thresh_pesa == 255)/whites_pesa
        ring_buffer_pesa.append(thresh_white_pesa_pcnt)
        temporal_whites_pesa = sum(ring_buffer_pesa)

    max_temporal_whites = max(max_temporal_whites, temporal_whites)
    min_temporal_whites = min(min_temporal_whites, temporal_whites)

    # Procesamiento adicional y visualización de video...

    # print(max_temporal_whites-min_temporal_whites)
    if (not cosas_en_balanza and abs(last_temporal_whites - temporal_whites) < 0.02 and round(temporal_whites, 1) != 0.0):
        # LA IMAGEN ESTÁ ESTÁTICA, NO VARÍA
        frames_con_tempral_whites_invariantes = frames_con_tempral_whites_invariantes + \
            1  # INCREMENTAMOS EL CONTADOR DE INVARIANZA

        frame_ref_orig_absdiff_gray_multi_masked = cv2.bitwise_and(
            absdiff_main, mask_multi)
        gray_levels_scalar = np.sum(frame_ref_orig_absdiff_gray_multi_masked)
        new_reference_candidates.append((frame, gray_levels_scalar))

        gray_levels_scalar_column = [tple[1]
                                     for tple in new_reference_candidates]

        min_greys = min(gray_levels_scalar_column)
        min_index = gray_levels_scalar_column.index(min_greys)

        if (frames_con_tempral_whites_invariantes > conf["frames_rectif_fondo"]):

            gray_levels_scalar_column = [tple[1]
                                         for tple in new_reference_candidates]
            # print("Column")
            # print(gray_levels_scalar_column)

            min_greys = min(gray_levels_scalar_column)
            # print("Min")
            # print(min_greys)

            min_index = gray_levels_scalar_column.index(min_greys)
            # print("Index")
            # print(min_index)

            frames_con_tempral_whites_invariantes = 0
            backgroundImage = new_reference_candidates[min_index][0]
            backgroundImage = convert_to_4_channel(backgroundImage)
            new_reference_candidates = []
            cv2.imwrite(conf["backgroundImage"], backgroundImage)
            print("Rectificando fondo")
    else:
        frames_con_tempral_whites_invariantes = 0
        new_reference_candidates = []

    last_temporal_whites = temporal_whites

    secondsWithoutSendingMqtt = round(
        (datetime.datetime.now() - lastHeartbeat).total_seconds(), 0)

    if (last_state_txt != fsm_state) or (secondsWithoutSendingMqtt > 5):
        MQTT_MSG = json.dumps(
            {"sensor": conf["sensor"], "estadoActual": fsm_state})
        print("pdi/sensor_change --> ", MQTT_MSG)

        try:
            result = client.publish("pdi/sensor_change", MQTT_MSG)
        except:
            print("No se pudo enviar socket")
        # print(result[0])

        lastHeartbeat = datetime.datetime.now()

    last_state_txt = fsm_state

    if conf["show_video"]:
        cv2.imshow("maskcreator", frame_orig_maskcreator)

        overlay12 = backgroundImage.copy()

        if showMask:

            cv2.drawContours(overlay11, contours_actual, -
                             1, (255, 255, 255, 200), 2)
            cv2.drawContours(overlay11, contours_pesa, -
                             1, (255, 255, 255, 200), 2)

            cv2.drawContours(overlay12, contours_actual, -
                             1, (255, 255, 255, 200), 2)
            cv2.drawContours(overlay12, contours_pesa, -
                             1, (255, 255, 255, 200), 2)

        if showThreshold:
            overlay22 = cv2.cvtColor(
                absdiff_main_masked, cv2.COLOR_GRAY2RGBA)
        else:
            overlay22 = frame_ref_orig_absdiff_gray_multi_masked

        overlay11 = cv2.putText(overlay11, estado_txt, (50, 50), cv2.FONT_HERSHEY_SIMPLEX,
                                1, (255, 0, 0, 0), 2, cv2.LINE_AA)

        overlay22 = cv2.putText(overlay22, "TW:" + str(round(temporal_whites, 1))+" (Rect@"+str(frames_con_tempral_whites_invariantes)+"/"+str(conf["frames_rectif_fondo"])+")", (50, 50), cv2.FONT_HERSHEY_SIMPLEX,
                                1, (255, 0, 0, 0), 2, cv2.LINE_AA)

        overlay21 = cv2.putText(overlay21, "Pesa TW:" + str(round(temporal_whites_pesa, 1)), (50, 50), cv2.FONT_HERSHEY_SIMPLEX,
                                1, (255, 0, 0, 0), 2, cv2.LINE_AA)

        #### PREPARA MATRIZ ########
        image_col1 = np.concatenate((overlay11, overlay12), axis=0)  # 3
        image_col2 = np.concatenate((overlay21, overlay22), axis=0)  # 4
        image_row1 = np.concatenate(
            (image_col1, image_col2), axis=1)
        cv2.imshow("Pase Espere Monitor", image_row1)

    ###################### SECCION DE COMANDOS POR TECLAS ######################
    key = cv2.waitKey(1) & 0xFF

    # if the `s` key is pressed, capture new reference image
    if key == ord("s"):
        print("Capturando nueva imagen de referencia")
        backgroundImage = frame
        backgroundImage = convert_to_4_channel(backgroundImage)
        cv2.imwrite(conf["backgroundImage"], backgroundImage)
        # cv2.imshow("backgroundImage", backgroundImage)

    # if the `c` key is pressed, clear the polygon points array
    if key == ord("c"):
        polygon_pts_array = np.empty(shape=[0, 1])
        polygon_pts_array = polygon_pts_array.astype(int)

    # if the `h` key is pressed, create a high mask
    if key == ord("h"):
        print('Creando mask-hi')
        cv2.fillPoly(img_mask, [polygon_pts_array], color_blanco)
        cv2.imwrite(conf["mask_hi"], img_mask)
        reloadMask()
        polygon_pts_array = np.empty(shape=[0, 1])
        polygon_pts_array = polygon_pts_array.astype(int)
        img_mask = np.zeros((height, width, 1), np.uint8)

    # if the `l` key is pressed, create a low mask
    if key == ord("l"):
        print('Creando mask-low')
        cv2.fillPoly(img_mask, [polygon_pts_array], color_blanco)
        cv2.imwrite(conf["mask_low"], img_mask)
        reloadMask()
        polygon_pts_array = np.empty(shape=[0, 1])
        polygon_pts_array = polygon_pts_array.astype(int)
        img_mask = np.zeros((height, width, 1), np.uint8)

    # if the `p` key is pressed, create a mask for the scale (pesa)
    if key == ord("p"):
        print('Creando mask-pesa')
        cv2.fillPoly(img_mask, [polygon_pts_array], color_blanco)
        cv2.imwrite(conf["mask_pesa"], img_mask)
        reloadMask()
        polygon_pts_array = np.empty(shape=[0, 1])
        polygon_pts_array = polygon_pts_array.astype(int)
        img_mask = np.zeros((height, width, 1), np.uint8)

    if key == ord("m"):
        print('Toggling Mask Display')
        showMask = not showMask

    if key == ord("t"):
        print('Toggling Treshold Display')
        showThreshold = not showThreshold

    # if the `q` key is pressed, break from the loop
    if key == ord("q"):
        print("Saliendo del programa...")
        break

# Stop the camera and close all windows once the loop ends
picam2.stop()
cv2.destroyAllWindows()
