#[cfg(doc)]
use super::function_definition;
use crate::Subscripts;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use std::collections::HashSet;
fn index_ident(i: char) -> syn::Ident {
    quote::format_ident!("{}", i)
}
fn n_ident(i: char) -> syn::Ident {
    quote::format_ident!("n_{}", i)
}
fn contraction_for(indices: &[char], inner: TokenStream2) -> TokenStream2 {
    let mut tt = inner;
    for &i in indices.iter().rev() {
        let index = index_ident(i);
        let n = n_ident(i);
        tt = quote! {
            for #index in 0..#n { #tt }
        };
    }
    tt
}
fn contraction_inner(subscripts: &Subscripts) -> TokenStream2 {
    let mut inner_args_tt = Vec::new();
    for (argc, arg) in subscripts.inputs.iter().enumerate() {
        let mut index = Vec::new();
        for i in subscripts.inputs[argc].indices() {
            index.push(index_ident(i));
        }
        inner_args_tt.push(quote! {
            #arg[(#(#index),*)]
        })
    }
    let mut inner_mul = None;
    for inner in inner_args_tt {
        match inner_mul {
            Some(i) => inner_mul = Some(quote! { #i * #inner }),
            None => inner_mul = Some(inner),
        }
    }
    let output_ident = &subscripts.output;
    let mut output_indices = Vec::new();
    for i in &subscripts.output.indices() {
        let index = index_ident(*i);
        output_indices.push(index.clone());
    }
    quote! {
        #output_ident[(#(#output_indices),*)] = #inner_mul;
    }
}
pub fn contraction(subscripts: &Subscripts) -> TokenStream2 {
    let mut indices: Vec<char> = subscripts.output.indices();
    for i in subscripts.contraction_indices() {
        indices.push(i);
    }
    let inner = contraction_inner(subscripts);
    contraction_for(&indices, inner)
}
pub fn define_array_size(subscripts: &Subscripts) -> TokenStream2 {
    let mut appeared: HashSet<char> = HashSet::new();
    let mut tt = Vec::new();
    for arg in subscripts.inputs.iter() {
        let n_ident: Vec<syn::Ident> = arg
            .indices()
            .into_iter()
            .map(|i| {
                if appeared.contains(&i) {
                    quote::format_ident!("_")
                } else {
                    appeared.insert(i);
                    n_ident(i)
                }
            })
            .collect();
        tt.push(quote! {
            let (#(#n_ident),*) = #arg.dim();
        });
    }
    quote! { #(#tt)* }
}
pub fn array_size_asserts(subscripts: &Subscripts) -> TokenStream2 {
    let mut tt = Vec::new();
    for arg in &subscripts.inputs {
        let n_each: Vec<_> = (0..arg.indices().len())
            .map(|m| quote::format_ident!("n_{}", m))
            .collect();
        let n: Vec<_> = arg.indices().into_iter().map(n_ident).collect();
        tt.push(quote! {
            let (#(#n_each),*) = #arg.dim();
            #(assert_eq!(#n_each, #n);)*
        });
    }
    quote! { #({ #tt })* }
}
fn define_output_array(subscripts: &Subscripts) -> TokenStream2 {
    let output_ident = &subscripts.output;
    let mut n_output = Vec::new();
    for i in subscripts.output.indices() {
        n_output.push(n_ident(i));
    }
    quote! {
        let mut #output_ident = ndarray::Array::zeros((#(#n_output),*));
    }
}
pub fn inner(subscripts: &Subscripts) -> TokenStream2 {
    let array_size = define_array_size(subscripts);
    let array_size_asserts = array_size_asserts(subscripts);
    let output_ident = &subscripts.output;
    let output_tt = define_output_array(subscripts);
    let contraction_tt = contraction(subscripts);
    quote! {
        #array_size
        #array_size_asserts
        #output_tt
        #contraction_tt
        #output_ident
    }
}
#[cfg(test)]
mod test {
    use crate::{codegen::format_block, *};
    #[test]
    fn define_array_size() {
        let mut namespace = Namespace::init();
        let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap();
        let tt = format_block(super::define_array_size(&subscripts).to_string());
        insta::assert_snapshot!(tt, @r###"
        let (n_a, n_b) = arg0.dim();
        let (_, n_c) = arg1.dim();
        "###);
    }
    #[test]
    fn contraction() {
        let mut namespace = Namespace::init();
        let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap();
        let tt = format_block(super::contraction(&subscripts).to_string());
        insta::assert_snapshot!(tt, @r###"
        for a in 0..n_a {
            for c in 0..n_c {
                for b in 0..n_b {
                    out0[(a, c)] = arg0[(a, b)] * arg1[(b, c)];
                }
            }
        }
        "###);
    }
    #[test]
    fn inner() {
        let mut namespace = Namespace::init();
        let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap();
        let tt = format_block(super::inner(&subscripts).to_string());
        insta::assert_snapshot!(tt, @r###"
        let (n_a, n_b) = arg0.dim();
        let (_, n_c) = arg1.dim();
        {
            let (n_0, n_1) = arg0.dim();
            assert_eq!(n_0, n_a);
            assert_eq!(n_1, n_b);
        }
        {
            let (n_0, n_1) = arg1.dim();
            assert_eq!(n_0, n_b);
            assert_eq!(n_1, n_c);
        }
        let mut out0 = ndarray::Array::zeros((n_a, n_c));
        for a in 0..n_a {
            for c in 0..n_c {
                for b in 0..n_b {
                    out0[(a, c)] = arg0[(a, b)] * arg1[(b, c)];
                }
            }
        }
        out0
        "###);
    }
}