use std::fs::File;
use std::io::BufReader;
use crate::io::byteio::*;
use crate::io::bitreader::*;
use super::super::*;

#[derive(Debug,PartialEq)]
enum State {
    NextFrame,
    OutputAudio,
    Still(u8),
    StillAudio(u8),
}

#[derive(Default)]
struct FadeState {
    cur_val:    u8,
    step:       u8,
    fade_in:    bool,
}

impl FadeState {
    fn new() -> Self { Self::default() }
    fn is_fading(&self) -> bool { self.cur_val != 0 }
    fn fade_pal(&self, pal: &[u8; 768]) -> [u8; 768] {
        let scale = u16::from(if self.fade_in { 64 - self.cur_val } else { self.cur_val });
        let mut dpal = [0; 768];
        for (dst, &src) in dpal.iter_mut().zip(pal.iter()) {
            *dst = ((u16::from(src) * scale) >> 6) as u8;
        }
        dpal
    }
    fn update(&mut self) {
        self.cur_val = self.cur_val.saturating_sub(self.step);
    }
    fn set_fade_in(&mut self, step: u8) {
        self.cur_val = 64;
        self.step = step;
        self.fade_in = true;
    }
    fn set_fade_out(&mut self, step: u8) {
        self.cur_val = 64;
        self.step = step;
        self.fade_in = false;
    }
}

struct FrameRecord {
    vsize:      u32,
    asize:      u32,
    ftype:      u8,
    repeat:     u8,
    fade:       u8,
}

struct PalRecord {
    clrs:       [u8; 768],
    fade_tab:   Vec<u8>,
}

struct CDADecoder {
    fr:         FileReader<BufReader<File>>,
    data:       Vec<u8>,
    frame:      [u8; 320 * 240],
    cur_frm:    usize,
    pal_no:     usize,
    pal:        [u8; 768],
    width:      usize,
    height:     usize,
    frame_tab:  Vec<FrameRecord>,
    pal_tab:    Vec<PalRecord>,
    has_audio:  bool,
    arate:      u16,
    channels:   u8,
    audio:      Vec<u8>,
    ablk_size:  usize,
    state:      State,
    fade:       FadeState,
}

impl CDADecoder {
    fn draw_frame(&mut self) -> DecoderResult<()> {
        let mut mr = MemoryReader::new_read(&self.data);
        let mut br = ByteReader::new(&mut mr);

        let frm_type = br.read_byte()?;

        if frm_type != 1 {
            let mut pos = 1;
            self.frame[0] = br.read_byte()?;
            let num_ops = br.read_u32le()? as usize;
            for _ in 0..num_ops {
                let op = br.read_byte()?;
                match op {
                    0 => {
                        let len = usize::from(br.read_byte()?) + 45;
                        validate!(pos + len <= self.frame.len());
                        pos += len;
                    },
                    1 => {
                        let len = br.read_u16le()? as usize + 45;
                        validate!(pos + len <= self.frame.len());
                        pos += len;
                    },
                    2 => {
                        let len = usize::from(br.read_byte()?) + 301;
                        validate!(pos + len <= self.frame.len());
                        pos += len;
                    },
                    3 => {
                        let len = usize::from(br.read_byte()?) + 557;
                        validate!(pos + len <= self.frame.len());
                        pos += len;
                    },
                    4 => {
                        let clr = br.read_byte()?;
                        let len = br.read_u16le()? as usize;
                        validate!(pos + len <= self.frame.len());
                        for el in self.frame[pos..][..len].iter_mut() {
                            *el = clr;
                        }
                        pos += len;
                    },
                    5 => {
                        let clr = br.read_byte()?;
                        let len = br.read_byte()? as usize;
                        validate!(pos + len <= self.frame.len());
                        for el in self.frame[pos..][..len].iter_mut() {
                            *el = clr;
                        }
                        pos += len;
                    },
                    6..=59 => {
                        let len = usize::from(op - 6);
                        let clr = br.read_byte()?;
                        validate!(pos + len <= self.frame.len());
                        for el in self.frame[pos..][..len].iter_mut() {
                            *el = clr;
                        }
                        pos += len;
                    },
                    60..=129 => {
                        let len = usize::from(op - 60);
                        validate!(pos + len <= self.frame.len());
                        br.read_buf(&mut self.frame[pos..][..len])?;
                        pos += len;
                    },
                    130..=169 => {
                        let len = usize::from(op - 124);
                        let ref_pos = br.read_u16le()? as usize;
                        validate!(pos + len <= self.frame.len());
                        validate!(ref_pos + len <= self.frame.len());
                        for i in 0..len {
                            self.frame[pos + i] = self.frame[ref_pos + i];
                        }
                        pos += len;
                    },
                    170..=209 => {
                        let len = usize::from(op - 164);
                        let ref_pos = br.read_u16le()? as usize;
                        validate!(pos + len <= self.frame.len());
                        validate!(ref_pos < self.frame.len() && ref_pos + 1 >= len);
                        for i in 0..len {
                            self.frame[pos + i] = self.frame[ref_pos - i];
                        }
                        pos += len;
                    },
                    _ => {
                        let skip_len = usize::from(op - 210);
                        validate!(pos + skip_len <= self.frame.len());
                        pos += skip_len;
                    },
                }
            }
        } else { // the format seems to be the same though
            return Err(DecoderError::NotImplemented);
        }
        Ok(())
    }
    fn get_video(&mut self) -> DecoderResult<(usize, Frame)> {
        if !self.fade.is_fading() {
            Ok((0, Frame::VideoPal(self.frame[..self.width * self.height].to_vec(), self.pal)))
        } else {
            let pal = self.fade.fade_pal(&self.pal);
            self.fade.update();
            Ok((0, Frame::VideoPal(self.frame[..self.width * self.height].to_vec(), pal)))
        }
    }
    fn unpack_audio(&mut self) -> DecoderResult<()> {
        // pad because last sample may overread
        self.data.push(0);
        let nsamples = read_u32le(&self.data)? as usize;
        validate!(nsamples >= 2);
        let mut br = BitReader::new(&self.data[4..], BitReaderMode::BE);
        self.audio.reserve(nsamples);
        if self.channels == 1 {
            let mut pred = br.read(8)? as u8;
            self.audio.push(pred);
            for _ in 1..nsamples {
                pred = Self::read_sample(&mut br, pred)?;
                self.audio.push(pred);
            }
        } else {
            validate!((nsamples & 1) == 0);
            let mut pred_l = br.read(8)? as u8;
            let mut pred_r = br.read(8)? as u8;
            self.audio.push(pred_l);
            self.audio.push(pred_r);
            for _ in 1..(nsamples / 2) {
                pred_l = Self::read_sample(&mut br, pred_l)?;
                self.audio.push(pred_l);
                pred_r = Self::read_sample(&mut br, pred_r)?;
                self.audio.push(pred_r);
            }
        }
        Ok(())
    }
    fn read_sample(br: &mut BitReader, prev: u8) -> DecoderResult<u8> {
        if !br.read_bool()? {
            let delta = br.read_s(4)? as u8;
            Ok(prev.wrapping_add(delta))
        } else if !br.read_bool()? {
            let mut delta = br.read(5)? as u8;
            if (delta & 0x10) == 0 {
                delta += 0x8;
            } else {
                delta = (delta | 0xE0) - 8;
            }
            Ok(prev.wrapping_add(delta))
        } else {
            Ok(br.read(8)? as u8)
        }
    }
    fn get_audio(&mut self) -> DecoderResult<(usize, Frame)> {
        let mut ret = Vec::with_capacity(self.ablk_size);
        validate!(!self.audio.is_empty());
        if self.audio.len() > self.ablk_size {
            ret.resize(self.ablk_size, 0);
            ret.copy_from_slice(&self.audio[..self.ablk_size]);
            self.audio.drain(..self.ablk_size);
        } else {
            std::mem::swap(&mut ret, &mut self.audio);
        }
        Ok((1, Frame::AudioU8(ret)))
    }
}

impl InputSource for CDADecoder {
    fn get_num_streams(&self) -> usize { if self.has_audio { 2 } else { 1 } }
    fn get_stream_info(&self, stream_no: usize) -> StreamInfo {
        match stream_no {
            0 => StreamInfo::Video(VideoInfo{
                    width:  self.width,
                    height: self.height,
                    bpp:    8,
                    tb_num: 3,
                    tb_den: 50,
                 }),
            1 if self.has_audio => StreamInfo::Audio(AudioInfo{
                    sample_rate: u32::from(self.arate),
                    channels:    self.channels,
                    sample_type: AudioSample::U8,
                 }),
            _ => StreamInfo::None
        }
    }
    fn decode_frame(&mut self) -> DecoderResult<(usize, Frame)> {
        let mut br = ByteReader::new(&mut self.fr);
        match self.state {
            State::NextFrame => {
                if self.cur_frm >= self.frame_tab.len() {
                    return Err(DecoderError::EOF);
                }

                let vsize = self.frame_tab[self.cur_frm].vsize as usize;
                let asize = self.frame_tab[self.cur_frm].asize as usize;
                let ftype = self.frame_tab[self.cur_frm].ftype;
                let mut repeat = self.frame_tab[self.cur_frm].repeat;
                let fade  = self.frame_tab[self.cur_frm].fade;

                if ftype == 2 {
                    validate!(fade > 0);
                    repeat = 64 / fade;
                }

                self.data.resize(vsize, 0);
                br.read_buf(&mut self.data)?;
                self.draw_frame().map_err(|_| DecoderError::InvalidData)?;

                if self.has_audio {
                    br = ByteReader::new(&mut self.fr);
                    validate!(asize > 4);
                    self.data.resize(asize, 0);
                    br.read_buf(&mut self.data)?;
                    self.unpack_audio().map_err(|_| DecoderError::InvalidData)?;
                    self.state = State::OutputAudio;
                }

                if matches!(self.frame_tab[self.cur_frm].ftype, 1 | 3 | 5) {
                    validate!(self.pal_no < self.pal_tab.len());
                    self.pal = self.pal_tab[self.pal_no].clrs;
                    self.pal_no += 1;
                }

                match ftype {
                    2 | 3 => {
                        self.fade.set_fade_out(fade);
                    },
                    5 | 6 => {
                        self.fade.set_fade_in(fade);
                    },
                    _ => {},
                }

                self.state = match (repeat > 0, self.has_audio) {
                        (false, false) => State::NextFrame,
                        (false, true)  => State::OutputAudio,
                        (true,  false) => State::Still(repeat - 1),
                        (true,  true)  => State::StillAudio(repeat - 1),
                    };

                self.cur_frm += 1;

                self.get_video()
            },
            State::OutputAudio => {
                self.state = State::NextFrame;
                self.get_audio()
            },
            State::Still(count) => {
                self.state = match (count > 0, self.has_audio) {
                        (true,  false) => State::Still(count - 1),
                        (true,  true)  => State::StillAudio(count - 1),
                        (false, false) => State::NextFrame,
                        (false, true)  => State::OutputAudio,
                    };
                self.get_video()
            },
            State::StillAudio(count) => {
                self.state = if count > 0 { State::Still(count) } else { State::NextFrame };
                self.get_audio()
            }
        }
    }
}

pub fn open(name: &str) -> DecoderResult<Box<dyn InputSource>> {
    let file = File::open(name).map_err(|_| DecoderError::InputNotFound(name.to_owned()))?;
    let mut fr = FileReader::new_read(BufReader::new(file));
    let mut br = ByteReader::new(&mut fr);

    let mut hdr = [0; 5];
    br.read_buf(&mut hdr)?;
    validate!(&hdr == b"LSANM");
    validate!(br.read_byte()? == 1);
    let ver = br.read_byte()?;
    validate!(ver == 1 || ver == 2);
    if ver == 1 { return Err(DecoderError::NotImplemented); }

    let channels = br.read_byte()?;
    validate!(channels < 3);
    let arate = br.read_u16le()?;
    validate!((channels == 0 && arate == 0) || (channels > 0 && (8000..=22050).contains(&arate)));

    let width = br.read_u16le()? as usize;
    let height = br.read_u16le()? as usize;
    validate!(width > 0 && width <= 320 && height > 0 && height <= 240);

    let smth = br.read_u16le()?;
    if smth != 17 { return Err(DecoderError::NotImplemented); }
    let pal_entries = br.read_u16le()? as usize;
    validate!(ver != 2 || pal_entries > 0);

    let start = br.read_u32le()?;
    let nframes = br.read_u32le()? as usize;
    validate!(nframes > 0 && nframes < 5000);
    let ablk_size = br.read_u32le()? as usize;
    validate!((arate == 0 && ablk_size == 0) || (arate != 0 && ablk_size != 0));

    let mut frame_tab = Vec::with_capacity(nframes);
    for _ in 0..nframes {
        let vsize = br.read_u32le()?;
        let asize = br.read_u32le()?;
        validate!((arate == 0 && asize == 0) || (arate != 0 && asize != 0));
        br.read_skip(3)?;
        let repeat = br.read_byte()?;
        let ftype = br.read_byte()?;
        let fade = br.read_byte()?;
        frame_tab.push(FrameRecord { vsize, asize, ftype, repeat, fade });
    }

    let mut pal_tab = Vec::with_capacity(pal_entries);
    for _ in 0..pal_entries {
        let nclrs = usize::from(br.read_byte()?) + 1;
        let has_tab = br.read_byte()?;
        let mut clrs = [0; 768];
        br.read_vga_pal_some(&mut clrs[..nclrs * 3])?;
        let mut fade_tab = Vec::new();
        if has_tab != 0 {
            fade_tab.resize(nclrs * 16, 0);
            br.read_buf(&mut fade_tab)?;
        }
        pal_tab.push(PalRecord { clrs, fade_tab });
    }
    validate!(br.tell() <= u64::from(start));

    br.seek(SeekFrom::Start(start.into()))?;

    Ok(Box::new(CDADecoder {
        fr,
        // decoded data seems to go past declared dimensions quite often
        frame: [0; 320 * 240],
        data: Vec::new(),
        width, height,
        pal: [0; 768],
        cur_frm: 0,
        pal_no: 0,
        frame_tab, pal_tab,
        arate, channels, ablk_size,
        audio: Vec::new(),
        has_audio: channels > 0,
        state: State::NextFrame,
        fade: FadeState::new(),
    }))
}
