#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
adx2mid.py  —  ADT/ADP → MIDI 변환기 (APT v2.1, Hi-Hat choke)
- TIME_SIG 반영/추정, PPQ 설정 가능
- 하이햇 초크: CHH(42) 노트 시 OHH(46)를 동일 tick에서 즉시 note_off
- 동일 음 재타격 시 이전 발음을 먼저 끊어 단일발음 보장
"""
import argparse, struct, re, sys
from pathlib import Path
from typing import List, Tuple, Dict
import mido

DEFAULT_SLOTS: List[Tuple[str,int]] = [
    ("Kick", 36), ("Snare", 38), ("CHH", 42), ("OHH", 46),
    ("Crash", 49), ("HighTom", 50), ("LowTom", 45), ("Ride", 51),
]
ACC_FROM_CHAR = {'.':0, 'o':1, 'O':1, 'X':2, 'x':2, '^':3}
VEL_LUT = [0, 80, 100, 118]

def parse_slots(line: str) -> List[Tuple[str,int]]:
    _, rhs = line.split('=', 1)
    parts = [p.strip() for p in rhs.split(',') if p.strip()]
    out: List[Tuple[str,int]] = []
    for p in parts:
        name, nn = p.split(':', 1)
        out.append((name.strip(), int(nn.strip())))
    return out

def cleanse_grid(s: str) -> str:
    return s.replace(' ', '').replace('|', '')

def parse_adt(path: Path):
    steps = 32; bpm = 120; midi_ch = 9; time_sig = (4,4)
    slots: List[Tuple[str,int]] = []
    rows_by_name: Dict[str,str] = {}
    with open(path, encoding='utf-8') as f:
        for ln in f:
            ln = ln.rstrip('\n')
            if ln.startswith('STEPS='):
                m = re.search(r'STEPS=(\d+)', ln); 
                if m: steps = int(m.group(1))
                m = re.search(r'BPM=(\d+)', ln); 
                if m: bpm = int(m.group(1))
                m = re.search(r'MIDI_CH=(\d+)', ln); 
                if m: midi_ch = int(m.group(1))
                m = re.search(r'TIME_SIG=(\d+)/(\d+)', ln); 
                if m: time_sig = (int(m.group(1)), int(m.group(2)))
            elif ln.startswith('SLOTS='):
                slots = parse_slots(ln)
            elif ':' in ln and not ln.startswith('SLOTS='):
                name, rest = ln.split(':', 1)
                rows_by_name[name.strip()] = cleanse_grid(rest.strip())
    if not slots:
        slots = DEFAULT_SLOTS[:]
    rows: List[List[int]] = []
    for name,_nn in slots:
        flat = list(rows_by_name.get(name, ''))
        if len(flat) < steps:
            flat += ['.'] * (steps - len(flat))
        else:
            flat = flat[:steps]
        rows.append([ACC_FROM_CHAR.get(ch, 0) for ch in flat])
    return steps, bpm, midi_ch, time_sig, slots, rows

def parse_adp(path: Path):
    data = path.read_bytes()
    if len(data) < 4:
        raise ValueError("ADP too short.")
    steps, bpm, midi_ch = struct.unpack_from('<HBB', data, 0)
    payload = data[4:]
    if steps <= 0:
        raise ValueError("Invalid steps in ADP.")
    if len(payload) % steps != 0:
        raise ValueError("ADP payload length is not a multiple of steps.")
    slots_count = len(payload) // steps
    rows = [list(payload[i*steps:(i+1)*steps]) for i in range(slots_count)]
    adt_guess = path.with_suffix('.ADT')
    slots = None; time_sig = (4,4)
    if adt_guess.exists():
        try:
            s_steps, s_bpm, s_ch, s_ts, s_slots, _ = parse_adt(adt_guess)
            if len(s_slots) == slots_count:
                slots = s_slots; time_sig = s_ts
        except Exception:
            pass
    if slots is None:
        slots = DEFAULT_SLOTS[:slots_count] if slots_count <= len(DEFAULT_SLOTS)                 else DEFAULT_SLOTS[:] + [(f"S{idx}", 35+idx) for idx in range(len(DEFAULT_SLOTS), slots_count)]
    return steps, bpm, midi_ch, time_sig, slots, rows

def infer_time_signature(steps_total:int) -> tuple[int,int]:
    if steps_total == 48: return (12,8)
    if steps_total == 24: return (6,8)
    if steps_total == 32: return (4,4)
    return (4,4)

def grid_to_midi(rows: List[List[int]], slots: List[Tuple[str,int]],
                 steps_total:int, bpm:int, midi_ch:int, time_sig:tuple[int,int],
                 ppq:int=480) -> mido.MidiFile:
    num, den = time_sig
    beats_per_bar_quarter = num * (4/den)
    steps_per_bar = steps_total // 2
    ticks_per_bar = int(round(ppq * beats_per_bar_quarter))
    ticks_per_step = max(1, int(round(ticks_per_bar / steps_per_bar)))
    gate_ticks = max(1, int(round(ticks_per_step * 0.5)))

    HH_OPEN = 46
    HH_CLOSED = 42

    mid = mido.MidiFile(ticks_per_beat=ppq, type=1)
    track = mido.MidiTrack()
    mid.tracks.append(track)

    tempo = mido.bpm2tempo(bpm)
    track.append(mido.MetaMessage('set_tempo', tempo=tempo, time=0))
    track.append(mido.MetaMessage('time_signature', numerator=num, denominator=den, clocks_per_click=24, notated_32nd_notes_per_beat=8, time=0))

    events = []
    active_notes = set()

    for row_idx, (name, note) in enumerate(slots):
        acc_row = rows[row_idx]
        for s, acc in enumerate(acc_row):
            if acc <= 0:
                continue
            vel = VEL_LUT[min(3, acc)]
            t_on = s * ticks_per_step
            t_off = t_on + gate_ticks

            # CHH chokes OHH
            if note == HH_CLOSED:
                events.append((t_on, mido.Message('note_off', note=HH_OPEN, velocity=0, channel=midi_ch, time=0)))
                active_notes.discard(HH_OPEN)

            # same-note mono
            if note in active_notes:
                events.append((t_on, mido.Message('note_off', note=note, velocity=0, channel=midi_ch, time=0)))
                active_notes.discard(note)

            events.append((t_on, mido.Message('note_on', note=note, velocity=vel, channel=midi_ch, time=0)))
            events.append((t_off, mido.Message('note_off', note=note, velocity=0, channel=midi_ch, time=0)))
            active_notes.add(note)

    events.sort(key=lambda e: (e[0], 0 if e[1].type=='note_off' else 1))

    cur = 0
    for t, msg in events:
        delta = max(0, t - cur)
        msg.time = delta
        track.append(msg)
        cur = t

    return mid

def main():
    ap = argparse.ArgumentParser(description="Convert ADT/ADP (APT v2.1) to MIDI (.mid) with Hi-Hat choke")
    ap.add_argument("input", help="Input .ADT or .ADP")
    ap.add_argument("--out", help="Output .mid path (default: same name)")
    ap.add_argument("--ppq", type=int, default=480, help="PPQ (ticks per beat), default 480")
    args = ap.parse_args()

    in_path = Path(args.input)
    if not in_path.exists():
        print(f"Input not found: {in_path}", file=sys.stderr); sys.exit(1)

    if in_path.suffix.upper() == ".ADT":
        steps, bpm, midi_ch, time_sig, slots, rows = parse_adt(in_path)
    elif in_path.suffix.upper() == ".ADP":
        steps, bpm, midi_ch, time_sig, slots, rows = parse_adp(in_path)
        if time_sig == (4,4) and steps in (48,24):
            time_sig = infer_time_signature(steps)
    else:
        print("Unsupported input type. Use .ADT or .ADP", file=sys.stderr); sys.exit(1)

    if not time_sig:
        time_sig = infer_time_signature(steps)

    mid = grid_to_midi(rows, slots, steps_total=steps, bpm=bpm, midi_ch=midi_ch, time_sig=time_sig, ppq=args.ppq)
    out_path = Path(args.out) if args.out else in_path.with_suffix(".mid")
    mid.save(out_path)
    print(f"Saved MIDI: {out_path} (steps2bars={steps}, ts={time_sig[0]}/{time_sig[1]}, bpm={bpm})")

if __name__ == "__main__":
    main()
