tenferro_internal_frontend_core/
structured_einsum.rs

1use std::collections::{HashMap, HashSet};
2
3use chainrules_core::Differentiable as _;
4use tenferro_algebra::{Conjugate, HasAlgebra, Scalar, Standard};
5use tenferro_einsum::{self as tf_einsum, EinsumBackend, Subscripts};
6use tenferro_internal_error::{Error, Result};
7use tenferro_prims::{TensorSemiringCore, TensorSemiringFastPath, TensorTempPoolContext};
8use tenferro_tensor::Tensor;
9
10use crate::structured_meta::{plan_axis_classes_for_subscripts, OperandAxisClasses};
11use crate::{DynTensorTyped, StructuredTensor};
12
13#[doc(hidden)]
14pub trait StructuredEinsumRuntimeValue:
15    Scalar + DynTensorTyped + HasAlgebra<Algebra = Standard<Self>> + Conjugate + 'static
16{
17}
18
19impl<T> StructuredEinsumRuntimeValue for T where
20    T: Scalar + DynTensorTyped + HasAlgebra<Algebra = Standard<T>> + Conjugate + 'static
21{
22}
23
24#[doc(hidden)]
25pub trait StructuredDenseEinsumBackend<T, C>:
26    EinsumBackend<Standard<T>>
27    + TensorSemiringCore<Standard<T>, Context = C>
28    + TensorSemiringFastPath<
29        Standard<T>,
30        Context = C,
31        Plan = <Self as TensorSemiringCore<Standard<T>>>::Plan,
32    >
33where
34    T: StructuredEinsumRuntimeValue,
35{
36}
37
38impl<T, C, B> StructuredDenseEinsumBackend<T, C> for B
39where
40    T: StructuredEinsumRuntimeValue,
41    B: EinsumBackend<Standard<T>>
42        + TensorSemiringCore<Standard<T>, Context = C>
43        + TensorSemiringFastPath<
44            Standard<T>,
45            Context = C,
46            Plan = <B as TensorSemiringCore<Standard<T>>>::Plan,
47        >,
48{
49}
50
51pub fn to_dense_in_ctx<B, C, T>(ctx: &mut C, tensor: &StructuredTensor<T>) -> Result<Tensor<T>>
52where
53    T: StructuredEinsumRuntimeValue,
54    B: StructuredDenseEinsumBackend<T, C>,
55    C: TensorTempPoolContext,
56{
57    if tensor.is_dense() {
58        return Ok(tensor.payload().clone());
59    }
60
61    let input_labels = usize_vec_to_u32(&(0..tensor.payload().dims().len()).collect::<Vec<_>>())?;
62    let output_labels = usize_vec_to_u32(tensor.axis_classes())?;
63    let inputs = [input_labels.as_slice()];
64    let subs = Subscripts::new(&inputs, &output_labels);
65    let out =
66        tf_einsum::einsum_with_subscripts::<Standard<T>, B>(ctx, &subs, &[tensor.payload()], None)
67            .map_err(Error::from)?;
68    if out.dims() != tensor.logical_dims() {
69        return Err(Error::InvalidTensorOperands {
70            message: format!(
71                "structured_to_dense output shape mismatch: expected {:?}, got {:?}",
72                tensor.logical_dims(),
73                out.dims()
74            ),
75        });
76    }
77    Ok(out)
78}
79
80pub fn compress_dense_to_layout_in_ctx<B, C, T>(
81    ctx: &mut C,
82    dense: &Tensor<T>,
83    layout: &StructuredTensor<T>,
84) -> Result<StructuredTensor<T>>
85where
86    T: StructuredEinsumRuntimeValue,
87    B: StructuredDenseEinsumBackend<T, C>,
88    C: TensorTempPoolContext,
89{
90    if dense.dims() != layout.logical_dims() {
91        return Err(Error::InvalidTensorOperands {
92            message: format!(
93                "structured compression shape mismatch: expected {:?}, got {:?}",
94                layout.logical_dims(),
95                dense.dims()
96            ),
97        });
98    }
99    if layout.is_dense() {
100        return Ok(StructuredTensor(
101            tenferro_tensor::StructuredTensor::from_dense(dense.clone()),
102        ));
103    }
104
105    let input_labels = usize_vec_to_u32(layout.axis_classes())?;
106    let output_labels = usize_vec_to_u32(&(0..layout.class_count()).collect::<Vec<_>>())?;
107    let inputs = [input_labels.as_slice()];
108    let subs = Subscripts::new(&inputs, &output_labels);
109    let payload = tf_einsum::einsum_with_subscripts::<Standard<T>, B>(ctx, &subs, &[dense], None)
110        .map_err(Error::from)?;
111    Ok(StructuredTensor(layout.0.with_payload_like(payload)?))
112}
113
114pub fn einsum_with_subscripts_in_ctx<B, C, T>(
115    ctx: &mut C,
116    subscripts: &Subscripts,
117    operands: &[&StructuredTensor<T>],
118) -> Result<StructuredTensor<T>>
119where
120    T: StructuredEinsumRuntimeValue,
121    B: StructuredDenseEinsumBackend<T, C>,
122    C: TensorTempPoolContext,
123{
124    let operand_meta: Vec<OperandAxisClasses> = operands
125        .iter()
126        .map(|operand| {
127            OperandAxisClasses::new(
128                operand.logical_dims().to_vec(),
129                operand.axis_classes().to_vec(),
130            )
131        })
132        .collect::<std::result::Result<Vec<_>, _>>()
133        .map_err(|e| Error::InvalidTensorOperands {
134            message: format!("invalid structured operand metadata: {e}"),
135        })?;
136    let plan = plan_axis_classes_for_subscripts(&operand_meta, subscripts).map_err(|e| {
137        Error::InvalidTensorOperands {
138            message: format!("failed to plan structured einsum: {e}"),
139        }
140    })?;
141
142    let mut normalized_payloads: Vec<Tensor<T>> = Vec::with_capacity(operands.len());
143    let mut normalized_roots: Vec<Vec<usize>> = Vec::with_capacity(operands.len());
144
145    for (operand_idx, operand) in operands.iter().enumerate() {
146        let class_roots = &plan.operand_plans[operand_idx].class_roots;
147        if operand.payload().dims().len() != class_roots.len() {
148            return Err(Error::InvalidTensorOperands {
149                message: format!(
150                    "operand {} payload rank {} does not match planned local class count {}",
151                    operand_idx,
152                    operand.payload().dims().len(),
153                    class_roots.len()
154                ),
155            });
156        }
157        let (normalized, roots) =
158            normalize_payload_for_roots::<B, _, T>(ctx, operand.payload(), class_roots)?;
159        normalized_payloads.push(normalized);
160        normalized_roots.push(roots);
161    }
162
163    let input_labels_u32: Vec<Vec<u32>> = normalized_roots
164        .iter()
165        .map(|roots| usize_vec_to_u32(roots))
166        .collect::<Result<_>>()?;
167    let output_labels_u32 = usize_vec_to_u32(&plan.output_compressed_roots)?;
168    let input_refs: Vec<&[u32]> = input_labels_u32.iter().map(Vec::as_slice).collect();
169    let payload_refs: Vec<&Tensor<T>> = normalized_payloads.iter().collect();
170    let backend_subs = Subscripts::new(&input_refs, &output_labels_u32);
171
172    let compressed_output = tf_einsum::einsum_with_subscripts::<Standard<T>, B>(
173        ctx,
174        &backend_subs,
175        &payload_refs,
176        None,
177    )
178    .map_err(Error::from)?;
179
180    Ok(StructuredTensor(tenferro_tensor::StructuredTensor::new(
181        plan.output_dims.clone(),
182        plan.output_axis_classes.clone(),
183        compressed_output,
184    )?))
185}
186
187pub fn accumulate_tangent<T>(
188    lhs: StructuredTensor<T>,
189    rhs: &StructuredTensor<T>,
190) -> Result<StructuredTensor<T>>
191where
192    T: Scalar,
193{
194    if lhs.logical_dims() != rhs.logical_dims() || lhs.axis_classes() != rhs.axis_classes() {
195        return Err(Error::InvalidTensorOperands {
196            message: format!(
197                "structured tangent layout mismatch: lhs dims {:?} classes {:?}, rhs dims {:?} classes {:?}",
198                lhs.logical_dims(),
199                lhs.axis_classes(),
200                rhs.logical_dims(),
201                rhs.axis_classes(),
202            ),
203        });
204    }
205
206    let logical_dims = lhs.logical_dims().to_vec();
207    let axis_classes = lhs.axis_classes().to_vec();
208    let payload = Tensor::<T>::accumulate_tangent(lhs.0.into_payload(), rhs.payload());
209    Ok(StructuredTensor(tenferro_tensor::StructuredTensor::new(
210        logical_dims,
211        axis_classes,
212        payload,
213    )?))
214}
215
216pub fn reverse_subscripts(subscripts: &Subscripts, input_idx: usize) -> Subscripts {
217    let mut rev_inputs = vec![subscripts.output.clone()];
218    for (idx, input) in subscripts.inputs.iter().enumerate() {
219        if idx != input_idx {
220            rev_inputs.push(input.clone());
221        }
222    }
223    Subscripts {
224        inputs: rev_inputs,
225        output: subscripts.inputs[input_idx].clone(),
226    }
227}
228
229#[doc(hidden)]
230pub fn unique_ids_first_appearance(ids: &[usize]) -> Vec<usize> {
231    let mut seen = HashSet::new();
232    let mut out = Vec::new();
233    for &id in ids {
234        if seen.insert(id) {
235            out.push(id);
236        }
237    }
238    out
239}
240
241#[doc(hidden)]
242pub fn first_duplicate_pair(ids: &[usize]) -> Option<(usize, usize)> {
243    let mut first_pos: HashMap<usize, usize> = HashMap::new();
244    for (pos, &id) in ids.iter().enumerate() {
245        if let Some(&first) = first_pos.get(&id) {
246            return Some((first, pos));
247        }
248        first_pos.insert(id, pos);
249    }
250    None
251}
252
253#[doc(hidden)]
254pub fn normalize_payload_for_roots<B, C, T>(
255    ctx: &mut C,
256    payload: &Tensor<T>,
257    roots: &[usize],
258) -> Result<(Tensor<T>, Vec<usize>)>
259where
260    T: StructuredEinsumRuntimeValue,
261    B: StructuredDenseEinsumBackend<T, C>,
262    C: TensorTempPoolContext,
263{
264    if payload.dims().len() != roots.len() {
265        return Err(Error::InvalidTensorOperands {
266            message: format!(
267                "payload rank {} must match roots length {}",
268                payload.dims().len(),
269                roots.len()
270            ),
271        });
272    }
273    if unique_ids_first_appearance(roots).len() == roots.len() {
274        return Ok((payload.clone(), roots.to_vec()));
275    }
276
277    let mut current_payload = payload.clone();
278    let mut current_roots = roots.to_vec();
279    let mut round = 0u32;
280
281    while let Some((pos_a, pos_b)) = first_duplicate_pair(&current_roots) {
282        let rank = current_roots.len();
283        debug_assert!(pos_b < rank);
284        let base = 1_000_000u32.saturating_add(round.saturating_mul(10_000));
285        let mut input_labels: Vec<u32> = (0..rank).map(|i| base + i as u32).collect();
286        input_labels[pos_b] = input_labels[pos_a];
287        let output_labels: Vec<u32> = input_labels
288            .iter()
289            .enumerate()
290            .filter_map(|(axis, &label)| (axis != pos_b).then_some(label))
291            .collect();
292        let inputs = [input_labels.as_slice()];
293        let subs = Subscripts::new(&inputs, &output_labels);
294        current_payload = tf_einsum::einsum_with_subscripts::<Standard<T>, B>(
295            ctx,
296            &subs,
297            &[&current_payload],
298            None,
299        )
300        .map_err(Error::from)?;
301        current_roots.remove(pos_b);
302        round = round.saturating_add(1);
303    }
304
305    Ok((current_payload, current_roots))
306}
307
308#[doc(hidden)]
309pub fn usize_vec_to_u32(values: &[usize]) -> Result<Vec<u32>> {
310    values
311        .iter()
312        .map(|&v| {
313            u32::try_from(v).map_err(|_| Error::InvalidTensorOperands {
314                message: format!("label id {} does not fit into u32", v),
315            })
316        })
317        .collect()
318}