use std::fs::File;
use std::io::BufReader;
use crate::io::byteio::*;
use crate::io::bitreader::*;
use crate::util::imaadpcm::*;
use super::super::*;

const FPS_NUM: u32 = 1;
const FPS_DEN: u32 = 25;

trait ClipVal {
    fn clip(self) -> u8;
}

impl ClipVal for i32 {
    fn clip(self) -> u8 {
        (((self as u32) >> 17) & 0x7F) as u8
    }
}

struct YUV2RGB {
    y_tab:   [i16; 128],
    u_tab1: [i16; 128],
    u_tab2: [i16; 128],
    v_tab1: [i16; 128],
    v_tab2: [i16; 128],
}

impl YUV2RGB {
    fn new() -> Self {
        let mut y_tab = [0; 128];
        let mut u_tab1 = [0; 128];
        let mut u_tab2 = [0; 128];
        let mut v_tab1 = [0; 128];
        let mut v_tab2 = [0; 128];
        for (i, el) in y_tab.iter_mut().enumerate() {
            let base = i32::from((i * 2 + 1) as i8);
            *el = (base + 0x80) as i16;
            u_tab1[i] = ((base * -22550) >> 16) as i16;
            u_tab2[i] = ((i64::from(base) * 116129) >> 32) as i16;
            v_tab1[i] = ((base * 91881) >> 16) as i16;
            v_tab2[i] = (((i64::from(base) * -46799) as u64) >> 32) as i16;
        }

        Self { y_tab, u_tab1, u_tab2, v_tab1, v_tab2 }
    }
    fn convert(&self, y: u8, u: u8, v: u8, dst: &mut [u8]) {
        let luma = self.y_tab[usize::from(y)];
        dst[0] = (luma + self.v_tab1[usize::from(v)]) as u8;
        dst[1] = (luma + self.v_tab2[usize::from(v)] + self.u_tab1[usize::from(u)]) as u8;
        dst[2] = (luma + self.u_tab2[usize::from(u)]) as u8;
    }
}

const C4: i32 = 1518500250;
const C6: f32 = 0.38268343;
const Q1: f32 = 1.306563;
const Q0: f32 = 0.5411961;

// AAN DCT essentially
fn idct_row(row: &mut [i32]) {
    let tmp0 =  row[1] + row[7];
    let tmp1 = (row[1] - row[7]) as f32;
    let tmp2 =  row[5] + row[3];
    let tmp3 = (row[5] - row[3]) as f32;
    let tmp4 = (tmp1 + tmp3) * C6;
    let tmp5 = ((i64::from(tmp0 - tmp2) * i64::from(C4)) >> 32) as i32 * 2;
    let tmp6 = (tmp1 * Q1 - tmp4) as i32;
    let tmp7 = (tmp3 * Q0 + tmp4) as i32;
    let tmp8 = row[0] + row[4];
    let tmp9 = row[0] - row[4];
    let tmp10 = ((i64::from(row[2] - row[6]) * i64::from(C4)) >> 32) as i32 * 2;
    let tmp11 = row[2] + row[6] + tmp10;
    let tmp12 = tmp9 + tmp10;
    let tmp13 = tmp9 - tmp10;
    let tmp14 = tmp8 + tmp11;
    let tmp15 = tmp8 - tmp11;

    row[0] = tmp14 + tmp0 + tmp2 + tmp6;
    row[1] = tmp12 + tmp5 + tmp6;
    row[2] = tmp13 + tmp5;
    row[3] = tmp15 + tmp7;
    row[4] = tmp15 - tmp7;
    row[5] = tmp13 - tmp5;
    row[6] = tmp12 - (tmp5 + tmp6);
    row[7] = tmp14 - (tmp0 + tmp2 + tmp6);
}

fn idct(blk: &mut [i32; 64]) {
    for row in blk.chunks_exact_mut(8) {
        idct_row(row);
    }
    for i in 1..8 {
        for j in 0..i {
            blk.swap(i + j * 8, i * 8 + j);
        }
    }
    for row in blk.chunks_exact_mut(8) {
        idct_row(row);
    }
}

struct ReaperDecoder {
    fr:         FileReader<BufReader<File>>,
    frame:      Vec<u8>,
    pframe:     Vec<u8>,
    data:       Vec<u8>,
    abuf:       Vec<i16>,
    aframe_len: usize,
    audio:      bool,
    arate:      u16,
    channels:   u8,
    pred:       [i16; 2],
    width:      usize,
    height:     usize,
    yuv2rgb:    YUV2RGB,
    quants:     Vec<u8>,
    nframes:    usize,
    cur_frame:  usize,
    qmat:       [i32; 64],
    last_q:     usize,
    adpcm:      [IMAState; 2],
}

struct FrameDecoder<'a> {
    frame:      &'a mut [u8],
    pframe:     &'a [u8],
    br:         ByteReader<'a>,
    width:      usize,
    height:     usize,
    stride:     usize,
    yuv2rgb:    &'a YUV2RGB,
    qmat:       &'a [i32; 64],
}

impl<'a> FrameDecoder<'a> {
    fn do_block(&mut self, xoff: usize, yoff: usize) -> DecoderResult<()> {
        let opcode = self.br.read_byte()? ^ 0xB6;
        let mut buf = [0; 256];
        match opcode {
            0 => return Err(DecoderError::InvalidData),
            1 => {
                let mv = self.br.read_byte()?;
                let mut mv_x = (mv >> 4) as isize;
                let mut mv_y = (mv & 0xF) as isize;
                if mv_x >= 8 { mv_x -= 16; }
                if mv_y >= 8 { mv_y -= 16; }
                let src_x = xoff as isize - mv_x;
                let src_y = yoff as isize - mv_y;
                let src_pos = src_x * 3 + src_y * (self.stride as isize);
                validate!(src_pos >= 0 && (src_pos as usize) + 8 + self.stride * 7 <= self.stride * self.height);
                let src = &self.pframe[src_pos as usize..];
                let dst = &mut self.frame[xoff * 3 + yoff * self.stride..];
                for (dline, sline) in dst.chunks_mut(self.stride)
                        .zip(src.chunks(self.stride)).take(16 * 3) {
                    dline[..16 * 3].copy_from_slice(&sline[..16 * 3]);
                }
            },
            2 => {},
            3 => {
                let y = self.br.read_byte()?;
                let cr = self.br.read_byte()?;
                let cb = self.br.read_byte()?;
                self.fill_block([y; 4], cr, cb, xoff, yoff);
            },
            4 | 5 => {},
            6 => {
                let mut y = [0; 4];
                self.br.read_buf(&mut y)?;
                let cr = self.br.read_byte()?;
                let cb = self.br.read_byte()?;
                self.fill_block(y, cr, cb, xoff, yoff);
            },
            7..=11 => {},
            12 => {
                let mut y = [0; 4];
                for el in y.iter_mut() {
                    *el = self.br.read_byte()?;
                    self.br.read_skip(1)?;
                }
                let cr = self.br.read_byte()?;
                self.br.read_skip(1)?;
                let cb = self.br.read_byte()?;
                self.br.read_skip(1)?;
                self.fill_block(y, cr, cb, xoff, yoff);
            },
            _ => {
                let len = usize::from(opcode);
                self.br.read_buf(&mut buf[..len])?;
                let dct_bits = &buf[..len];
                self.dct_block(dct_bits, xoff, yoff)?;
            },
        }
        Ok(())
    }
    fn fill_block(&mut self, luma: [u8; 4], cr: u8, cb: u8, xoff: usize, yoff: usize) {
        let uu = (i32::from(cr as i8) * self.qmat[0]).clip();
        let vv = (i32::from(cb as i8) * self.qmat[0]).clip();
        for (y, line) in self.frame.chunks_exact_mut(self.stride).skip(yoff).take(16).enumerate() {
            for (x, clr) in line[xoff * 3..].chunks_exact_mut(3).take(16).enumerate() {
                let yy = luma[x / 8 + (y / 8) * 2] as i8;
                let yy = (i32::from(yy) * self.qmat[0]).clip();
                self.yuv2rgb.convert(yy, uu, vv, clr);
            }
        }
    }
    fn dct_block(&mut self, src: &[u8], xoff: usize, yoff: usize) -> DecoderResult<()> {
        let mut br = BitReader::new(src, BitReaderMode::LE);
        let mut coeffs = [[0; 64]; 6];
        for blk in coeffs.iter_mut() {
            self.decode_block(&mut br, blk)?;
            idct(blk);
        }
        let dst = &mut self.frame[xoff * 3 + yoff * self.stride..];
        for (y, line) in dst.chunks_mut(self.stride).take(16).enumerate() {
            let cur_y = if y < 8 { [&coeffs[0], &coeffs[1]] } else { [&coeffs[2], &coeffs[3]] };
            for (x, pix) in line.chunks_exact_mut(3).take(16).enumerate() {
                let yy = if x < 8 { cur_y[0][x + (y & 7) * 8] } else { cur_y[1][(x & 7) + (y & 7) * 8] };
                let uu = coeffs[4][x / 2 + (y / 2) * 8];
                let vv = coeffs[5][x / 2 + (y / 2) * 8];
                self.yuv2rgb.convert(yy.clip(), uu.clip(), vv.clip(), pix);
            }
        }
        Ok(())
    }
    fn decode_block(&self, br: &mut BitReader, blk: &mut [i32; 64]) -> DecoderResult<()> {
        blk[0] = i32::from(br.read(8)? as i8) * self.qmat[0];
        let mut idx = blk.len() - 1;
        while idx > 0 {
            match br.read(3)? {
                0 => { // zero
                    idx -= 1;
                },
                1 => { // even run of zeroes
                    let zero_run = br.read(5)? as usize * 2;
                    validate!(idx >= zero_run);
                    idx -= zero_run;
                },
                2 => {
                    blk[SCAN[idx]] = self.qmat[SCAN[idx]];
                    idx -= 1;
                },
                3 => {
                    let coef = br.read_s(5)? * 2;
                    blk[SCAN[idx]] = coef * self.qmat[SCAN[idx]];
                    idx -= 1;
                },
                4 => { // a pair of zeroes
                    validate!(idx >= 2);
                    idx -= 2;
                },
                5 => { // odd run of zeroes
                    let zero_run = br.read(5)? as usize * 2 + 1;
                    validate!(idx >= zero_run);
                    idx -= zero_run;
                },
                6 => {
                    blk[SCAN[idx]] = -self.qmat[SCAN[idx]];
                    idx -= 1;
                },
                _ => {
                    if br.peek(5) == 0x1F {
                        br.skip(5)?;
                        let coef = br.read_s(8)?;
                        blk[SCAN[idx]] = coef * self.qmat[SCAN[idx]];
                    } else {
                        let coef = (br.read_s(5)? * 2) | 1;
                        blk[SCAN[idx]] = coef * self.qmat[SCAN[idx]];
                    }
                    idx -= 1;
                },
            }
        }
        Ok(())
    }
}

impl ReaperDecoder {
    fn decode_video(&mut self) -> DecoderResult<()> {
        std::mem::swap(&mut self.frame, &mut self.pframe);

        let quant = usize::from(self.quants[self.cur_frame]);
        if quant != self.last_q {
            let scale1 = 100 - i32::from(quant as i8);
            let bias = scale1 * 22 / 100 + 8;
            let iquant = scale1 * 50 / 100 + 10 - bias;
            let mut csum = 0;
            for (drow, srow) in self.qmat.chunks_exact_mut(8).zip(QMAT.chunks_exact(8)) {
                let mut mul = csum;
                for (dst, &src) in drow.iter_mut().zip(srow.iter()) {
                    *dst = src * (mul / 14 + bias);
                    mul += iquant;
                }
                csum += iquant;
            }
            self.last_q = quant;
        }

        let mut mr = MemoryReader::new_read(&self.data);
        let br = ByteReader::new(&mut mr);

        let mut fdec = FrameDecoder {
                yuv2rgb:    &self.yuv2rgb,
                frame:      &mut self.frame,
                pframe:     &self.pframe,
                width:      self.width,
                height:     self.height,
                stride:     self.width * 3,
                br,
                qmat:       &self.qmat,
            };
        for y in (0..self.height).step_by(16) {
            for x in (0..self.width).step_by(16) {
                fdec.do_block(x, y)?;
            }
        }

        Ok(())
    }
}

impl InputSource for ReaperDecoder {
    fn get_num_streams(&self) -> usize { 2 }
    fn get_stream_info(&self, stream_no: usize) -> StreamInfo {
        match stream_no {
            0 => StreamInfo::Video(VideoInfo{
                    width:  self.width,
                    height: self.height,
                    bpp:    24,
                    tb_num: FPS_NUM,
                    tb_den: FPS_DEN,
                 }),
            1 => StreamInfo::Audio(AudioInfo{
                    sample_rate: u32::from(self.arate),
                    channels:    self.channels,
                    sample_type: AudioSample::S16,
                }),
            _ => StreamInfo::None
        }
    }
    fn decode_frame(&mut self) -> DecoderResult<(usize, Frame)> {
        let mut br = ByteReader::new(&mut self.fr);
        loop {
            if self.audio && self.abuf.len() >= self.aframe_len {
                let mut audio = vec![0; self.aframe_len];
                audio.copy_from_slice(&self.abuf[..self.aframe_len]);
                self.abuf.drain(..self.aframe_len);
                self.audio = false;
                return Ok((1, Frame::AudioS16(audio)));
            }
            let ctype = br.read_u32le()?;
            let csize = br.read_u32le()? as usize;
            validate!(csize >= 8);
            let csize = csize - 8;

            match ctype {
                0x0B85120F => {
                    validate!(csize > 0);
                    validate!(self.cur_frame < self.nframes);
                    self.data.resize(csize, 0);
                    br.read_buf(&mut self.data)?;
                    self.decode_video().map_err(|_| DecoderError::InvalidData)?;
                    self.audio = true;
                    self.cur_frame += 1;

                    return Ok((0, Frame::VideoRGB24(self.frame.clone())));
                },
                0x90FC0302 if self.arate > 0 => {
                    self.data.resize(csize, 0);
                    br.read_buf(&mut self.data)?;
                    if self.channels == 1 {
                        for &b in self.data.iter() {
                            self.abuf.push(self.adpcm[0].expand_sample(b >> 4));
                            self.abuf.push(self.adpcm[0].expand_sample(b & 0xF));
                        }
                    } else {
                        for &b in self.data.iter() {
                            self.abuf.push(self.adpcm[0].expand_sample(b >> 4));
                            self.abuf.push(self.adpcm[1].expand_sample(b & 0xF));
                        }
                    }
                },
                0xD6DB0106 => {
unimplemented!();
                },
                0x00000000 => return Err(DecoderError::EOF),
                _ => {
                    return Err(DecoderError::InvalidData);
                },
            }
        }
    }
}

fn unpack_qdata(src: &[u8], dst: &mut Vec<u8>, size: usize) -> DecoderResult<()> {
    let mut mr = MemoryReader::new_read(src);
    let mut br = ByteReader::new(&mut mr);

    let mut last_val = br.peek_byte()? ^ 0xFF;
    while br.left() > 0 {
        let val = br.read_byte()?;
        dst.push(val);
        if val == last_val {
            let len = usize::from(br.read_byte()?);
            for _ in 0..len {
                dst.push(val);
            }
        }
        last_val = val;
    }

    validate!(dst.len() == size);
    Ok(())
}

pub fn open(name: &str) -> DecoderResult<Box<dyn InputSource>> {
    let file = File::open(name).map_err(|_| DecoderError::InputNotFound(name.to_owned()))?;
    let mut fr = FileReader::new_read(BufReader::new(file));
    let mut br = ByteReader::new(&mut fr);

    let mut tag = [0; 8];
    br.read_buf(&mut tag)?;
    validate!(&tag == b"!Reaper!");
    let hdr_size = br.read_u32le()? as usize;
    validate!(hdr_size > 0x1C);
    let width = br.read_u16le()? as usize;
    let height = br.read_u16le()? as usize;
    validate!((1..=320).contains(&width) && (1..=240).contains(&height));
    validate!((width | height) & 0xF == 0);
    let nframes = br.read_u16le()? as usize;
    validate!(nframes > 0);
    let acompr = br.read_u16le()?;
    if acompr != 1 {
        return Err(DecoderError::NotImplemented);
    }
    let arate = br.read_u16le()?;
    validate!((8000..=32000).contains(&arate));
    let abits = br.read_u16le()?;
    if abits != 16 && arate != 0 {
        return Err(DecoderError::NotImplemented);
    }
    let channels = br.read_u16le()?;
    validate!((arate == 0 && channels == 0) || channels == 1 || channels == 2);
    br.read_u16le()?; // 2

    validate!(hdr_size <= 0x1C + nframes * 3);
    let mut quants = Vec::with_capacity(nframes);
    let mut qdata = vec![0; hdr_size - 0x1C];
    br.read_buf(&mut qdata)?;
    unpack_qdata(&qdata, &mut quants, nframes).map_err(|_| DecoderError::InvalidData)?;

    Ok(Box::new(ReaperDecoder {
        fr,
        width, height,
        data: Vec::new(),
        frame:  vec![0; width * height * 3],
        pframe: vec![0; width * height * 3],
        abuf: Vec::with_capacity(32000),
        aframe_len: arate as usize / 4,
        audio: false,
        arate,
        channels: channels as u8,
        pred: [0; 2],
        yuv2rgb: YUV2RGB::new(),
        quants,
        nframes,
        cur_frame: 0,
        qmat: [0; 64],
        last_q: 4242,
        adpcm: [IMAState::new(), IMAState::new()],
    }))
}

const QMAT: [i32; 64] = [
     8192,  5906,  6270,  6967,  8192, 10426, 15137,  29692,
     5906,  4258,  4520,  5023,  5906,  7517, 10913,  21407,
     6270,  4520,  4799,  5332,  6270,  7980, 11585,  22725,
     6967,  5023,  5332,  5925,  6967,  8867, 12873,  25251,
     8192,  5906,  6270,  6967,  8192, 10426, 15137,  29692,
    10426,  7517,  7980,  8867, 10426, 13270, 19266,  37791,
    15137, 10913, 11585, 12873, 15137, 19266, 27969,  54864,
    29692, 21407, 22725, 25251, 29692, 37791, 54864, 107619
];

const SCAN: [usize; 64] = [
     0,  63,  55,  62, 61,  54,  47,  39,
    46,  53,  60,  59, 52,  45,  38,  31,
    23,  30,  37,  44, 51,  58,  57,  50,
    43,  36,  29,  22, 15,   7,  14,  21,
    28,  35,  42,  49, 56,  48,  41,  34,
    27,  20,  13,   6,  5,  12,  19,  26,
    33,  40,  32,  25, 18,  11,   4,   3,
    10,  17,  24,  16,  9,   2,   1,   8
];
