// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements.  See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License.  You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::{
    any::Any,
    fmt::{Debug, Formatter},
    sync::Arc,
};

use arrow::{array::*, datatypes::*};
use auron_memmgr::spill::{SpillCompressedReader, SpillCompressedWriter};
use datafusion::{common::Result, physical_expr::PhysicalExprRef};
use datafusion_ext_commons::{
    downcast_any,
    io::{read_len, write_len},
};

use crate::{
    agg::{
        acc::{AccColumn, AccColumnRef},
        agg::{Agg, IdxSelection},
    },
    idx_for, idx_for_zipped,
};

pub struct AggCount {
    children: Vec<PhysicalExprRef>,
    data_type: DataType,
}

impl AggCount {
    pub fn try_new(children: Vec<PhysicalExprRef>, data_type: DataType) -> Result<Self> {
        assert_eq!(data_type, DataType::Int64);
        Ok(Self {
            children,
            data_type,
        })
    }
}

impl Debug for AggCount {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "Count({:?})", self.children)
    }
}

impl Agg for AggCount {
    fn as_any(&self) -> &dyn Any {
        self
    }

    fn exprs(&self) -> Vec<PhysicalExprRef> {
        self.children.clone()
    }

    fn with_new_exprs(&self, exprs: Vec<PhysicalExprRef>) -> Result<Arc<dyn Agg>> {
        Ok(Arc::new(Self::try_new(
            exprs.clone(),
            self.data_type.clone(),
        )?))
    }

    fn data_type(&self) -> &DataType {
        &self.data_type
    }

    fn nullable(&self) -> bool {
        false
    }

    fn create_acc_column(&self, num_rows: usize) -> Box<dyn AccColumn> {
        Box::new(AccCountColumn {
            values: vec![0; num_rows],
        })
    }

    fn partial_update(
        &self,
        accs: &mut AccColumnRef,
        acc_idx: IdxSelection<'_>,
        partial_args: &[ArrayRef],
        partial_arg_idx: IdxSelection<'_>,
    ) -> Result<()> {
        let accs = downcast_any!(accs, mut AccCountColumn)?;
        accs.ensure_size(acc_idx);

        if partial_args.is_empty() {
            idx_for_zipped! {
                ((acc_idx, _partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
                    if acc_idx >= accs.values.len() {
                        accs.values.push(1);
                    } else {
                        accs.values[acc_idx] += 1;
                    }
                }
            }
        } else {
            idx_for_zipped! {
                ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
                    let add = partial_args
                        .iter()
                        .all(|arg| arg.is_valid(partial_arg_idx)) as i64;

                    if acc_idx >= accs.values.len() {
                        accs.values.push(add);
                    } else {
                        accs.values[acc_idx] += add;
                    }
                }
            }
        }
        Ok(())
    }

    fn partial_merge(
        &self,
        accs: &mut AccColumnRef,
        acc_idx: IdxSelection<'_>,
        merging_accs: &mut AccColumnRef,
        merging_acc_idx: IdxSelection<'_>,
    ) -> Result<()> {
        let accs = downcast_any!(accs, mut AccCountColumn)?;
        let merging_accs = downcast_any!(merging_accs, mut AccCountColumn)?;
        accs.ensure_size(acc_idx);

        idx_for_zipped! {
            ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
                if acc_idx < accs.values.len() {
                    accs.values[acc_idx] += merging_accs.values[merging_acc_idx];
                } else {
                    accs.values.push(merging_accs.values[merging_acc_idx]);
                }
            }
        }
        Ok(())
    }

    fn final_merge(&self, accs: &mut AccColumnRef, acc_idx: IdxSelection<'_>) -> Result<ArrayRef> {
        Ok(accs.freeze_to_arrays(acc_idx)?[0].clone())
    }

    fn acc_array_data_types(&self) -> &[DataType] {
        &[DataType::Int64]
    }
}

pub struct AccCountColumn {
    pub values: Vec<i64>,
}

impl AccColumn for AccCountColumn {
    fn as_any(&self) -> &dyn Any {
        self
    }

    fn as_any_mut(&mut self) -> &mut dyn Any {
        self
    }

    fn resize(&mut self, num_accs: usize) {
        self.values.resize(num_accs, 0);
    }

    fn shrink_to_fit(&mut self) {
        self.values.shrink_to_fit();
    }

    fn num_records(&self) -> usize {
        self.values.len()
    }

    fn mem_used(&self) -> usize {
        self.values.capacity() * 2 * size_of::<i64>()
    }

    fn freeze_to_arrays(&mut self, idx: IdxSelection<'_>) -> Result<Vec<ArrayRef>> {
        let mut values = Vec::with_capacity(idx.len());
        idx_for! {
            (idx in idx) => {
                values.push(self.values[idx]);
            }
        }
        Ok(vec![Arc::new(Int64Array::from(values))])
    }

    fn unfreeze_from_arrays(&mut self, arrays: &[ArrayRef]) -> Result<()> {
        let array = downcast_any!(arrays[0], Int64Array)?;
        self.values = array.iter().map(|v| v.unwrap_or(0)).collect();
        Ok(())
    }

    fn spill(&mut self, idx: IdxSelection<'_>, w: &mut SpillCompressedWriter) -> Result<()> {
        idx_for! {
            (idx in idx) => {
                write_len(self.values[idx] as usize, w)?;
            }
        }
        Ok(())
    }

    fn unspill(&mut self, num_rows: usize, r: &mut SpillCompressedReader) -> Result<()> {
        assert_eq!(self.num_records(), 0, "expect empty AccColumn");
        for _ in 0..num_rows {
            self.values.push(read_len(r)? as i64);
        }
        Ok(())
    }
}
