#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
mid2adt.py (APT v2.1, auto time-signature + HH exclusivity)
- MIDI를 2마디 단위로 ADT(Ardule Pattern Text)로 변환
- 타임시그 자동 감지(12/8 → 48스텝, 4/4 → 32스텝 등)
- 텍스트 그리드: 첫 데이터 11열, 중앙 ' || ', 박당 스텝 간격으로 공백(4/4=4, 12/8=6)
- 하이햇 상호배제: 같은 스텝에서 CHH(42)와 OHH(46)가 겹치면 CHH 우선, OHH 제거
- FAT 8.3 파일명: {BASE6}{NN}.ADT
"""

import argparse
import re
from pathlib import Path
from typing import Dict, List, Tuple, Union

import mido

SLOT_DEF: List[Tuple[str, int]] = [
    ("Kick", 36),
    ("Snare", 38),
    ("CHH", 42),
    ("OHH", 46),
    ("Crash", 49),
    ("HighTom", 50),
    ("LowTom", 45),
    ("Ride", 51),
]
NOTE_TO_SLOT: Dict[int, int] = {nn: idx for idx, (_, nn) in enumerate(SLOT_DEF)}

def vel_to_char(vel: int) -> str:
    if vel <= 0:
        return "."
    elif vel < 90:
        return "o"
    elif vel < 110:
        return "X"
    else:
        return "^"

def _prefix_for_name(name: str) -> str:
    label = f"{name}:"
    pad = max(0, 10 - len(label))
    return label + (" " * pad)

def _format_row(chars: List[str], steps_per_bar: int, beats_per_bar: float) -> str:
    block_size = max(1, int(round(steps_per_bar / beats_per_bar)))
    blocks = ["".join(chars[i:i+block_size]) for i in range(0, len(chars), block_size)]
    half = len(blocks) // 2
    left = " ".join(blocks[:half])
    right = " ".join(blocks[half:])
    return f"{left} || {right}"

def _sorted_slots_for_output() -> List[Tuple[str, int]]:
    return sorted(SLOT_DEF, key=lambda x: x[1], reverse=True)

def get_first_bpm(mid: mido.MidiFile) -> int:
    for track in mid.tracks:
        for msg in track:
            if msg.type == 'set_tempo':
                try:
                    return max(30, min(300, round(mido.tempo2bpm(msg.tempo))))
                except Exception:
                    pass
    return 120

def get_time_signature(mid: mido.MidiFile):
    for t in mid.tracks:
        for msg in t:
            if msg.type == 'time_signature':
                return msg.numerator, msg.denominator
    return 4, 4

def calc_grid_params(numerator: int, denominator: int, steps_override: Union[int, None]):
    if steps_override:
        steps_total = steps_override
        steps_per_bar = steps_total // 2
        if steps_per_bar % 6 == 0:
            beats = steps_per_bar / 6
        else:
            beats = 4
        return steps_total, steps_per_bar, beats, (steps_per_bar % 6 == 0)

    if denominator == 8 and (numerator % 3 == 0):
        beats_per_bar = numerator / 3.0
        steps_per_bar = int(round(beats_per_bar * 6))
        steps_total = steps_per_bar * 2
        return steps_total, steps_per_bar, beats_per_bar, True

    beats_per_bar = 4.0 * numerator / denominator
    steps_per_bar = int(round(16 * numerator / denominator))
    steps_total = steps_per_bar * 2
    return steps_total, steps_per_bar, beats_per_bar, False

def build_empty_grid(steps: int) -> List[List[str]]:
    return [["."] * steps for _ in SLOT_DEF]

def write_adt(path: Path, steps_total: int, bpm: int, midi_ch: int,
              grid: List[List[str]], numerator:int, denominator:int,
              steps_per_bar:int, beats_per_bar:float) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        f.write(f"STEPS={steps_total} BPM={bpm} MIDI_CH={midi_ch} TIME_SIG={numerator}/{denominator}\n")
        f.write("SLOTS=" + ",".join(f"{name}:{nn}" for name, nn in SLOT_DEF) + "\n")
        name_to_row = {name: row for (name, _), row in zip(SLOT_DEF, grid)}
        for name, _nn in _sorted_slots_for_output():
            row = name_to_row[name]
            f.write(_prefix_for_name(name) + _format_row(row, steps_per_bar, beats_per_bar) + "\n")

def quantize_to_step(pos_ticks: int, ticks_in_window: int, steps_total: int) -> int:
    frac = pos_ticks / ticks_in_window if ticks_in_window else 0.0
    s = int(round(frac * steps_total))
    if s >= steps_total:
        s = steps_total - 1
    if s < 0:
        s = 0
    return s

def extract_chunks_to_grids(mid: mido.MidiFile,
                            target_channel: Union[int, str],
                            steps_total: int,
                            min_vel: int):
    ppq = mid.ticks_per_beat
    bpm = get_first_bpm(mid)
    num, den = get_time_signature(mid)
    beats_per_bar_quarter = num * (4 / den)
    ticks_per_bar = int(round(ppq * beats_per_bar_quarter))
    ticks_two_bars = ticks_per_bar * 2

    merged = mido.merge_tracks(mid.tracks)
    abs_ticks = 0
    grids: Dict[int, List[List[str]]] = {}

    def pick_channel(msg) -> bool:
        if not hasattr(msg, "channel"):
            return False
        if target_channel == "all":
            return True
        return (msg.channel == target_channel)

    for msg in merged:
        abs_ticks += msg.time
        if msg.type == "note_on" and msg.velocity > 0 and pick_channel(msg):
            nn = msg.note
            vel = msg.velocity
            if vel < min_vel:
                continue
            if nn not in NOTE_TO_SLOT:
                continue
            chunk_idx = abs_ticks // ticks_two_bars
            pos_in_win = abs_ticks % ticks_two_bars
            step_idx = quantize_to_step(pos_in_win, ticks_two_bars, steps_total)
            if chunk_idx not in grids:
                grids[chunk_idx] = build_empty_grid(steps_total)
            slot = NOTE_TO_SLOT[nn]
            grids[chunk_idx][slot][step_idx] = vel_to_char(vel)

    if not grids:
        return bpm, num, den, []

    max_idx = max(grids.keys())
    out = [grids[i] if i in grids else build_empty_grid(steps_total) for i in range(max_idx + 1)]

    # Hi-Hat exclusivity: CHH wins over OHH on the same step
    name_to_row_index = {name: idx for idx, (name, _) in enumerate(SLOT_DEF)}
    if "CHH" in name_to_row_index and "OHH" in name_to_row_index:
        chh_idx = name_to_row_index["CHH"]
        ohh_idx = name_to_row_index["OHH"]
        for g in out:
            steps = len(g[chh_idx])
            for s in range(steps):
                if g[chh_idx][s] != "." and g[ohh_idx][s] != ".":
                    g[ohh_idx][s] = "."

    return bpm, num, den, out

def make_83_name(base: str, index: int) -> str:
    clean = re.sub(r'[^A-Za-z0-9]', '', base).upper()
    if len(clean) < 6:
        clean = (clean + "XXXXXX")[:6]
    return f"{clean[:6]}{index:02d}.ADT"

def main():
    ap = argparse.ArgumentParser(description="Slice MIDI into 2-bar APT v2.1 ADT files (8.3, auto TS, HH exclusivity).")
    ap.add_argument("midi", help="Input MIDI path")
    ap.add_argument("--channel", default="9", help="0..15 or 'all' (default: 9)")
    ap.add_argument("--steps", type=int, default=None, help="Total steps for 2 bars (auto by TIME_SIG if omitted)")
    ap.add_argument("--min-vel", type=int, default=1, help="Ignore hits below this velocity")
    ap.add_argument("--max-chunks", type=int, default=99, help="Max chunks (01..99)")
    ap.add_argument("--out-dir", default=None, help="Output directory (default: same as input)")
    ap.add_argument("--out-prefix", default=None, help="Override base for 8.3 generation")
    args = ap.parse_args()

    in_path = Path(args.midi)
    if not in_path.exists():
        raise SystemExit(f"Input not found: {in_path}")

    try:
        target_channel: Union[int, str]
        if str(args.channel).lower() == "all":
            target_channel = "all"
        else:
            target_channel = int(args.channel)
            if not (0 <= target_channel <= 15):
                raise ValueError
    except Exception:
        raise SystemExit("--channel must be 0..15 or 'all'")

    mid = mido.MidiFile(in_path, clip=True)
    num, den = get_time_signature(mid)
    steps_total, steps_per_bar, beats_per_bar, _ = calc_grid_params(num, den, args.steps)

    bpm, _n, _d, grids = extract_chunks_to_grids(mid, target_channel, steps_total, min_vel=args.min_vel)

    out_dir = Path(args.out_dir) if args.out_dir else in_path.parent
    base = args.out_prefix or in_path.stem

    written = 0
    for i, grid in enumerate(grids, start=1):
        if written >= min(args.max_chunks, 99):
            break
        out_name = make_83_name(base, i)
        out_path = out_dir / out_name
        write_adt(out_path, steps_total=steps_total, bpm=bpm, midi_ch=9, grid=grid,
                  numerator=num, denominator=den, steps_per_bar=steps_per_bar, beats_per_bar=beats_per_bar)
        written += 1

    print(f"Wrote {written} ADT (TS={num}/{den}, steps2bars={steps_total}, HH exclusive) to {out_dir}")

if __name__ == "__main__":
    main()
