use std::fs::File;
use std::io::BufWriter;

use crate::io::byteio::*;
use super::*;

#[cfg(debug_assertions)]
macro_rules! validate {
    ($a:expr) => { if !$a { println!("check failed at {}:{}", file!(), line!()); return Err(EncoderError::InvalidData); } };
}
#[cfg(not(debug_assertions))]
macro_rules! validate {
    ($a:expr) => { if !$a { return Err(EncoderError::InvalidData); } };
}

struct IndexEntry {
    tag:    [u8; 4],
    size:   u32,
    offs:   u32,
}

struct StreamStats {
    pal:        [u8; 768],
    pal_pos:    u32,
    has_pal:    bool,
    has_pc:     bool,
    nframes:    u32,
    max_size:   u32,
    strh_pos:   u64,
    is_video:   bool,
}

struct AVIWriter {
    fr:         FileWriter<BufWriter<File>>,
    sstats:     Vec<StreamStats>,
    streams:    Vec<StreamInfo>,
    index:      Vec<IndexEntry>,
    data_pos:   u64,
}

impl OutputWriter for AVIWriter {
    fn add_stream(&mut self, stream_no: usize, stream_info: StreamInfo) -> EncoderResult<()> {
        validate!(stream_no == self.streams.len());
        if stream_no > 99 {
            return Err(EncoderError::Ignored);
        }
        match &stream_info {
            StreamInfo::Video(ref vinfo) => {
                validate!(vinfo.width > 0 && vinfo.height > 0);
                if !matches!(vinfo.bpp, 8 | 15 | 16 | 24) {
                    return Err(EncoderError::Unsupported);
                }
            },
            StreamInfo::Audio(ref ainfo) => {
                validate!(ainfo.sample_rate > 0 && ainfo.channels > 0);
            },
            _ => return Err(EncoderError::Ignored),
        }
        self.streams.push(stream_info);
        Ok(())
    }
    fn finish_header(&mut self) -> EncoderResult<()> {
        if self.streams.is_empty() {
            return Err(EncoderError::EmptyOutput);
        }

        let mut bw = ByteWriter::new(&mut self.fr);
        let hdrl_pos = bw.tell() + 20;
        bw.write_buf(b"RIFF")?;
        bw.write_u32le(0)?;
        bw.write_buf(b"AVI LIST")?;
        bw.write_u32le(0)?;
        bw.write_buf(b"hdrlavih")?;

        bw.write_u32le(56)?; // avih size
        let mut found = false;
        for stream in self.streams.iter() {
            if let StreamInfo::Video(ref vinfo) = stream {
                found = true;
                let ms_per_frame = if vinfo.tb_den != 0 {
                        if vinfo.tb_den > 1000 {
                            vinfo.tb_num / (vinfo.tb_den / 1000)
                        } else {
                            vinfo.tb_num * 1000 / vinfo.tb_den
                        }
                    } else {
                        0
                    };
                bw.write_u32le(ms_per_frame)?;
                bw.write_u32le(0)?; // max transfer rate
                bw.write_u32le(0)?; // padding granularity
                bw.write_u32le(0)?; // flags
                bw.write_u32le(0)?; // total frames
                bw.write_u32le(0)?; // initial frames
                bw.write_u32le(self.streams.len() as u32)?;
                bw.write_u32le(0)?; // suggested buffer size
                bw.write_u32le(vinfo.width as u32)?;
                bw.write_u32le(vinfo.height as u32)?;
                break;
            }
        }
        if !found {
            bw.write_u32le(0)?; // milliseconds per frame
            bw.write_u32le(0)?; // max transfer rate
            bw.write_u32le(0)?; // padding granularity
            bw.write_u32le(0)?; // flags
            bw.write_u32le(0)?; // total frames
            bw.write_u32le(0)?; // initial frames
            bw.write_u32le(self.streams.len() as u32)?;
            bw.write_u32le(0)?; // suggested buffer size
            bw.write_u32le(0)?;
            bw.write_u32le(0)?;
        }
        bw.write_u32le(0)?; // reserved
        bw.write_u32le(0)?; // reserved
        bw.write_u32le(0)?; // reserved
        bw.write_u32le(0)?; // reserved

        for _ in 0..self.streams.len() {
            self.sstats.push(StreamStats{
                pal:        [0; 768],
                has_pal:    false,
                has_pc:     false,
                pal_pos:    0,
                nframes:    0,
                max_size:   0,
                strh_pos:   0,
                is_video:   false,
            });
        }
        for (stream, sstat) in self.streams.iter().zip(self.sstats.iter_mut()) {
            let strl_pos = bw.tell() + 8;
            bw.write_buf(b"LIST\0\0\0\0strlstrh")?;
            bw.write_u32le(56)?; // strh size
            match stream {
                StreamInfo::Video(ref _vinfo) => {
                    bw.write_buf(b"vids")?;
                    bw.write_buf(b"DIB ")?;
                },
                StreamInfo::Audio(ref _ainfo) => {
                    bw.write_buf(b"auds")?;
                    bw.write_u32le(0)?;
                },
                _ => unreachable!(),
            };

            sstat.strh_pos = bw.tell();

            bw.write_u32le(0)?; // flags
            bw.write_u16le(0)?; // priority
            bw.write_u16le(0)?; // language
            bw.write_u32le(0)?; // initial frames
            match stream {
                StreamInfo::Video(ref vinfo) => {
                    bw.write_u32le(vinfo.tb_num)?;
                    bw.write_u32le(vinfo.tb_den)?;
                    sstat.is_video = true;
                },
                StreamInfo::Audio(ref ainfo) => {
                    bw.write_u32le(1)?;
                    bw.write_u32le(ainfo.sample_rate)?;
                },
                _ => unreachable!(),
            };
            bw.write_u32le(0)?; // start
            bw.write_u32le(0)?; // length
            bw.write_u32le(0)?; // suggested buffer size
            bw.write_u32le(0)?; // quality
            bw.write_u32le(0)?; // sample_size
            bw.write_u16le(0)?; // x
            bw.write_u16le(0)?; // y
            bw.write_u16le(0)?; // w
            bw.write_u16le(0)?; // h

            bw.write_buf(b"strf")?;
            bw.write_u32le(0)?;
            let strf_pos = bw.tell();
            match stream {
                StreamInfo::Video(ref vinfo) => {
                    let hdr_pos = bw.tell();
                    bw.write_u32le(0)?;
                    bw.write_u32le(vinfo.width as u32)?;
                    bw.write_u32le(vinfo.height as u32)?;
                    match vinfo.bpp {
                        8 => {
                            bw.write_u16le(1)?;
                            bw.write_u16le(8)?;
                        },
                        15 => {
                            bw.write_u16le(1)?;
                            bw.write_u16le(16)?;
                        },
                        16 | 24 => {
                            bw.write_u16le(1)?;
                            bw.write_u16le(24)?;
                        },
                        _ => unreachable!(),
                    }
                    sstat.max_size = if vinfo.bpp != 16 {
                            (vinfo.width * vinfo.height * (usize::from(vinfo.bpp + 1) / 8)) as u32
                        } else { 0 };
                    bw.write_u32le(0)?; // raw video
                    bw.write_u32le(sstat.max_size)?; // image size
                    bw.write_u32le(0)?; // x dpi
                    bw.write_u32le(0)?; // y dpi
                    if vinfo.bpp == 8 {
                        bw.write_u32le(256)?; // total colors
                        bw.write_u32le(0)?; // important colors
                        sstat.pal_pos = bw.tell() as u32;
                        for _ in 0..256 {
                            bw.write_u32le(0)?;
                        }
                    } else {
                        bw.write_u32le(0)?; // total colors
                        bw.write_u32le(0)?; // important colors
                    }
                    let bisize = bw.tell() - hdr_pos;
                    bw.seek(SeekFrom::Current(-(bisize as i64)))?;
                    bw.write_u32le(bisize as u32)?;
                    bw.seek(SeekFrom::End(0))?;
                },
                StreamInfo::Audio(ref ainfo) => {
                    let bps = if ainfo.sample_type == AudioSample::U8 { 1 } else { 2 };
                    bw.write_u16le(0x0001)?;
                    bw.write_u16le(u16::from(ainfo.channels))?;
                    bw.write_u32le(ainfo.sample_rate)?;
                    bw.write_u32le(ainfo.sample_rate * u32::from(ainfo.channels) * bps)?;
                    bw.write_u16le(u16::from(ainfo.channels) * (bps as u16))?;
                    bw.write_u16le(if ainfo.sample_type == AudioSample::U8 { 8 } else { 16 })?;
                },
                _ => unreachable!(),
            };
            patch_size(&mut bw, strf_pos)?;
            patch_size(&mut bw, strl_pos)?;
        }
        patch_size(&mut bw, hdrl_pos)?;

        self.data_pos = bw.tell() + 8;
        bw.write_buf(b"LIST\0\0\0\0movi")?;

        Ok(())
    }
    fn write(&mut self, stream_no: usize, frame: Frame) -> EncoderResult<()> {
        if stream_no >= self.streams.len() { return Err(EncoderError::Ignored); }

        let mut bw = ByteWriter::new(&mut self.fr);

        if bw.tell() > (1 << 31) || self.sstats[stream_no].nframes > (1 << 20) {
            return Err(EncoderError::ContainerFull);
        }

        let mut tag = [b'0' + ((stream_no / 10) as u8), b'0' + ((stream_no % 10) as u8), b' ', b' '];
        let offs = bw.tell() as u32;

        self.sstats[stream_no].nframes += 1;
        match (&self.streams[stream_no], &frame) {
            (StreamInfo::Video(ref vinfo), Frame::VideoPal(ref data, ref pal)) => {
                validate!(vinfo.bpp == 8);
                if !self.sstats[stream_no].has_pal {
                    bw.seek(SeekFrom::Start(u64::from(self.sstats[stream_no].pal_pos)))?;
                    for clr in pal.chunks_exact(3) {
                        bw.write_byte(clr[2])?;
                        bw.write_byte(clr[1])?;
                        bw.write_byte(clr[0])?;
                        bw.write_byte(0)?;
                    }
                    bw.seek(SeekFrom::Start(u64::from(offs)))?;
                    self.sstats[stream_no].pal.copy_from_slice(pal);
                    self.sstats[stream_no].has_pal = true;
                } else if &self.sstats[stream_no].pal != pal {
                    tag[2] = b'p';
                    tag[3] = b'c';
                    let pc_data = generate_palchange(&self.sstats[stream_no].pal, pal);
                    let size = pc_data.len() as u32;
                    self.index.push(IndexEntry { tag, offs, size });
                    bw.write_buf(&tag)?;
                    bw.write_u32le(size)?;
                    bw.write_buf(&pc_data)?;
                    self.sstats[stream_no].pal.copy_from_slice(pal);
                    self.sstats[stream_no].has_pc = true;
                }

                tag[2] = b'd';
                tag[3] = b'b';
                let offs = bw.tell() as u32;
                let size = data.len() as u32;
                self.index.push(IndexEntry { tag, offs, size });
                bw.write_buf(&tag)?;
                bw.write_u32le(size)?;
                let stride = vinfo.width;
                let pad = vec![0; if (stride & 3) != 0 { 4 - (stride & 3) } else { 0 }];
                for line in data.chunks_exact(vinfo.width).rev() {
                    bw.write_buf(line)?;
                    bw.write_buf(&pad)?;
                }
                self.sstats[stream_no].max_size = self.sstats[stream_no].max_size.max(size);
            },
            (StreamInfo::Video(ref vinfo), Frame::VideoRGB16(ref data)) => {
                validate!(vinfo.bpp == 15 || vinfo.bpp == 16);
                match vinfo.bpp {
                    15 => {
                        tag[2] = b'd';
                        tag[3] = b'b';
                        let size = (data.len() * 2) as u32;
                        self.index.push(IndexEntry { tag, offs, size });
                        bw.write_buf(&tag)?;
                        bw.write_u32le(size)?;
                        for line in data.chunks_exact(vinfo.width).rev() {
                            for &pix in line.iter() {
                                bw.write_u16le(pix)?;
                            }
                            if (line.len() & 1) != 0 {
                                bw.write_u16le(0)?;
                            }
                        }
                    },
                    16 => {
                        tag[2] = b'd';
                        tag[3] = b'b';
                        let size = (data.len() * 3) as u32;
                        self.index.push(IndexEntry { tag, offs, size });
                        bw.write_buf(&tag)?;
                        bw.write_u32le(size)?;
                        let stride = vinfo.width * 3;
                        let pad = vec![0; if (stride & 3) != 0 { 4 - (stride & 3) } else { 0 }];
                        for line in data.chunks_exact(vinfo.width).rev() {
                            for &pix in line.iter() {
                                let r = ((pix >> 11) & 0x1F) as u8;
                                let g = ((pix >>  5) & 0x3F) as u8;
                                let b = ( pix        & 0x1F) as u8;
                                bw.write_byte((r << 3) | (r >> 2))?;
                                bw.write_byte((g << 2) | (g >> 4))?;
                                bw.write_byte((b << 3) | (b >> 2))?;
                            }
                            bw.write_buf(&pad)?;
                        }
                        self.sstats[stream_no].max_size = self.sstats[stream_no].max_size.max(size);
                    },
                    _ => unreachable!(),
                }
            },
            (StreamInfo::Video(ref vinfo), Frame::VideoRGB24(ref data)) => {
                validate!(vinfo.bpp == 24);
                tag[2] = b'd';
                tag[3] = b'b';
                let size = data.len() as u32;
                self.index.push(IndexEntry { tag, offs, size });
                bw.write_buf(&tag)?;
                bw.write_u32le(size)?;
                let stride = vinfo.width * 3;
                let pad = vec![0; if (stride & 3) != 0 { 4 - (stride & 3) } else { 0 }];
                for line in data.chunks_exact(vinfo.width * 3).rev() {
                    bw.write_buf(line)?;
                    bw.write_buf(&pad)?;
                }
            },
            (StreamInfo::Audio(ref ainfo), Frame::AudioU8(ref data)) => {
                validate!(ainfo.sample_type == AudioSample::U8);
                tag[2] = b'w';
                tag[3] = b'b';
                let size = data.len() as u32;
                self.index.push(IndexEntry { tag, offs, size });
                bw.write_buf(&tag)?;
                bw.write_u32le(size)?;
                bw.write_buf(data)?;
                self.sstats[stream_no].max_size = self.sstats[stream_no].max_size.max(size);
            },
            (StreamInfo::Audio(ref ainfo), Frame::AudioS16(ref data)) => {
                validate!(ainfo.sample_type == AudioSample::S16);
                tag[2] = b'w';
                tag[3] = b'b';
                let size = (data.len() * 2) as u32;
                self.index.push(IndexEntry { tag, offs, size });
                bw.write_buf(&tag)?;
                bw.write_u32le(size)?;
                for &sample in data.iter() {
                    bw.write_u16le(sample as u16)?;
                }
                self.sstats[stream_no].max_size = self.sstats[stream_no].max_size.max(size);
            },
            _ => return Err(EncoderError::InvalidData),
        }
        if (bw.tell() & 1) != 0 {
            bw.write_byte(0)?; // padding
        }

        Ok(())
    }
    fn finish(&mut self) -> EncoderResult<()> {
        if !self.index.is_empty() {
            let mut bw = ByteWriter::new(&mut self.fr);

            patch_size(&mut bw, self.data_pos)?;
            if !self.index.is_empty() {
                bw.write_buf(b"idx1")?;
                bw.write_u32le((self.index.len() * 16) as u32)?;
                for item in self.index.iter() {
                    bw.write_buf(&item.tag)?;
                    bw.write_u32le(if &item.tag[2..] == b"pc" { 0x100 } else { 0x10 })?;
                    bw.write_u32le(item.offs)?;
                    bw.write_u32le(item.size)?;
                }
            }
            patch_size(&mut bw, 8)?;

            let mut max_frames = 0;
            let mut max_size = 0;
            for sstat in self.sstats.iter() {
                max_frames = max_frames.max(sstat.nframes);
                max_size = max_size.max(sstat.max_size);
                if sstat.has_pc {
                    bw.seek(SeekFrom::Start(sstat.strh_pos))?;
                    bw.write_u32le(0x00010000)?;
                }
                bw.seek(SeekFrom::Start(sstat.strh_pos + 0x18))?;
                bw.write_u32le(if sstat.is_video { sstat.nframes } else { 0 })?;
                bw.write_u32le(sstat.max_size)?;
            }
            bw.seek(SeekFrom::Start(0x30))?;
            bw.write_u32le(max_frames)?;
            bw.seek(SeekFrom::Current(8))?;
            bw.write_u32le(max_size)?;

            bw.flush()?;

            Ok(())
        } else {
            Err(EncoderError::EmptyOutput)
        }
    }
}

fn patch_size(bw: &mut ByteWriter, pos: u64) -> EncoderResult<()> {
    let size = bw.tell() - pos;
    bw.seek(SeekFrom::Current(-((size + 4) as i64)))?;
    bw.write_u32le(size as u32)?;
    bw.seek(SeekFrom::End(0))?;
    Ok(())
}

fn generate_palchange(old_pal: &[u8; 768], new_pal: &[u8; 768]) -> Vec<u8> {
    let mut pc_data = Vec::with_capacity(1028);

    let mut start_clr = 256;
    for (i, (clr1, clr2)) in old_pal.chunks_exact(3).zip(new_pal.chunks_exact(3)).enumerate() {
        if clr1 != clr2 {
            start_clr = i;
            break;
        }
    }
    let mut end_clr = 256;
    for (i, (clr1, clr2)) in old_pal.chunks_exact(3).zip(new_pal.chunks_exact(3)).enumerate().rev() {
        if clr1 != clr2 {
            end_clr = i + 1;
            break;
        }
    }

    pc_data.push(start_clr as u8);
    pc_data.push((end_clr - start_clr) as u8);
    pc_data.push(0);
    pc_data.push(0);
    for new_clr in new_pal[start_clr * 3..end_clr * 3].chunks_exact(3) {
        pc_data.push(new_clr[0]);
        pc_data.push(new_clr[1]);
        pc_data.push(new_clr[2]);
        pc_data.push(0);
    }

    pc_data
}

pub fn create(name: &str) -> EncoderResult<Box<dyn OutputWriter>> {
    let file = File::create(name).map_err(|_| EncoderError::InvalidFilename(name.to_owned()))?;
    let fr = FileWriter::new_write(BufWriter::new(file));

    Ok(Box::new(AVIWriter {
        fr,
        streams: Vec::with_capacity(2),
        sstats: Vec::with_capacity(2),
        index: Vec::new(),
        data_pos: 0,
    }))
}
