tenferro_einsum/
tensordot.rs1use std::collections::HashSet;
2
3use tenferro_runtime::error::{Error, Result};
4use tenferro_runtime::{DotGeneralConfig, TracedTensor};
5
6#[derive(Clone, Copy, Debug, PartialEq, Eq)]
29pub enum TensorDotAxes<'a> {
30 Count(usize),
32 Axes {
34 lhs: &'a [isize],
36 rhs: &'a [isize],
38 },
39}
40
41pub(crate) fn dot_general_config(
42 axes: TensorDotAxes<'_>,
43 lhs_rank: usize,
44 rhs_rank: usize,
45) -> Result<DotGeneralConfig> {
46 let (lhs_contracting_dims, rhs_contracting_dims) = match axes {
47 TensorDotAxes::Count(count) => {
48 if count > lhs_rank || count > rhs_rank {
49 return Err(contraction_error(format!(
50 "TensorDotAxes::Count({count}) cannot contract {count} axes \
51 for lhs rank {lhs_rank} and rhs rank {rhs_rank}"
52 )));
53 }
54 ((lhs_rank - count..lhs_rank).collect(), (0..count).collect())
55 }
56 TensorDotAxes::Axes { lhs, rhs } => {
57 if lhs.len() != rhs.len() {
58 return Err(contraction_error(format!(
59 "tensordot explicit axes must have matching lengths, got lhs {} and rhs {}",
60 lhs.len(),
61 rhs.len()
62 )));
63 }
64 (
65 normalize_axes(lhs, lhs_rank, "lhs")?,
66 normalize_axes(rhs, rhs_rank, "rhs")?,
67 )
68 }
69 };
70
71 let config = DotGeneralConfig {
72 lhs_contracting_dims,
73 rhs_contracting_dims,
74 lhs_batch_dims: Vec::new(),
75 rhs_batch_dims: Vec::new(),
76 };
77 config
78 .validate_dims_with_ranks(lhs_rank, rhs_rank)
79 .map_err(contraction_error)?;
80 Ok(config)
81}
82
83#[cfg(feature = "autodiff")]
84pub(crate) fn validate_concrete_contract_dims(
85 lhs_shape: &[usize],
86 rhs_shape: &[usize],
87 config: &DotGeneralConfig,
88) -> Result<()> {
89 config
90 .validate_dims_with_ranks(lhs_shape.len(), rhs_shape.len())
91 .map_err(contraction_error)?;
92 for (&lhs_axis, &rhs_axis) in config
93 .lhs_contracting_dims
94 .iter()
95 .zip(config.rhs_contracting_dims.iter())
96 {
97 let lhs_dim = lhs_shape[lhs_axis];
98 let rhs_dim = rhs_shape[rhs_axis];
99 if lhs_dim != rhs_dim {
100 return Err(contracted_dims_error(lhs_axis, lhs_dim, rhs_axis, rhs_dim));
101 }
102 }
103 Ok(())
104}
105
106pub(crate) fn validate_traced_contract_dims(
107 lhs: &TracedTensor,
108 rhs: &TracedTensor,
109 config: &DotGeneralConfig,
110) -> Result<()> {
111 config
112 .validate_dims_with_ranks(lhs.rank, rhs.rank)
113 .map_err(contraction_error)?;
114 for (&lhs_axis, &rhs_axis) in config
115 .lhs_contracting_dims
116 .iter()
117 .zip(config.rhs_contracting_dims.iter())
118 {
119 let lhs_dim = lhs.axis_sym_dim(lhs_axis)?;
120 let rhs_dim = rhs.axis_sym_dim(rhs_axis)?;
121 if lhs_dim == rhs_dim {
122 continue;
123 }
124 if let (Some(lhs_value), Some(rhs_value)) =
125 (lhs_dim.constant_value(), rhs_dim.constant_value())
126 {
127 if lhs_value != rhs_value {
128 return Err(contracted_dims_error(
129 lhs_axis, lhs_value, rhs_axis, rhs_value,
130 ));
131 }
132 }
133 }
134 Ok(())
135}
136
137fn normalize_axes(axes: &[isize], rank: usize, operand: &str) -> Result<Vec<usize>> {
138 let mut normalized = Vec::with_capacity(axes.len());
139 let mut seen = HashSet::with_capacity(axes.len());
140 for &axis in axes {
141 let normalized_axis = normalize_axis(axis, rank, operand)?;
142 if !seen.insert(normalized_axis) {
143 return Err(contraction_error(format!(
144 "duplicate {operand} axis {normalized_axis} in tensordot axes"
145 )));
146 }
147 normalized.push(normalized_axis);
148 }
149 Ok(normalized)
150}
151
152fn normalize_axis(axis: isize, rank: usize, operand: &str) -> Result<usize> {
153 let rank_isize = isize::try_from(rank).map_err(|_| {
154 contraction_error(format!(
155 "{operand} rank {rank} is too large to normalize tensordot axes"
156 ))
157 })?;
158 let normalized = if axis < 0 { rank_isize + axis } else { axis };
159 if normalized < 0 || normalized >= rank_isize {
160 return Err(contraction_error(format!(
161 "{operand} axis {axis} out of bounds for rank {rank}"
162 )));
163 }
164 usize::try_from(normalized).map_err(|_| {
165 contraction_error(format!(
166 "{operand} axis {axis} could not be normalized for rank {rank}"
167 ))
168 })
169}
170
171fn contracted_dims_error(
172 lhs_axis: usize,
173 lhs_dim: usize,
174 rhs_axis: usize,
175 rhs_dim: usize,
176) -> Error {
177 contraction_error(format!(
178 "contracted dimensions differ for lhs axis {lhs_axis} ({lhs_dim}) \
179 and rhs axis {rhs_axis} ({rhs_dim})"
180 ))
181}
182
183fn contraction_error(message: impl Into<String>) -> Error {
184 Error::ContractionError(message.into())
185}
186
187#[cfg(test)]
188mod tests;