from typing import List, Optional
import argparse
import ipaddress
import socket
import struct
import threading
import time

# Based on espressif example app (Android version)
# https://github.com/EspressifApp/EsptouchForAndroid/blob/master/esptouch/src/main/java/com/espressif/iot/esptouch/protocol/EsptouchGenerator.java


class EspTouchCRC:
    """
    Implements CRC algorithm used by ESPTouch.

    This is CRC-8/MAXIM / Dallas One Wire CRC.
    Polynomial = 0x31, input reflected, output reflected, no output xor.
    """

    CRC_POLYNOM = 0x8C  # 0x31 reflected

    crcTable = None

    def __init__(self, init: int = 0):
        if EspTouchCRC.crcTable is None:
            EspTouchCRC.crcTable = [0] * 256
            for i in range(256):
                crc = i
                for _ in range(8):
                    if (crc & 0x01) != 0:
                        crc = (crc >> 1) ^ EspTouchCRC.CRC_POLYNOM
                    else:
                        crc >>= 1
                EspTouchCRC.crcTable[i] = crc

        self.value = init

    def update(self, data: bytes) -> int:
        for d in data:
            d ^= self.value
            self.value = (EspTouchCRC.crcTable[d & 0xFF] ^ (self.value << 8)) & 0xFFFF
        return self.value & 0xFF


class ESPTouchDataCode:
    """Represents a data byte to be encoded and transmitted."""

    MAX_INDEX = 127

    def __init__(self, data: int, index: int):
        if index > self.MAX_INDEX:
            raise Exception(f"index must be <= {self.MAX_INDEX}")

        data &= 0xFF
        index &= 0xFF

        self.data_high, self.data_low = self._split_byte(data)
        crc = EspTouchCRC().update(bytes([data, index]))
        self.crc_high, self.crc_low = self._split_byte(crc)
        self.seq_no = index

    def to_bytes(self) -> bytes:
        buffer = bytearray()
        buffer.append(0x00)
        buffer.append(self._combine_byte(self.crc_high, self.data_high))
        buffer.append(0x01)
        buffer.append(self.seq_no)
        buffer.append(0x00)
        buffer.append(self._combine_byte(self.crc_low, self.data_low))
        return bytes(buffer)

    def _split_byte(self, data: int):
        high = (data & 0xF0) >> 4
        low = data & 0x0F
        high = struct.pack("B", high)[0]
        low = struct.pack("B", low)[0]
        return high, low

    def _combine_byte(self, high: int, low: int) -> int:
        return ((high & 0x0F) << 4) | (low & 0x0F)


class ESPTouch:
    """
    Implementation of ESPTouch device provisioning algorithm.

    payload: arbitrary string payload encoded into the apPwd field.
    encoded_ip: IP address encoded into the ESPTouch ipAddress field.
    bind_ip: local adapter IP to bind the UDP socket to.

    The original script mixed encoded_ip and bind_ip into one parameter, which
    made --adapter_ip override the IP transmitted to the device. Here they are
    intentionally separated.
    """

    def __init__(
        self,
        ssid: str,
        payload: str,
        encoded_ip: Optional[str] = None,
        bind_ip: Optional[str] = None,
        bssid: str = "00:00:00:00:00:00",
    ) -> None:
        self._ssid = ssid
        self._payload = payload
        self._encoded_ip = encoded_ip
        self._bind_ip = bind_ip
        self._bssid = bssid

        self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)

        if self._bind_ip is not None:
            self._sock.bind((self._bind_ip, 0))

        self._sock.connect(("255.255.255.255", 7001))

        if self._encoded_ip is None:
            # Default to the actual local interface IP selected by the OS/socket.
            self._encoded_ip = self._sock.getsockname()[0]

        ssid_enc = self._ssid.encode()
        payload_enc = self._payload.encode()
        ip_enc = ipaddress.ip_address(self._encoded_ip).packed
        bssid_enc = bytes.fromhex(self._bssid.replace(":", ""))

        guide_code_packet_lengths = [515, 514, 513, 512]
        self._guide_code_packets = [b"\x01" * i for i in guide_code_packet_lengths]

        total_xor = 0
        payload_len = len(payload_enc)
        ssid_crc = EspTouchCRC().update(ssid_enc)
        bssid_crc = EspTouchCRC().update(bssid_enc)
        ssid_len = len(ssid_enc)
        ip_len = len(ip_enc)

        extra_head_len = 5
        total_len = extra_head_len + ip_len + payload_len + ssid_len

        codes: List[ESPTouchDataCode] = []

        codes.append(ESPTouchDataCode(total_len & 0xFF, 0))
        total_xor ^= total_len & 0xFF

        codes.append(ESPTouchDataCode(payload_len & 0xFF, 1))
        total_xor ^= payload_len & 0xFF

        codes.append(ESPTouchDataCode(ssid_crc & 0xFF, 2))
        total_xor ^= ssid_crc & 0xFF

        codes.append(ESPTouchDataCode(bssid_crc & 0xFF, 3))
        total_xor ^= bssid_crc & 0xFF

        for i in range(ip_len):
            b = ip_enc[i]
            codes.append(ESPTouchDataCode(b, extra_head_len + i))
            total_xor ^= b

        for i in range(payload_len):
            b = payload_enc[i]
            codes.append(ESPTouchDataCode(b, extra_head_len + ip_len + i))
            total_xor ^= b

        for i in range(ssid_len):
            b = ssid_enc[i]
            codes.append(ESPTouchDataCode(b, extra_head_len + ip_len + payload_len + i))
            total_xor ^= b

        codes.append(ESPTouchDataCode(total_xor & 0xFF, 4))

        bssid_insert_index = extra_head_len
        for i in range(len(bssid_enc)):
            b = bssid_enc[i]
            code = ESPTouchDataCode(b, total_len + i)
            if bssid_insert_index >= len(codes):
                codes.append(code)
            else:
                codes.insert(bssid_insert_index, code)
            bssid_insert_index += 4

        datum_code_bytes = bytearray()
        for code in codes:
            datum_code_bytes.extend(code.to_bytes())

        extra_length = 40
        datum_code_packet_lengths = []
        for i in range(len(datum_code_bytes) // 2):
            high = datum_code_bytes[i * 2 + 0]
            low = datum_code_bytes[i * 2 + 1]
            high_low = ((high & 0xFF) << 8) | (low & 0xFF)
            datum_code_packet_lengths.append(high_low + extra_length)

        self._datum_code_packets = [b"\x01" * i for i in datum_code_packet_lengths]
        self._thread = threading.Thread(target=self._thread_target, daemon=True)
        self._should_be_running = True

    def start(self):
        self._thread.start()

    def stop(self):
        self._should_be_running = False
        self._thread.join()

    def _thread_target(self):
        interval_guide_code = 8
        interval_data_code = 8
        timeout_guide_code = 2000
        timeout_data_code = 4000

        index = 0
        increment = 3

        while self._should_be_running:
            start = time.time()
            while self._should_be_running and (time.time() - start) < (timeout_guide_code / 1000.0):
                for datagram in self._guide_code_packets:
                    self._sock.send(datagram)
                    time.sleep(interval_guide_code / 1000.0)

            start = time.time()
            while self._should_be_running and (time.time() - start) < (timeout_data_code / 1000.0):
                for datagram in self._datum_code_packets[index:index + increment]:
                    self._sock.send(datagram)
                    time.sleep(interval_data_code / 1000.0)
                index = (index + increment) % len(self._datum_code_packets)


def main():
    try:
        parser = argparse.ArgumentParser(
            description=(
                "Provision Riden RD60xx power supply with Wi-Fi network details and data endpoint address."
            )
        )
        parser.add_argument("ssid", type=str, help="SSID of the Wi-Fi network")
        parser.add_argument("password", type=str, help="Password of the Wi-Fi network")
        parser.add_argument(
            "endpoint_ip",
            type=str,
            help="IP address that should be encoded and sent to the device as the endpoint/server address",
        )
        parser.add_argument(
            "--adapter_ip",
            type=str,
            help="Local adapter IP to bind the UDP broadcast socket to",
        )
        args = parser.parse_args()

        # First stream sends the endpoint/server IP in the ESPTouch ipAddress field.
        instance_ip = ESPTouch(
            ssid=args.ssid,
            payload="",
            encoded_ip=args.endpoint_ip,
            bind_ip=args.adapter_ip,
        )

        # Second stream sends the Wi-Fi password, while still encoding the same endpoint IP.
        instance_pass = ESPTouch(
            ssid=args.ssid,
            payload=args.password,
            encoded_ip=args.endpoint_ip,
            bind_ip=args.adapter_ip,
        )

        instance_ip.start()

        print("Please press enter when power supply displays message 'Connecting wifi'...", end="", flush=True)
        input()

        instance_ip.stop()
        instance_pass.start()

        print("Please press enter when power supply displays message 'Connecting server'...", end="", flush=True)
        input()

        instance_pass.stop()
    except KeyboardInterrupt:
        pass


if __name__ == "__main__":
    main()
