Skip to main content

tenferro_einsum/
eager.rs

1use std::collections::{HashMap, HashSet};
2
3use tenferro_tensor::{DotGeneralConfig, Error, Result, Tensor, TensorBackend, TensorExec};
4
5use crate::{ContractionTree, Subscripts};
6
7const EAGER_EINSUM_OP: &str = "eager_einsum";
8
9enum TensorValue<'a> {
10    Borrowed(&'a Tensor),
11    Owned(Tensor),
12}
13
14impl TensorValue<'_> {
15    fn as_tensor(&self) -> &Tensor {
16        match self {
17            Self::Borrowed(tensor) => tensor,
18            Self::Owned(tensor) => tensor,
19        }
20    }
21
22    fn into_tensor(self) -> Tensor {
23        match self {
24            Self::Borrowed(tensor) => tensor.clone(),
25            Self::Owned(tensor) => tensor,
26        }
27    }
28}
29
30struct LabeledTensor<'a> {
31    tensor: TensorValue<'a>,
32    labels: Vec<u32>,
33}
34
35impl LabeledTensor<'_> {
36    fn tensor(&self) -> &Tensor {
37        self.tensor.as_tensor()
38    }
39
40    fn shape(&self) -> &[usize] {
41        self.tensor().shape()
42    }
43}
44
45fn eager_invalid_config(message: impl Into<String>) -> Error {
46    Error::InvalidConfig {
47        op: EAGER_EINSUM_OP,
48        message: message.into(),
49    }
50}
51
52fn take_labeled<'a>(
53    labeled: &mut [Option<LabeledTensor<'a>>],
54    index: usize,
55    role: &'static str,
56) -> Result<LabeledTensor<'a>> {
57    labeled
58        .get_mut(index)
59        .ok_or_else(|| eager_invalid_config(format!("missing {role} operand at index {index}")))?
60        .take()
61        .ok_or_else(|| eager_invalid_config(format!("missing {role} operand at index {index}")))
62}
63
64fn find_label_axis(labels: &[u32], label: u32) -> Result<usize> {
65    labels
66        .iter()
67        .position(|candidate| *candidate == label)
68        .ok_or_else(|| eager_invalid_config(format!("label {label} missing from tensor labels")))
69}
70
71fn label_size(label: u32, operands: &[&LabeledTensor<'_>]) -> Result<usize> {
72    for operand in operands {
73        if let Some(axis) = operand
74            .labels
75            .iter()
76            .position(|candidate| *candidate == label)
77        {
78            return Ok(operand.shape()[axis]);
79        }
80    }
81    Err(eager_invalid_config(format!(
82        "label {label} missing from eager einsum operands"
83    )))
84}
85
86fn reduce_tensor<'a>(
87    exec: &mut dyn TensorExec,
88    operand: LabeledTensor<'a>,
89    reduce_labels: &HashSet<u32>,
90) -> Result<LabeledTensor<'a>> {
91    if reduce_labels.is_empty() {
92        return Ok(operand);
93    }
94
95    let reduce_axes: Vec<usize> = operand
96        .labels
97        .iter()
98        .enumerate()
99        .filter(|(_, label)| reduce_labels.contains(label))
100        .map(|(axis, _)| axis)
101        .collect();
102    if reduce_axes.is_empty() {
103        return Ok(operand);
104    }
105
106    let reduce_set: HashSet<usize> = reduce_axes.iter().copied().collect();
107    let labels: Vec<u32> = operand
108        .labels
109        .iter()
110        .enumerate()
111        .filter(|(axis, _)| !reduce_set.contains(axis))
112        .map(|(_, label)| *label)
113        .collect();
114    let tensor = exec.reduce_sum(operand.tensor(), &reduce_axes)?;
115    Ok(LabeledTensor {
116        tensor: TensorValue::Owned(tensor),
117        labels,
118    })
119}
120
121fn diagonalize_repeated<'a>(
122    exec: &mut dyn TensorExec,
123    mut operand: LabeledTensor<'a>,
124) -> Result<LabeledTensor<'a>> {
125    loop {
126        let mut seen = HashMap::new();
127        let mut repeated_pair = None;
128        for (axis, label) in operand.labels.iter().copied().enumerate() {
129            if let Some(first_axis) = seen.insert(label, axis) {
130                repeated_pair = Some((first_axis, axis));
131                break;
132            }
133        }
134
135        let Some((axis_a, axis_b)) = repeated_pair else {
136            return Ok(operand);
137        };
138
139        let tensor = exec.extract_diagonal(operand.tensor(), axis_a, axis_b)?;
140        let mut labels = operand.labels;
141        labels.remove(axis_b);
142        operand = LabeledTensor {
143            tensor: TensorValue::Owned(tensor),
144            labels,
145        };
146    }
147}
148
149fn embed_repeated<'a>(
150    exec: &mut dyn TensorExec,
151    mut operand: LabeledTensor<'a>,
152    output_labels: &[u32],
153) -> Result<LabeledTensor<'a>> {
154    loop {
155        let mut embedded = false;
156        for &label in output_labels {
157            let current_count = operand
158                .labels
159                .iter()
160                .filter(|candidate| **candidate == label)
161                .count();
162            let output_count = output_labels
163                .iter()
164                .filter(|candidate| **candidate == label)
165                .count();
166            if output_count > current_count {
167                let axis_a = find_label_axis(&operand.labels, label)?;
168                let axis_b = axis_a + 1;
169                let tensor = exec.embed_diagonal(operand.tensor(), axis_a, axis_b)?;
170                let mut labels = operand.labels;
171                labels.insert(axis_b, label);
172                operand = LabeledTensor {
173                    tensor: TensorValue::Owned(tensor),
174                    labels,
175                };
176                embedded = true;
177                break;
178            }
179        }
180
181        if !embedded {
182            return Ok(operand);
183        }
184    }
185}
186
187fn transpose_to_labels<'a>(
188    exec: &mut dyn TensorExec,
189    operand: LabeledTensor<'a>,
190    target_labels: &[u32],
191) -> Result<LabeledTensor<'a>> {
192    if operand.labels == target_labels {
193        return Ok(operand);
194    }
195
196    let perm: Vec<usize> = target_labels
197        .iter()
198        .map(|label| find_label_axis(&operand.labels, *label))
199        .collect::<Result<_>>()?;
200    if perm
201        .iter()
202        .enumerate()
203        .all(|(axis, target)| axis == *target)
204    {
205        return Ok(operand);
206    }
207
208    let tensor = exec.transpose(operand.tensor(), &perm)?;
209    Ok(LabeledTensor {
210        tensor: TensorValue::Owned(tensor),
211        labels: target_labels.to_vec(),
212    })
213}
214
215fn outer_product<'a>(
216    exec: &mut dyn TensorExec,
217    lhs: LabeledTensor<'a>,
218    rhs: LabeledTensor<'a>,
219    batch_labels: &[u32],
220    lhs_free_labels: &[u32],
221    rhs_free_labels: &[u32],
222) -> Result<LabeledTensor<'a>> {
223    if lhs.labels == rhs.labels {
224        let tensor = exec.mul(lhs.tensor(), rhs.tensor())?;
225        return Ok(LabeledTensor {
226            tensor: TensorValue::Owned(tensor),
227            labels: lhs.labels,
228        });
229    }
230
231    let combined_labels: Vec<u32> = lhs_free_labels
232        .iter()
233        .chain(rhs_free_labels.iter())
234        .chain(batch_labels.iter())
235        .copied()
236        .collect();
237    let combined_shape: Vec<usize> = combined_labels
238        .iter()
239        .map(|label| label_size(*label, &[&lhs, &rhs]))
240        .collect::<Result<_>>()?;
241    let lhs_dims: Vec<usize> = lhs
242        .labels
243        .iter()
244        .map(|label| find_label_axis(&combined_labels, *label))
245        .collect::<Result<_>>()?;
246    let rhs_dims: Vec<usize> = rhs
247        .labels
248        .iter()
249        .map(|label| find_label_axis(&combined_labels, *label))
250        .collect::<Result<_>>()?;
251
252    let lhs_tensor = exec.broadcast_in_dim(lhs.tensor(), &combined_shape, &lhs_dims)?;
253    let rhs_tensor = exec.broadcast_in_dim(rhs.tensor(), &combined_shape, &rhs_dims)?;
254    let tensor = exec.mul(&lhs_tensor, &rhs_tensor)?;
255    Ok(LabeledTensor {
256        tensor: TensorValue::Owned(tensor),
257        labels: combined_labels,
258    })
259}
260
261fn binary_contract<'a>(
262    exec: &mut dyn TensorExec,
263    lhs: LabeledTensor<'a>,
264    rhs: LabeledTensor<'a>,
265    survive_labels: &[u32],
266    reorder_result: bool,
267) -> Result<LabeledTensor<'a>> {
268    let survive_set: HashSet<u32> = survive_labels.iter().copied().collect();
269    let rhs_label_set: HashSet<u32> = rhs.labels.iter().copied().collect();
270    let lhs_label_set: HashSet<u32> = lhs.labels.iter().copied().collect();
271
272    let lhs_reduce: HashSet<u32> = lhs
273        .labels
274        .iter()
275        .filter(|label| !rhs_label_set.contains(label) && !survive_set.contains(label))
276        .copied()
277        .collect();
278    let rhs_reduce: HashSet<u32> = rhs
279        .labels
280        .iter()
281        .filter(|label| !lhs_label_set.contains(label) && !survive_set.contains(label))
282        .copied()
283        .collect();
284
285    let lhs = reduce_tensor(exec, lhs, &lhs_reduce)?;
286    let rhs = reduce_tensor(exec, rhs, &rhs_reduce)?;
287
288    let lhs_label_set: HashSet<u32> = lhs.labels.iter().copied().collect();
289    let rhs_label_set: HashSet<u32> = rhs.labels.iter().copied().collect();
290
291    let mut batch_labels = Vec::new();
292    let mut contracting_labels = Vec::new();
293    let mut lhs_free_labels = Vec::new();
294    let mut rhs_free_labels = Vec::new();
295
296    for &label in &lhs.labels {
297        if rhs_label_set.contains(&label) {
298            if survive_set.contains(&label) {
299                if !batch_labels.contains(&label) {
300                    batch_labels.push(label);
301                }
302            } else if !contracting_labels.contains(&label) {
303                contracting_labels.push(label);
304            }
305        } else if !lhs_free_labels.contains(&label) {
306            lhs_free_labels.push(label);
307        }
308    }
309
310    for &label in &rhs.labels {
311        if !lhs_label_set.contains(&label) && !rhs_free_labels.contains(&label) {
312            rhs_free_labels.push(label);
313        }
314    }
315
316    let result = if contracting_labels.is_empty() {
317        outer_product(
318            exec,
319            lhs,
320            rhs,
321            &batch_labels,
322            &lhs_free_labels,
323            &rhs_free_labels,
324        )?
325    } else {
326        let lhs_contracting_dims: Vec<usize> = contracting_labels
327            .iter()
328            .map(|label| find_label_axis(&lhs.labels, *label))
329            .collect::<Result<_>>()?;
330        let rhs_contracting_dims: Vec<usize> = contracting_labels
331            .iter()
332            .map(|label| find_label_axis(&rhs.labels, *label))
333            .collect::<Result<_>>()?;
334        let lhs_batch_dims: Vec<usize> = batch_labels
335            .iter()
336            .map(|label| find_label_axis(&lhs.labels, *label))
337            .collect::<Result<_>>()?;
338        let rhs_batch_dims: Vec<usize> = batch_labels
339            .iter()
340            .map(|label| find_label_axis(&rhs.labels, *label))
341            .collect::<Result<_>>()?;
342        let labels: Vec<u32> = lhs_free_labels
343            .iter()
344            .chain(rhs_free_labels.iter())
345            .chain(batch_labels.iter())
346            .copied()
347            .collect();
348        let config = DotGeneralConfig {
349            lhs_contracting_dims,
350            rhs_contracting_dims,
351            lhs_batch_dims,
352            rhs_batch_dims,
353            lhs_rank: lhs.labels.len(),
354            rhs_rank: rhs.labels.len(),
355        };
356        let tensor = exec.dot_general(lhs.tensor(), rhs.tensor(), &config)?;
357        LabeledTensor {
358            tensor: TensorValue::Owned(tensor),
359            labels,
360        }
361    };
362
363    if !reorder_result {
364        return Ok(result);
365    }
366
367    let result_label_set: HashSet<u32> = result.labels.iter().copied().collect();
368    let target_labels: Vec<u32> = survive_labels
369        .iter()
370        .filter(|label| result_label_set.contains(label))
371        .copied()
372        .collect();
373    transpose_to_labels(exec, result, &target_labels)
374}
375
376fn eager_einsum_exec(
377    exec: &mut dyn TensorExec,
378    inputs: &[&Tensor],
379    tree: &ContractionTree,
380) -> Result<Tensor> {
381    let subscripts = &tree.subscripts;
382    let n_inputs = subscripts.inputs.len();
383    let output_labels = &subscripts.output;
384
385    let mut labeled: Vec<Option<LabeledTensor<'_>>> = inputs
386        .iter()
387        .zip(subscripts.inputs.iter())
388        .map(|(tensor, labels)| {
389            Some(LabeledTensor {
390                tensor: TensorValue::Borrowed(tensor),
391                labels: labels.clone(),
392            })
393        })
394        .collect();
395
396    for index in 0..labeled.len() {
397        let operand = take_labeled(&mut labeled, index, "input")?;
398        labeled[index] = Some(diagonalize_repeated(exec, operand)?);
399    }
400
401    if n_inputs == 1 || tree.step_count() == 0 {
402        let operand = take_labeled(&mut labeled, 0, "input")?;
403        let output_set: HashSet<u32> = output_labels.iter().copied().collect();
404        let reduce_labels: HashSet<u32> = operand
405            .labels
406            .iter()
407            .filter(|label| !output_set.contains(label))
408            .copied()
409            .collect();
410        let reduced = reduce_tensor(exec, operand, &reduce_labels)?;
411        let embedded = embed_repeated(exec, reduced, output_labels)?;
412        let reordered = transpose_to_labels(exec, embedded, output_labels)?;
413        return Ok(reordered.tensor.into_tensor());
414    }
415
416    for step_idx in 0..tree.step_count() {
417        let (left, right) = tree.step_pair(step_idx).ok_or_else(|| {
418            eager_invalid_config(format!("missing contraction pair for step {step_idx}"))
419        })?;
420        let (_, _, step_output_labels) = tree.step_subscripts(step_idx).ok_or_else(|| {
421            eager_invalid_config(format!(
422                "missing contraction subscripts for step {step_idx}"
423            ))
424        })?;
425        let lhs = take_labeled(&mut labeled, left, "lhs")?;
426        let rhs = take_labeled(&mut labeled, right, "rhs")?;
427        let result = binary_contract(
428            exec,
429            lhs,
430            rhs,
431            step_output_labels,
432            step_idx + 1 == tree.step_count(),
433        )?;
434        labeled.push(Some(result));
435    }
436
437    let final_index = n_inputs + tree.step_count() - 1;
438    let result = take_labeled(&mut labeled, final_index, "result")?;
439    let output_set: HashSet<u32> = output_labels.iter().copied().collect();
440    let extra_labels: HashSet<u32> = result
441        .labels
442        .iter()
443        .filter(|label| !output_set.contains(label))
444        .copied()
445        .collect();
446    let reduced = reduce_tensor(exec, result, &extra_labels)?;
447    let reordered = transpose_to_labels(exec, reduced, output_labels)?;
448    Ok(reordered.tensor.into_tensor())
449}
450
451/// Eager N-ary einsum on concrete [`Tensor`] values.
452///
453/// This applies the same contraction-tree optimization strategy used by the
454/// traced einsum path, but executes each contraction immediately against the
455/// provided backend context.
456///
457/// # Examples
458///
459/// ```
460/// use tenferro_einsum::eager_einsum;
461/// use tenferro_tensor::{Tensor, TensorBackend, cpu::CpuBackend};
462///
463/// let mut ctx = CpuBackend::new();
464/// let a = Tensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
465/// let b = Tensor::from_vec(vec![3, 2], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
466/// let c = eager_einsum(&mut ctx, &[&a, &b], "ij,jk->ik").unwrap();
467///
468/// assert_eq!(c.shape(), &[2, 2]);
469/// assert_eq!(c.as_slice::<f64>().unwrap(), &[22.0, 28.0, 49.0, 64.0]);
470/// ```
471pub fn eager_einsum(
472    ctx: &mut impl TensorBackend,
473    inputs: &[&Tensor],
474    subscripts: &str,
475) -> Result<Tensor> {
476    if inputs.is_empty() {
477        return Err(eager_invalid_config(
478            "eager einsum requires at least one input tensor",
479        ));
480    }
481
482    let subs = Subscripts::parse(subscripts)
483        .map_err(|err| eager_invalid_config(format!("invalid subscripts: {err}")))?;
484    if subs.inputs.len() != inputs.len() {
485        return Err(eager_invalid_config(format!(
486            "eager einsum subscripts expect {} inputs, got {}",
487            subs.inputs.len(),
488            inputs.len()
489        )));
490    }
491
492    let shapes: Vec<&[usize]> = inputs.iter().map(|tensor| tensor.shape()).collect();
493    let tree = ContractionTree::optimize(&subs, &shapes).map_err(|err| {
494        eager_invalid_config(format!("failed to optimize contraction tree: {err}"))
495    })?;
496    ctx.with_exec_session(|exec| eager_einsum_exec(exec, inputs, &tree))
497}