#!/bin/python
import sys
import os
import subprocess
import json
from check_video import check_video_ext, normalize_language, normalize_lang_code

color = True
try:
    from termcolor import colored
except ImportError:
    if os.name == "posix":
        print("For nicer output install termcolor:\nsudo \'your installer\' python-termcolor")
    else:
        print("For nicer output install termcolor:\npip install termcolor")
    color = False

NORMAL_STYLE  = ("white", None, [])
ERROR_STYLE   = ("red", None, ["bold"])
WARN_STYLE    = ("yellow", None, ["bold"])
INFO_STYLE    = ("cyan", None, [])
SUCCESS_STYLE = ("green", None, ["bold"])
DEBUG_STYLE   = ("magenta", None, ["dark"])

def np(string, style, end = "\n"):
    if color:
        print(colored(string, *style), end=end)
    else:
        print(string, end=end)

def human_readable_size(size, decimal_places=2):
    for unit in ['B','KB','MB','GB','TB']:
        if size < 1024:
            return f"{size:.{decimal_places}f} {unit}"
        size /= 1024

def calculate_aspect(width: int, height: int) -> str:
    temp = 0

    def gcd(a, b):
        """The GCD (greatest common divisor) is the highest number that evenly divides both width and height."""
        return a if b == 0 else gcd(b, a % b)

    if width == height:
        return "1:1"

    if width < height:
        temp = width
        width = height
        height = temp

    divisor = gcd(width, height)

    x = int(width / divisor) if not temp else int(height / divisor)
    y = int(height / divisor) if not temp else int(width / divisor)

    return f"{x}:{y}"

def get_interlace_label(fo):
    if not fo:
        return "Progressive"
    
    fo = str(fo).lower()
    
    # Map ffprobe codes to standard labels
    if fo in ["tt", "tff", "tb"]:
        return "Interlaced (TFF)"
    elif fo in ["bb", "bff", "bt"]:
        return "Interlaced (BFF)"
    elif "progressive" in fo:
        return "Progressive"
    
    return "Progressive" # Default assumption for modern web video

def fixed_width(s, width, align="left", fill=" "):
    s = str(s)
    if len(s) > width:
        return s[:width]
    if align == "right":
        return s.rjust(width, fill)
    if align == "center":
        return s.center(width, fill)
    return s.ljust(width, fill)

def short_codec_name(codec):
    if not codec:
        return ""

    codec = codec.lower()

    match codec:
        # ---- Audio ----
        case "pcm_s16le":
            new_codec = "PCM16"
        case "pcm_s24le":
            new_codec = "PCM24"
        case "pcm_s32le":
            new_codec = "PCM32"
        case "pcm_f32le":
            new_codec = "PCMF"
        case "truehd":
            new_codec = "THD"

        # ---- Video ----
        case "mpeg2video":
            new_codec = "MPG2"
        case "prores":
            new_codec = "PRRS"
        case "prores_ks":
            new_codec = "PRRSK"

        # ---- Subtitles ----
        case "subrip":
            new_codec = "SRT"
        case "webvtt":
            new_codec = "VTT"
        case "hdmv_pgs_subtitle":
            new_codec = "PGS"
        case "dvb_subtitle":
            new_codec = "DVB"

        case _:
            new_codec = codec
    return fixed_width(new_codec.upper(), 5)

class video_lines:
    def __init__(self, stream):
        if stream.get("index"):
            self.id = stream.get("index")
        else:
            self.id = None

        if stream.get("name"):
            self.name = stream.get("name")
        else:
            self.name = ""

        if stream.get("name"):
            self.duration = seconds_to_hms(stream.get("duration"))
        else:
            self.duration = None

        if stream.get("codec_name"):
            self.codec = stream.get("codec_name")
        else:
            self.codec = ""

        if stream.get("width"):
            self.width = stream.get("width")
        else:
            self.width = ""

        if stream.get("height"):
            self.height = stream.get("height")
        else:
            self.height = ""

        self.resolution = f"{self.width}x{self.height}"

        if stream.get("r_frame_rate"):
            num, den = map(int, stream.get("r_frame_rate").split("/"))
            self.framerate = round(num / den, 2)
        else:
            self.framerate = ""

        if stream.get("display_aspect_ratio"):
            self.aspect_ratio = stream.get("display_aspect_ratio")
        elif self.resolution != "x":
            self.aspect_ratio = calculate_aspect(self.width, self.height)
        else:
            self.aspect_ratio = ""

        if stream.get("pix_fmt"):
            self.pix_fmt = stream.get("pix_fmt")
        else:
            self.pix_fmt = ""

        if stream.get("color_space"):
            self.color_space = stream.get("color_space")
        else:
            self.color_space = ""

        if stream.get("field_order"):
            self.field_order = get_interlace_label(stream.get("field_order"))
        else:
            self.field_order = ""

    def __str__(self):
        string = fixed_width("Video", 6)
        if self.id != None:
            string += f" {self.id:02d}: "
        else:
            string += f": "
        string += f"{short_codec_name(self.codec)}"
        if self.duration:
            string += f" {self.duration}s"
        if self.resolution != "x":
            string += f" ({self.resolution}"
            if self.framerate != "x":
                string += f"@{self.framerate})"
        if self.aspect_ratio:
            string += f" [{self.aspect_ratio}]"
        if self.pix_fmt and self.color_space:
            string += f" [{self.pix_fmt}, {self.color_space}]"
        if self.field_order:
            string += f" [{self.field_order}]"
        return string

class audio_lines:
    def __init__(self, stream):
        # 1. Basic ID
        self.id = stream.get("index")

        # 2. Name (usually in tags as 'title')
        self.name = stream.get("tags", {}).get("title", "")

        # 3. Language (usually in tags)
        raw_lang = stream.get("tags", {}).get("language", "und").lower()
        self.language = normalize_language(raw_lang)

        # 4. Duration (fallback to file duration if stream duration is missing)
        stream_dur = stream.get("duration")
        if stream_dur:
            self.duration = seconds_to_hms(float(stream_dur))
        else:
            self.duration = None

        # 5. Codec
        self.codec = stream.get("codec_name", "")

        # 6. Sample Rate (converted to kHz for readability, e.g., 48000 -> 48.0)
        sr = stream.get("sample_rate")
        self.sample_rate = f"{int(sr) / 1000} kHz" if sr else ""

        # 7. Channels
        self.channels = stream.get("channels", "")

        # 8. Bit Depth
        # PCM uses bits_per_sample; lossy like AAC/MP3 might use bits_per_raw_sample
        depth = stream.get("bits_per_sample") or stream.get("bits_per_raw_sample")
        self.bit_depth = f"{depth}-bit" if depth else ""

        # 9. Bitrate (converted to kbps)
        br = stream.get("bit_rate")
        self.bitrate = f"{int(br) // 1000} kbps" if br else ""

    def __str__(self):
        string = "dub"
        if self.id is not None:
            string += f" {self.id:02d}: "
        else:
            string += ": "
            
        string += f"{short_codec_name(self.codec)}"
        
        if self.language:
            string += f" [{self.language}]"
            
        if self.duration:
            string += f" {self.duration}s"
            
        # Grouping audio specs: Channels, Sample Rate, and Bit Depth
        specs = []
        if self.channels:
            specs.append(f"{self.channels}ch")
        if self.sample_rate:
            specs.append(self.sample_rate)
        if self.bit_depth:
            specs.append(self.bit_depth)
            
        if specs:
            string += f" ({', '.join(specs)})"
            
        if self.bitrate:
            string += f" @{self.bitrate}"
            
        if self.name:
            string += f" [{self.name}]"
            
        return string

class subtitles:
    def __init__(self, stream):
        self.id = stream.get("index")
        self.name = stream.get("tags", {}).get("title", "")
        
        # Language translation
        raw_lang = stream.get("tags", {}).get("language", "und")
        self.language = normalize_language(raw_lang)

        # Duration logic
        stream_dur = stream.get("duration")
        self.duration = seconds_to_hms(float(stream_dur)) if stream_dur else None
        
        # Codec (e.g., srt, ass, subrip)
        self.codec = stream.get("codec_name", "")

        # Disposition (Extra helpful info for subs)
        dispo = stream.get("disposition", {})
        self.is_forced = dispo.get("forced") == 1
        self.is_default = dispo.get("default") == 1

    def __str__(self):
        parts = [f"sub {self.id:02d}:" if self.id is not None else "Subtitle:"]
        
        if self.codec:
            parts.append(short_codec_name(self.codec))
            
        if self.language:
            parts.append(f"[{self.language}]")
            
        if self.duration:
            parts.append(f"{self.duration}s")
            
        # Add flags for Forced/Default
        flags = []
        if self.is_forced: flags.append("FORCED")
        if self.is_default: flags.append("Default")
        if flags:
            parts.append(f"({'/'.join(flags)})")
            
        if self.name:
            parts.append(f"[{self.name}]")
            
        return " ".join(parts)

# def get_media_info(file):
#     cmd = [
#         "ffprobe",
#         "-v", "error",
#         "-show_entries",
#         (
#             "format=duration:"
#             "stream=index,codec_type,codec_name,"
#             "width,height,r_frame_rate,bit_rate,duration,nb_frames,"
#             "pix_fmt,field_order,time_base,display_aspect_ratio,"
#             "color_space,color_transfer,color_primaries,bits_per_raw_sample,"
#             "sample_rate,channels,bits_per_sample,"
#             "stream_disposition=forced,default:"
#             "stream_tags=language,title"
#         ),
#         "-of", "json",
#         file
#     ]

#     try:
#         result = subprocess.run(cmd, capture_output=True, text=True, check=True)
#         info = json.loads(result.stdout)
#     except subprocess.CalledProcessError as e:
#         print(f"Error running ffprobe: {e.stderr}")
#         return None, [], [], []

#     # Container / file duration (string seconds, per ffprobe convention)
#     duration = None
#     if "format" in info:
#         duration = info["format"].get("duration")

#     video_streams = []
#     audio_streams = []
#     subtitle_streams = []

#     for stream in info.get("streams", []):
#         stream_type = stream.get("codec_type")
#         if stream_type == "video":
#             video_streams.append(stream)
#         elif stream_type == "audio":
#             audio_streams.append(stream)
#         elif stream_type == "subtitle":
#             subtitle_streams.append(stream)

#     if duration is not None:
#         duration = float(duration)
#     else:
#         duration = float('nan')

#     return duration, video_streams, audio_streams, subtitle_streams

def get_media_info(file):
    cmd = [
        "ffprobe",
        "-v", "error",
        "-show_entries",
        (
            "format=duration:format_tags=title:"
            "stream=index,codec_type,codec_name,"
            "width,height,r_frame_rate,bit_rate,duration,nb_frames,"
            "pix_fmt,field_order,time_base,display_aspect_ratio,"
            "color_space,color_transfer,color_primaries,bits_per_raw_sample,"
            "sample_rate,channels,bits_per_sample,"
            "stream_disposition=forced,default:"
            "stream_tags=language,title"
        ),
        "-of", "json",
        file
    ]

    try:
        result = subprocess.run(cmd, capture_output=True, text=True, check=True)
        info = json.loads(result.stdout)
    except subprocess.CalledProcessError as e:
        print(f"Error running ffprobe: {e.stderr}")
        return None, [], [], [], None

    # Duration
    duration = None
    if "format" in info:
        duration = info["format"].get("duration")

    # ✅ Extract container title
    title = None
    if "format" in info:
        title = info["format"].get("tags", {}).get("title")

    video_streams = []
    audio_streams = []
    subtitle_streams = []

    for stream in info.get("streams", []):
        stream_type = stream.get("codec_type")
        if stream_type == "video":
            video_streams.append(stream)
        elif stream_type == "audio":
            audio_streams.append(stream)
        elif stream_type == "subtitle":
            subtitle_streams.append(stream)

    if duration is not None:
        duration = float(duration)
    else:
        duration = float('nan')

    return duration, video_streams, audio_streams, subtitle_streams, title

def seconds_to_hms(seconds):
    if type(seconds) is float:
        h = int(seconds // 3600)
        m = int((seconds % 3600) // 60)
        s = int(seconds % 60)
        return f"{h:02}:{m:02}:{s:02}"
    else:
        return "ERROR"

def get_stream_bitrate(file_size, duration):
    if type(duration) is float:
        if duration != float('nan'):
            return round(float((file_size * 8)/duration/1000000), 2) if duration > 0 else 0
    return float('nan')

class video_file:
    def __init__(self, path, base_tab=""):
        self.base_tab = base_tab # \t
        self.path = path # folder/25.mkv
        self.name = os.path.basename(path) # 25.mkv
        self.size = os.path.getsize(path)

        self.duration, videos, audios, subs, title = get_media_info(path)
        
        self.sort_video_info(videos)
        self.sort_audio_info(audios)
        self.sort_subs_info(subs)

        self.bitrate = get_stream_bitrate(self.size, self.duration)
        self.size = human_readable_size(self.size) # 198MB
        self.duration = seconds_to_hms(self.duration)
        self.title = title

    def sort_video_info(self, videos):
        self.videos = []
        if videos:
            for vl in videos:
                video_line = video_lines(vl)
                self.videos.append(video_line)

    def sort_audio_info(self, audios):
        self.audios = []
        if audios:
            for al in audios:
                self.audios.append(audio_lines(al))

    def sort_subs_info(self, subs):
        self.subtitles = []
        if subs:
            for st in subs:
                subtitle = subtitles(st)
                self.subtitles.append(subtitle)

    def print(self):
        if self.base_tab == "\t":
            np(f"{self.base_tab}{self.name} ({self.size}, {self.duration}, {self.bitrate} MB/s):", NORMAL_STYLE)
        else:
            np(f"{os.path.dirname(self.path)}/", INFO_STYLE)
            np(f"\t{self.name} ({self.size}, {self.duration}, {self.bitrate} MB/s):", NORMAL_STYLE)
        if self.title:
            np(f"\t\tTitle: \"{self.title}\"", NORMAL_STYLE)
        for video in self.videos:
            np(f"\t\t{video}", NORMAL_STYLE)
        for audio in self.audios:
            np(f"\t\t{audio}", NORMAL_STYLE)
        for subtitle in self.subtitles:
            np(f"\t\t{subtitle}", NORMAL_STYLE)
        print()
        

def get_folder_info(files):
    np(f"Videos in {os.path.dirname(files[0])}/", INFO_STYLE)
    for file in files:
        file = video_file(file, "\t")
        file.print()

def get_file_info(file, file_name):
    file = video_file(file, "")
    file.print()

def handle_files(files, all_files):
    if(files != []):
        files.sort()

        grouped = []
        current_dir = None
        current_group = []

        for f in files:
            dir_path = os.path.dirname(f)
            if dir_path != current_dir:
                if current_group:
                    if len(current_group) == 1:
                        grouped.append(current_group[0])  # singleton as string
                    else:
                        grouped.append(current_group)    # multiple files as list
                current_dir = dir_path
                current_group = [f]
            else:
                current_group.append(f)

        # Add the last group
        if current_group:
            if len(current_group) == 1:
                grouped.append(current_group[0])
            else:
                grouped.append(current_group)


        all_files.extend(grouped)



def handle_folders(dirs, all_files):
    if(dirs != []):
        dirs.sort(key=lambda f: os.path.dirname(f))
        for dir in dirs:
            dir_files = []
            for file in os.scandir(dir):
                if file.is_file():
                    file = file.path
                    if check_video_ext(os.path.splitext(file)[1]):
                        dir_files.append(file)
                    # else:
                    #     np(f"{file} is not a compatabile Video file", WARN_STYLE)
            if(dir_files != []):
                dir_files.sort()
                all_files.append(dir_files)

if __name__ == "__main__":
    # print(sys.argv)
    subprocess.run(["python", "/home/honney/.bin/tracker.py", "add", "ff"])
    file_dir_array = []
    if len(sys.argv) == 0:
        print("Something went horribly wrong!")
    if len(sys.argv) == 1:
        # current_dir = os.path.dirname(os.path.realpath(__file__))
        current_dir = os.getcwd()
        handle_folders([os.path.abspath(current_dir)], file_dir_array)
    else:
        files = []
        dirs = []
        for argv in sys.argv[1:]:
            if os.path.isfile(argv):
                if check_video_ext(os.path.splitext(argv)[1]):
                    files.append(os.path.abspath(argv))
                # else:
                #     np(f"{os.path.abspath(argv)} is not a compatabile Video file", WARN_STYLE)
            elif os.path.isdir(argv):
                dirs.append(os.path.abspath(argv))
            else:
                np(f"This is not a file or directory: {argv}\nNow canceling!", ERROR_STYLE)
                sys.exit()
        handle_folders(dirs, file_dir_array)
        handle_files(files, file_dir_array)

    for element in file_dir_array:
        if type(element) == list:
            get_folder_info(element)
        else:
            file = element
            get_file_info(file, file)
