use std::collections::VecDeque;
use crate::input::*;
use crate::io::byteio::*;

struct AudioState {
    lpc0_idx:   usize,
    lpc1_idx:   usize,
    lpc2_idx:   usize,
    scale:      i32,
    frame_mode: usize,
    lpc_coeffs: [i32; 8],
    cur_filt:   [i32; 8],
    prev_filt:  [i32; 8],

    pulse_buf:  [i32; 128],
    pulse_hist: [i32; 256],
    lpc_hist:   [i32; 8],

    lpc0_cb:    [[i16; 8]; 64],
    lpc1_cb:    [[i16; 8]; 64],
    lpc2_cb:    [[i16; 8]; 64],
    decays:     [i32; 8],
    base_filt:  [i32; 8],
    base_scale: i32,
}

impl AudioState {
    fn new() -> Self {
        Self {
            lpc0_idx:   0,
            lpc1_idx:   0,
            lpc2_idx:   0,
            scale:      0,
            frame_mode: 0,
            lpc_coeffs: [0; 8],
            cur_filt:   [0; 8],
            prev_filt:  [0; 8],

            pulse_buf:  [0; 128],
            pulse_hist: [0; 256],
            lpc_hist:   [0; 8],

            lpc0_cb:    [[0; 8]; 64],
            lpc1_cb:    [[0; 8]; 64],
            lpc2_cb:    [[0; 8]; 64],
            decays:     [0; 8],
            base_filt:  [0; 8],
            base_scale: 0,
        }
    }
    fn read_initial_params(&mut self, br: &mut dyn ByteIO) -> DecoderResult<()> {
        for entry in self.lpc0_cb.iter_mut() {
            for el in entry.iter_mut() {
                *el                     = br.read_u16le()? as i16;
            }
        }
        for entry in self.lpc1_cb.iter_mut() {
            for el in entry.iter_mut() {
                *el                     = br.read_u16le()? as i16;
            }
        }
        for entry in self.lpc2_cb.iter_mut() {
            for el in entry.iter_mut() {
                *el                     = br.read_u16le()? as i16;
            }
        }
        for el in self.decays.iter_mut() {
            *el                         = i32::from(br.read_u16le()? as i16);
        }
        for el in self.base_filt.iter_mut() {
            *el                         = br.read_u32le()? as i32;
        }
        self.base_scale                 = br.read_u32le()? as i32;
        Ok(())
    }
    fn unpack_data(&mut self, br: &mut dyn ByteIO, val: u16) -> DecoderResult<()> {
        self.lpc0_idx = (val & 0x3F) as usize;
        self.scale = (self.decays[((val >> 6) & 7) as usize] * self.scale) >> 13;
        let val1                        = br.read_u16le()?;
        self.lpc1_idx = ((val1 >> 6) & 0x3F) as usize;
        self.lpc2_idx = (val1 & 0x3F) as usize;
        self.frame_mode = ((val1 >> 12) & 3) as usize;
        let mut idx = (val1 >> 14) as usize;
        if self.frame_mode == 0 {
            let mut tail = 0;
            for _ in 0..8 {
                let val                 = br.read_u16le()?;
                for i in 0..5 {
                    let add = i32::from((val >> (13 - i * 3)) & 7);
                    self.pulse_buf[idx] += self.scale * (add * 2 - 7);
                    idx += 3;
                }
                tail = tail * 2 + (val & 1);
            }
            let add = i32::from((tail >> 5) & 7);
            self.pulse_buf[idx] += self.scale * (add * 2 - 7);
            idx += 3;
            let add = i32::from((tail >> 2) & 7);
            self.pulse_buf[idx] += self.scale * (add * 2 - 7);
        } else {
            let (len, step) = match self.frame_mode {
                    1 => (5, 3),
                    2 => (4, 4),
                    3 => (3, 5),
                    _ => unreachable!(),
                };
            for _ in 0..len {
                let val                 = br.read_u16le()?;
                for i in 0..8 {
                    let add = i32::from((val >> (14 - i * 2)) & 3);
                    self.pulse_buf[idx] += self.scale * (add * 2 - 3);
                    idx += step;
                }
            }
        }
        Ok(())
    }
    fn update_lpc_coeffs(&mut self) {
        for i in 0..8 {
            self.lpc_coeffs[i] += i32::from(self.lpc0_cb[self.lpc0_idx][i]);
            self.lpc_coeffs[i] += i32::from(self.lpc1_cb[self.lpc1_idx][i]);
            self.lpc_coeffs[i] += i32::from(self.lpc2_cb[self.lpc2_idx][i]);
        }

        let mut tmp = [0; 8];

        self.cur_filt = self.lpc_coeffs;
        for i in 0..4 {
            self.cur_filt.swap(i, 7 - i);
        }
        for len in 1..8 {
            let scale = self.cur_filt[len];
            for (prod, &val) in tmp.iter_mut().zip(self.cur_filt.iter()).take(len) {
                //*prod = (val * scale) >> 15;
                *prod = val.wrapping_mul(scale) >> 15;
            }
            for (dst, &add) in self.cur_filt.iter_mut().zip(tmp[..len].iter()) {
                *dst += add;
            }
        }

        for el in self.cur_filt.iter_mut() {
            *el = -(*el >> 1);
        }
    }
    fn decode_intra(&mut self, br: &mut dyn ByteIO, val: u16, out: &mut [i32; 128]) -> DecoderResult<()> {
        self.scale = self.base_scale;
        self.lpc_hist = [0; 8];

        for el in self.pulse_buf.iter_mut() {
            *el = 0;
        }
        self.unpack_data(br, val)?;

        self.lpc_coeffs = self.base_filt;
        self.update_lpc_coeffs();

        apply_lpc(out, &self.pulse_buf, &mut self.lpc_hist, &self.cur_filt);
        Ok(())
    }
    #[allow(clippy::needless_range_loop)]
    fn decode_inter(&mut self, br: &mut dyn ByteIO, val: u16, mode: u16, out: &mut [i32; 128]) -> DecoderResult<()> {
        let (part0, part1) = self.pulse_hist.split_at_mut(128);
        part0.copy_from_slice(part1);
        part1.copy_from_slice(&self.pulse_buf);
        self.prev_filt = self.cur_filt;

        if mode == 0x7E {
            for el in self.pulse_buf.iter_mut() {
                *el = 0;
            }
        } else {
            let src = &self.pulse_hist[127 - (mode as usize)..];
            let (src_head, body) = src.split_at(7);
            let (src_body, src_tail) = body.split_at(128 - 7 * 2);

            let (dst_head, body) = self.pulse_buf.split_at_mut(7);
            let (dst_body, dst_tail) = body.split_at_mut(128 - 7 * 2);

            for (i, (dst, &src)) in dst_head.iter_mut().zip(src_head.iter()).enumerate() {
                *dst = (src * ((i + 1) as i32)) >> 4;
            }
            for (dst, &src) in dst_body.iter_mut().zip(src_body.iter()) {
                *dst = src >> 1;
            }
            for (i, (dst, &src)) in dst_tail.iter_mut().zip(src_tail.iter()).enumerate() {
                *dst = (src * ((7 - i) as i32)) >> 4;
            }
        }

        self.unpack_data(br, val)?;
        self.update_lpc_coeffs();

        let mut filters = [[0; 8]; 4];
        filters[3] = self.cur_filt;
        let prev_filter = &self.prev_filt;
        for i in 0..8 {
            filters[1][i] = (prev_filter[i] + filters[3][i]) >> 1;
            filters[0][i] = (prev_filter[i] + filters[1][i]) >> 1;
            filters[2][i] = (filters[1][i]  + filters[3][i]) >> 1;
        }
        for ((dst, src), filter) in out.chunks_exact_mut(32)
                .zip(self.pulse_buf.chunks_exact(32)).zip(filters.iter()) {
            apply_lpc(dst, src, &mut self.lpc_hist, filter);
        }
        Ok(())
    }
}

fn apply_lpc(dst: &mut [i32], src: &[i32], hist: &mut [i32; 8], filt: &[i32; 8]) {
    for (hidx, (out, src)) in dst.iter_mut().zip(src.iter()).enumerate() {
        let mut sum = *src << 14;
        for i in 0..8 {
            //sum += hist[(hidx + i) & 7] * filt[i];
            sum = sum.wrapping_add(hist[(hidx + i) & 7].wrapping_mul(filt[i]));
        }
        let samp = sum >> 14;
        *out = samp;
        hist[hidx & 7] = samp;
    }
}

pub struct VXAudioDecoder {
    channels:   usize,
    state:      [AudioState; 2],
    bufl:       [i32; 128],
    bufr:       [i32; 128],
    out:        VecDeque<Vec<i16>>,
}

impl VXAudioDecoder {
    pub fn new(br: &mut dyn ByteIO, channels: usize) -> DecoderResult<Self> {
        let mut state = [AudioState::new(), AudioState::new()];
        for state in state.iter_mut().take(channels) {
            state.read_initial_params(br)?;
        }
        Ok(Self {
            channels, state,
            bufl: [0; 128],
            bufr: [0; 128],
            out: VecDeque::new(),
        })
    }
    pub fn get_channels(&self) -> u8 { self.channels as u8 }
    pub fn decode(&mut self, src: &[u8], nblocks: usize) -> DecoderResult<()> {
        let mut br = MemoryReader::new_read(src);

        if self.channels == 1 {
            for _ in 0..nblocks {
                let val                 = br.read_u16le()?;
                if val == 0x00 { continue; }
                let mode = val >> 9;
                if mode == 0x7F {
                    self.state[0].decode_intra(&mut br, val, &mut self.bufl)?;
                } else {
                    self.state[0].decode_inter(&mut br, val & 0x1FF, mode, &mut self.bufl)?;
                }
                self.output_frame();
            }
        } else {
            for _ in 0..nblocks {
                let val                 = br.read_u16le()?;
                if val == 0x00 { continue; }
                let mode = val >> 9;
                if mode == 0x7F {
                    self.state[0].decode_intra(&mut br, val, &mut self.bufl)?;
                } else {
                    self.state[0].decode_inter(&mut br, val & 0x1FF, mode, &mut self.bufl)?;
                }
                let val                 = br.read_u16le()?;
                if val == 0x00 { continue; }
                let mode = val >> 9;
                if mode == 0x7F {
                    self.state[1].decode_intra(&mut br, val, &mut self.bufr)?;
                } else {
                    self.state[1].decode_inter(&mut br, val & 0x1FF, mode, &mut self.bufr)?;
                }
                self.output_frame();
            }
        }

        Ok(())
    }
    fn output_frame(&mut self) {
        if self.channels == 1 {
            let mut frame = Vec::with_capacity(128);
            for &src in self.bufl.iter() {
                frame.push(src.max(-0x8000).min(0x7FFF) as i16);
            }
            self.out.push_back(frame);
        } else {
            let mut frame = Vec::with_capacity(128 * 2);
            for (&l, &r) in self.bufl.iter().zip(self.bufr.iter()) {
                frame.push(l.max(-0x8000).min(0x7FFF) as i16);
                frame.push(r.max(-0x8000).min(0x7FFF) as i16);
            }
            self.out.push_back(frame);
        }
    }
    pub fn get_frame(&mut self) -> Option<Vec<i16>> { self.out.pop_front() }
}
