Skip to main content

tenferro_einsum/
tensordot.rs

1use std::collections::HashSet;
2
3use tenferro_runtime::error::{Error, Result};
4use tenferro_runtime::{DotGeneralConfig, TracedTensor};
5
6/// Axis specification for [`TracedTensorEinsumExt::tensordot`](crate::TracedTensorEinsumExt::tensordot)
7/// contraction sugar.
8///
9/// `Count(n)` contracts the last `n` axes of the left operand with the first
10/// `n` axes of the right operand. `Axes` contracts the explicitly paired axes
11/// in the order they are provided, accepting negative indices relative to each
12/// operand rank.
13///
14/// # Examples
15///
16/// ```
17/// use tenferro_einsum::TensorDotAxes;
18///
19/// let count = TensorDotAxes::Count(2);
20/// let explicit = TensorDotAxes::Axes {
21///     lhs: &[-1],
22///     rhs: &[0],
23/// };
24///
25/// assert_eq!(count, TensorDotAxes::Count(2));
26/// assert_ne!(count, explicit);
27/// ```
28#[derive(Clone, Copy, Debug, PartialEq, Eq)]
29pub enum TensorDotAxes<'a> {
30    /// Contract the last `n` left axes with the first `n` right axes.
31    Count(usize),
32    /// Contract explicit left/right axis pairs.
33    Axes {
34        /// Left operand axes. Negative axes are normalized against the left rank.
35        lhs: &'a [isize],
36        /// Right operand axes. Negative axes are normalized against the right rank.
37        rhs: &'a [isize],
38    },
39}
40
41pub(crate) fn dot_general_config(
42    axes: TensorDotAxes<'_>,
43    lhs_rank: usize,
44    rhs_rank: usize,
45) -> Result<DotGeneralConfig> {
46    let (lhs_contracting_dims, rhs_contracting_dims) = match axes {
47        TensorDotAxes::Count(count) => {
48            if count > lhs_rank || count > rhs_rank {
49                return Err(contraction_error(format!(
50                    "TensorDotAxes::Count({count}) cannot contract {count} axes \
51                     for lhs rank {lhs_rank} and rhs rank {rhs_rank}"
52                )));
53            }
54            ((lhs_rank - count..lhs_rank).collect(), (0..count).collect())
55        }
56        TensorDotAxes::Axes { lhs, rhs } => {
57            if lhs.len() != rhs.len() {
58                return Err(contraction_error(format!(
59                    "tensordot explicit axes must have matching lengths, got lhs {} and rhs {}",
60                    lhs.len(),
61                    rhs.len()
62                )));
63            }
64            (
65                normalize_axes(lhs, lhs_rank, "lhs")?,
66                normalize_axes(rhs, rhs_rank, "rhs")?,
67            )
68        }
69    };
70
71    let config = DotGeneralConfig {
72        lhs_contracting_dims,
73        rhs_contracting_dims,
74        lhs_batch_dims: Vec::new(),
75        rhs_batch_dims: Vec::new(),
76    };
77    config
78        .validate_dims_with_ranks(lhs_rank, rhs_rank)
79        .map_err(contraction_error)?;
80    Ok(config)
81}
82
83#[cfg(feature = "autodiff")]
84pub(crate) fn validate_concrete_contract_dims(
85    lhs_shape: &[usize],
86    rhs_shape: &[usize],
87    config: &DotGeneralConfig,
88) -> Result<()> {
89    config
90        .validate_dims_with_ranks(lhs_shape.len(), rhs_shape.len())
91        .map_err(contraction_error)?;
92    for (&lhs_axis, &rhs_axis) in config
93        .lhs_contracting_dims
94        .iter()
95        .zip(config.rhs_contracting_dims.iter())
96    {
97        let lhs_dim = lhs_shape[lhs_axis];
98        let rhs_dim = rhs_shape[rhs_axis];
99        if lhs_dim != rhs_dim {
100            return Err(contracted_dims_error(lhs_axis, lhs_dim, rhs_axis, rhs_dim));
101        }
102    }
103    Ok(())
104}
105
106pub(crate) fn validate_traced_contract_dims(
107    lhs: &TracedTensor,
108    rhs: &TracedTensor,
109    config: &DotGeneralConfig,
110) -> Result<()> {
111    config
112        .validate_dims_with_ranks(lhs.rank, rhs.rank)
113        .map_err(contraction_error)?;
114    for (&lhs_axis, &rhs_axis) in config
115        .lhs_contracting_dims
116        .iter()
117        .zip(config.rhs_contracting_dims.iter())
118    {
119        let lhs_dim = lhs.axis_sym_dim(lhs_axis)?;
120        let rhs_dim = rhs.axis_sym_dim(rhs_axis)?;
121        if lhs_dim == rhs_dim {
122            continue;
123        }
124        if let (Some(lhs_value), Some(rhs_value)) =
125            (lhs_dim.constant_value(), rhs_dim.constant_value())
126        {
127            if lhs_value != rhs_value {
128                return Err(contracted_dims_error(
129                    lhs_axis, lhs_value, rhs_axis, rhs_value,
130                ));
131            }
132        }
133    }
134    Ok(())
135}
136
137fn normalize_axes(axes: &[isize], rank: usize, operand: &str) -> Result<Vec<usize>> {
138    let mut normalized = Vec::with_capacity(axes.len());
139    let mut seen = HashSet::with_capacity(axes.len());
140    for &axis in axes {
141        let normalized_axis = normalize_axis(axis, rank, operand)?;
142        if !seen.insert(normalized_axis) {
143            return Err(contraction_error(format!(
144                "duplicate {operand} axis {normalized_axis} in tensordot axes"
145            )));
146        }
147        normalized.push(normalized_axis);
148    }
149    Ok(normalized)
150}
151
152fn normalize_axis(axis: isize, rank: usize, operand: &str) -> Result<usize> {
153    let rank_isize = isize::try_from(rank).map_err(|_| {
154        contraction_error(format!(
155            "{operand} rank {rank} is too large to normalize tensordot axes"
156        ))
157    })?;
158    let normalized = if axis < 0 { rank_isize + axis } else { axis };
159    if normalized < 0 || normalized >= rank_isize {
160        return Err(contraction_error(format!(
161            "{operand} axis {axis} out of bounds for rank {rank}"
162        )));
163    }
164    usize::try_from(normalized).map_err(|_| {
165        contraction_error(format!(
166            "{operand} axis {axis} could not be normalized for rank {rank}"
167        ))
168    })
169}
170
171fn contracted_dims_error(
172    lhs_axis: usize,
173    lhs_dim: usize,
174    rhs_axis: usize,
175    rhs_dim: usize,
176) -> Error {
177    contraction_error(format!(
178        "contracted dimensions differ for lhs axis {lhs_axis} ({lhs_dim}) \
179         and rhs axis {rhs_axis} ({rhs_dim})"
180    ))
181}
182
183fn contraction_error(message: impl Into<String>) -> Error {
184    Error::ContractionError(message.into())
185}
186
187#[cfg(test)]
188mod tests;