1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
use proc_macro2::TokenStream;
use quote::{format_ident, ToTokens, TokenStreamExt};
use std::fmt;

/// Names of tensors
///
/// As the crate level document explains,
/// einsum factorization requires to track names of tensors
/// in addition to subscripts, and this struct manages it.
/// This works as a simple counter, which counts how many intermediate
/// tensor denoted `out{N}` appears and issues new `out{N+1}` identifier.
///
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Namespace {
    last: usize,
}

impl Namespace {
    /// Create new namespace
    pub fn init() -> Self {
        Namespace { last: 0 }
    }

    /// Issue new identifier
    pub fn new_ident(&mut self) -> Position {
        let pos = Position::Out(self.last);
        self.last += 1;
        pos
    }
}

/// Which tensor the subscript specifies
#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
pub enum Position {
    /// The tensor which user inputs as N-th argument of einsum
    Arg(usize),
    /// The tensor created by einsum in its N-th step
    Out(usize),
}

impl fmt::Debug for Position {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Position::Arg(n) => write!(f, "arg{}", n),
            Position::Out(n) => write!(f, "out{}", n),
        }
    }
}

impl fmt::Display for Position {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        fmt::Debug::fmt(self, f)
    }
}

impl ToTokens for Position {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        match self {
            Position::Arg(n) => tokens.append(format_ident!("arg{}", n)),
            Position::Out(n) => tokens.append(format_ident!("out{}", n)),
        }
    }
}