use std::fs::File;
use crate::io::byteio::*;
use super::super::*;
use crate::input::util::lzss::lz_copy;

const ABLK_LEN: usize = 1321;

struct SMVDecoder {
    fr:         FileReader<File>,
    width:      usize,
    height:     usize,
    nclrs:      usize,
    nib_offs:   usize,
    pal:        [u8; 768],
    frame:      Vec<u8>,
    data:       Vec<u8>,
    ubuf:       Vec<u8>,
    codec_id:   u8,
}

impl SMVDecoder {
    fn unpack_frame(&mut self) -> DecoderResult<()> {
        match self.codec_id {
            0 | 1 => {
                lz_unpack(&self.data, &mut self.ubuf, self.codec_id == 0)?;
                let vec_size = self.width * self.height / 16;
                validate!(self.ubuf.len() == vec_size + self.width * self.height / 2);
                let (vectors, blocks) = self.ubuf.split_at(vec_size);
                let mut blk_iter = vectors.chunks_exact(16).zip(blocks.chunks_exact(16 * 16 / 2));
                for strip in self.frame.chunks_exact_mut(self.width * 16) {
                    for x in (0..self.width).step_by(16) {
                        let (blk_vec, blk_data) = blk_iter.next().unwrap();
                        for (line, bline) in strip.chunks_exact_mut(self.width)
                                .zip(blk_data.chunks_exact(16 / 2)) {
                            for (pair, &b) in line[x..].chunks_exact_mut(2).zip(bline.iter()) {
                                pair[0] = blk_vec[usize::from(b >> 4)];
                                pair[1] = blk_vec[usize::from(b & 0xF)];
                            }
                        }
                    }
                }
            },
            2 => {
                const NPARTITIONS: usize = 8;
                lz_unpack(&self.data, &mut self.ubuf, false)?;
                let vec_size = self.width * self.height / 16;
                validate!(self.ubuf.len() == vec_size + 256 * 2 * NPARTITIONS + self.width * self.height / 4);

                let (vectors, data) = self.ubuf.split_at(vec_size);
                let (blk_data, codebooks) = data.split_at(self.width * self.height / 4);
                let blk_data_size = self.width * self.height / 4 / NPARTITIONS;
                for (partition, codebook) in codebooks.chunks_exact(256 * 2).enumerate() {
                    let xstart = (partition & 1) * (self.width / 2);
                    let ystart = (partition >> 1) * (self.height / 4);

                    let mut blk_iter = blk_data[partition * blk_data_size..].chunks_exact(16 * 16 / 4);
                    let mut bvrow_iter = vectors[ystart / 16 * self.width + xstart..].chunks(self.width);
                    for strip in self.frame[ystart * self.width..].chunks_mut(self.width * 16).take(self.height / 16 / 4) {
                        let bvrow = bvrow_iter.next().unwrap();
                        for (xoff, blk_vec) in (xstart..).step_by(16).take(self.width / 16 / 2)
                                .zip(bvrow.chunks_exact(16)) {
                            let bdata = blk_iter.next().unwrap();
                            for (lines, indices) in strip[xoff..].chunks_mut(self.width * 2).zip(bdata.chunks_exact(8)) {
                                for (i, &idx) in indices.iter().enumerate() {
                                    let idx = usize::from(idx);
                                    let nibs = read_u16be(&codebook[idx * 2..]).unwrap_or(0) as usize;
                                    lines[i * 2]                  = blk_vec[nibs >> 12];
                                    lines[i * 2 + 1]              = blk_vec[(nibs >> 8) & 0xF];
                                    lines[i * 2 + self.width]     = blk_vec[(nibs >> 4) & 0xF];
                                    lines[i * 2 + self.width + 1] = blk_vec[nibs & 0xF];
                                }
                            }
                        }
                    }
                }
            },
            3 => unimplemented!(),
            _ => unreachable!(),
        }
        Ok(())
    }
}

impl InputSource for SMVDecoder {
    fn get_num_streams(&self) -> usize { 2 }
    fn get_stream_info(&self, stream_no: usize) -> StreamInfo {
        if stream_no == 0 {
            StreamInfo::Video(VideoInfo{
                width:  self.width,
                height: self.height,
                bpp:    8,
                tb_num: 1,
                tb_den: 12,
             })
        } else if stream_no == 1 {
            StreamInfo::Audio(AudioInfo{
                sample_rate: 15862,
                sample_type: AudioSample::U8,
                channels:    1,
             })
        } else {
            StreamInfo::None
        }
    }
    fn decode_frame(&mut self) -> DecoderResult<(usize, Frame)> {
        let mut br = ByteReader::new(&mut self.fr);
        let mut tag = [0; 2];
        loop {
            br.read_buf(&mut tag)?;
            let size = br.read_u16le()? as usize;
            match &tag {
                b"GP" => {
                    validate!(size == self.nclrs * 3);
                    br.read_vga_pal_some(&mut self.pal[..size])?;
                },
                b"MU" => {
                    let mut audio = Vec::with_capacity(ABLK_LEN);
                    if size == ABLK_LEN {
                        audio.resize(ABLK_LEN, 0);
                        br.read_buf(&mut audio)?;
                    } else {
                        self.data.resize(size, 0);
                        br.read_buf(&mut self.data)?;
                        lz_unpack(&self.data, &mut audio, true).map_err(|_| DecoderError::InvalidData)?;
                        validate!(audio.len() == ABLK_LEN);
                    }
                    return Ok((1, Frame::AudioU8(audio)));
                },
                b"FR" => {
                    self.data.resize(size, 0);
                    br.read_buf(&mut self.data)?;
                    self.unpack_frame().map_err(|_| DecoderError::InvalidData)?;
                    return Ok((0, Frame::VideoPal(self.frame.clone(), self.pal)));
                },
                b"FE" => return Err(DecoderError::EOF),
                _ => {
                    println!("unknown tag {}{} @ {:X}", tag[0] as char, tag[1] as char, br.tell() - 4);
                    return Err(DecoderError::InvalidData);
                }
            }
        }
    }
}

struct HybridReader<'a> {
    src:        &'a [u8],
    pos:        usize,
    bitbuf:     u8,
    bits:       u8,
}

impl<'a> HybridReader<'a> {
    fn new(src: &'a [u8]) -> Self {
        Self {
            src,
            pos: 0,
            bitbuf: 0,
            bits: 0,
        }
    }
    fn read_byte(&mut self) -> DecoderResult<u8> {
        if self.pos < self.src.len() {
            let ret = self.src[self.pos];
            self.pos += 1;
            Ok(ret)
        } else {
            Err(DecoderError::ShortData)
        }
    }
    fn read_bit(&mut self) -> DecoderResult<bool> {
        if self.bits == 0 {
            self.bitbuf = self.read_byte()?;
            self.bits = 8;
        }
        let ret = (self.bitbuf & 0x80) != 0;
        self.bitbuf <<= 1;
        self.bits    -= 1;
        Ok(ret)
    }
    fn read_bits(&mut self, nbits: u8) -> DecoderResult<u8> {
        let mut ret = 0;
        validate!(nbits < 8);
        for _ in 0..nbits {
            ret <<= 1;
            if self.read_bit()? {
                ret |= 1;
            }
        }
        Ok(ret)
    }
}

fn lz_unpack(src: &[u8], dst: &mut Vec<u8>, words: bool) -> DecoderResult<()> {
    dst.clear();
    let mut br = HybridReader::new(src);
    loop {
        if !br.read_bit()? {
            let lit = br.read_byte()?;
            dst.push(lit);
        } else {
            let mut offset = if !br.read_bit()? {
                    usize::from(br.read_byte()?)
                } else {
                    let hi = usize::from(br.read_bits(if words { 2 } else { 3 })?);
                    let lo = usize::from(br.read_byte()?);
                    (hi << 8) + lo
                };
            if words {
                offset = (offset + 1) * 2;
            } else {
                offset += 1;
            }
            let length = if !br.read_bit()? {
                    2
                } else if !br.read_bit()? {
                    3 + (br.read_bit()? as usize)
                } else if !br.read_bit()? {
                    5 + (br.read_bit()? as usize)
                } else if !br.read_bit()? {
                    7
                } else {
                    let len = usize::from(br.read_byte()?);
                    if len == 0 {
                        return Ok(())
                    }
                    len + 7
                };

            let pos = dst.len();
            validate!(pos >= offset);
            dst.resize(pos + length, 0);
            lz_copy(dst, pos, offset, length);
        }
    }
}

pub fn open(name: &str) -> DecoderResult<Box<dyn InputSource>> {
    let file = File::open(name).map_err(|_| DecoderError::InputNotFound(name.to_owned()))?;
    let mut fr = FileReader::new_read(file);
    let mut br = ByteReader::new(&mut fr);

    validate!(br.read_byte()? == b'S');
    validate!(br.read_byte()? == b'T');
    let st_size = br.read_u16le()? as usize;
    validate!(st_size >= 14);
    let width = br.read_u16le()? as usize;
    let height = br.read_u16le()? as usize;
    validate!((1..=320).contains(&width) && (1..=240).contains(&height));
    validate!((width | height) & 0xF == 0);
    let tile_w = br.read_u16le()?;
    let tile_h = br.read_u16le()?;
    validate!(tile_w == 16 && tile_h == 16);
    let nframes = br.read_u16le()?;
    validate!(nframes > 0);
    let nclrs = br.read_u16le()? as usize;
    validate!((1..=256).contains(&nclrs));
    let nib_offs = br.read_u16le()? as usize;
    validate!(nib_offs > 0);
    let codec_id = if st_size >= 15 { br.read_byte()? } else { 0 };
    validate!(codec_id < 4);
    if codec_id == 2 {
        validate!((width % 32) == 0 && (height % 64) == 0);
    }
    br.read_skip(st_size.saturating_sub(15))?;

    Ok(Box::new(SMVDecoder {
        fr,
        width, height, nclrs, nib_offs, codec_id,
        pal: [0; 768],
        data: Vec::new(),
        ubuf: Vec::new(),
        frame: vec![0; width * height],
    }))
}
