import lz4.block
import msgpack
import pprint
from datetime import datetime
from mitmproxy import tcp
from mitmproxy import ctx

def unpack_packet(data: bytes):
    if len(data) < 10:
        return None

    ver = int.from_bytes(data[0:1], 'big')
    cmd = int.from_bytes(data[1:3], 'big')
    seq = int.from_bytes(data[3:4], 'big')
    opcode = int.from_bytes(data[4:6], 'big')
    packed_len = int.from_bytes(data[6:10], 'big', signed=False)
    
    comp_flag = packed_len >> 24
    payload_length = packed_len & 0xFFFFFF
    
    if payload_length == 0:
        return {
            "ver": ver, "cmd": cmd, "seq": seq, "opcode": opcode,
            "payload": "[Empty Payload / System Message / ACK]"
        }

    payload_bytes = data[10:10 + payload_length]
    
    if comp_flag != 0:
        compressed_data = payload_bytes
        try:
            payload_bytes = lz4.block.decompress(compressed_data, uncompressed_size=1048576)
        except lz4.block.LZ4BlockError as e:
            return {
                "ver": ver, "cmd": cmd, "seq": seq, "opcode": opcode,
                "payload": f"[Error: LZ4 Decompression failed - {e}]"
            }

    try:
        payload = msgpack.unpackb(payload_bytes, raw=False, strict_map_key=False)
    except Exception as e:
        payload = f"[Error: MessagePack unpack failed - {e}]"
        
    return {
        "ver": ver,
        "cmd": cmd,
        "seq": seq,
        "opcode": opcode,
        "payload": payload
    }

class MaxProtoDumper:
    def tcp_message(self, flow: tcp.TCPFlow):
        host = ""
        if flow.server_conn and flow.server_conn.sni:
            host = flow.server_conn.sni
        elif flow.server_conn and flow.server_conn.address:
            host = flow.server_conn.address[0]

        if "oneme.ru" not in host and "155.212" not in host:
            return

        message = flow.messages[-1]
        raw_bytes = message.content
        
        direction = "C->S" if message.from_client else "S->C"

        parsed = unpack_packet(raw_bytes)
        if not parsed:
            return

        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]

        if isinstance(parsed["payload"], (dict, list)):
            formatted_payload = pprint.pformat(parsed["payload"], indent=2)
        else:
            formatted_payload = str(parsed["payload"])

        log_msg = (
            f"\n[{timestamp}]\n{direction}\n"
            f"VER: {parsed['ver']} | CMD: {parsed['cmd']} | SEQ: {parsed['seq']} | OPCODE: {hex(parsed['opcode'])}\n"
            f"Payload Data:\n{formatted_payload}\n"
            f"{'='*50}"
        )
        
        ctx.log.info(log_msg)
        
        with open("maxproto_decoded.txt", "a", encoding="utf-8") as f:
            f.write(log_msg + "\n")

addons = [
    MaxProtoDumper()
]