tenferro_internal_frontend_core/
structured_einsum.rs1use std::collections::{HashMap, HashSet};
2
3use chainrules_core::Differentiable as _;
4use tenferro_algebra::{Conjugate, HasAlgebra, Scalar, Standard};
5use tenferro_einsum::{self as tf_einsum, EinsumBackend, Subscripts};
6use tenferro_internal_error::{Error, Result};
7use tenferro_prims::{TensorSemiringCore, TensorSemiringFastPath, TensorTempPoolContext};
8use tenferro_tensor::Tensor;
9
10use crate::structured_meta::{plan_axis_classes_for_subscripts, OperandAxisClasses};
11use crate::{DynTensorTyped, StructuredTensor};
12
13#[doc(hidden)]
14pub trait StructuredEinsumRuntimeValue:
15 Scalar + DynTensorTyped + HasAlgebra<Algebra = Standard<Self>> + Conjugate + 'static
16{
17}
18
19impl<T> StructuredEinsumRuntimeValue for T where
20 T: Scalar + DynTensorTyped + HasAlgebra<Algebra = Standard<T>> + Conjugate + 'static
21{
22}
23
24#[doc(hidden)]
25pub trait StructuredDenseEinsumBackend<T, C>:
26 EinsumBackend<Standard<T>>
27 + TensorSemiringCore<Standard<T>, Context = C>
28 + TensorSemiringFastPath<
29 Standard<T>,
30 Context = C,
31 Plan = <Self as TensorSemiringCore<Standard<T>>>::Plan,
32 >
33where
34 T: StructuredEinsumRuntimeValue,
35{
36}
37
38impl<T, C, B> StructuredDenseEinsumBackend<T, C> for B
39where
40 T: StructuredEinsumRuntimeValue,
41 B: EinsumBackend<Standard<T>>
42 + TensorSemiringCore<Standard<T>, Context = C>
43 + TensorSemiringFastPath<
44 Standard<T>,
45 Context = C,
46 Plan = <B as TensorSemiringCore<Standard<T>>>::Plan,
47 >,
48{
49}
50
51pub fn to_dense_in_ctx<B, C, T>(ctx: &mut C, tensor: &StructuredTensor<T>) -> Result<Tensor<T>>
52where
53 T: StructuredEinsumRuntimeValue,
54 B: StructuredDenseEinsumBackend<T, C>,
55 C: TensorTempPoolContext,
56{
57 if tensor.is_dense() {
58 return Ok(tensor.payload().clone());
59 }
60
61 let input_labels = usize_vec_to_u32(&(0..tensor.payload().dims().len()).collect::<Vec<_>>())?;
62 let output_labels = usize_vec_to_u32(tensor.axis_classes())?;
63 let inputs = [input_labels.as_slice()];
64 let subs = Subscripts::new(&inputs, &output_labels);
65 let out =
66 tf_einsum::einsum_with_subscripts::<Standard<T>, B>(ctx, &subs, &[tensor.payload()], None)
67 .map_err(Error::from)?;
68 if out.dims() != tensor.logical_dims() {
69 return Err(Error::InvalidTensorOperands {
70 message: format!(
71 "structured_to_dense output shape mismatch: expected {:?}, got {:?}",
72 tensor.logical_dims(),
73 out.dims()
74 ),
75 });
76 }
77 Ok(out)
78}
79
80pub fn compress_dense_to_layout_in_ctx<B, C, T>(
81 ctx: &mut C,
82 dense: &Tensor<T>,
83 layout: &StructuredTensor<T>,
84) -> Result<StructuredTensor<T>>
85where
86 T: StructuredEinsumRuntimeValue,
87 B: StructuredDenseEinsumBackend<T, C>,
88 C: TensorTempPoolContext,
89{
90 if dense.dims() != layout.logical_dims() {
91 return Err(Error::InvalidTensorOperands {
92 message: format!(
93 "structured compression shape mismatch: expected {:?}, got {:?}",
94 layout.logical_dims(),
95 dense.dims()
96 ),
97 });
98 }
99 if layout.is_dense() {
100 return Ok(StructuredTensor(
101 tenferro_tensor::StructuredTensor::from_dense(dense.clone()),
102 ));
103 }
104
105 let input_labels = usize_vec_to_u32(layout.axis_classes())?;
106 let output_labels = usize_vec_to_u32(&(0..layout.class_count()).collect::<Vec<_>>())?;
107 let inputs = [input_labels.as_slice()];
108 let subs = Subscripts::new(&inputs, &output_labels);
109 let payload = tf_einsum::einsum_with_subscripts::<Standard<T>, B>(ctx, &subs, &[dense], None)
110 .map_err(Error::from)?;
111 Ok(StructuredTensor(layout.0.with_payload_like(payload)?))
112}
113
114pub fn einsum_with_subscripts_in_ctx<B, C, T>(
115 ctx: &mut C,
116 subscripts: &Subscripts,
117 operands: &[&StructuredTensor<T>],
118) -> Result<StructuredTensor<T>>
119where
120 T: StructuredEinsumRuntimeValue,
121 B: StructuredDenseEinsumBackend<T, C>,
122 C: TensorTempPoolContext,
123{
124 let operand_meta: Vec<OperandAxisClasses> = operands
125 .iter()
126 .map(|operand| {
127 OperandAxisClasses::new(
128 operand.logical_dims().to_vec(),
129 operand.axis_classes().to_vec(),
130 )
131 })
132 .collect::<std::result::Result<Vec<_>, _>>()
133 .map_err(|e| Error::InvalidTensorOperands {
134 message: format!("invalid structured operand metadata: {e}"),
135 })?;
136 let plan = plan_axis_classes_for_subscripts(&operand_meta, subscripts).map_err(|e| {
137 Error::InvalidTensorOperands {
138 message: format!("failed to plan structured einsum: {e}"),
139 }
140 })?;
141
142 let mut normalized_payloads: Vec<Tensor<T>> = Vec::with_capacity(operands.len());
143 let mut normalized_roots: Vec<Vec<usize>> = Vec::with_capacity(operands.len());
144
145 for (operand_idx, operand) in operands.iter().enumerate() {
146 let class_roots = &plan.operand_plans[operand_idx].class_roots;
147 if operand.payload().dims().len() != class_roots.len() {
148 return Err(Error::InvalidTensorOperands {
149 message: format!(
150 "operand {} payload rank {} does not match planned local class count {}",
151 operand_idx,
152 operand.payload().dims().len(),
153 class_roots.len()
154 ),
155 });
156 }
157 let (normalized, roots) =
158 normalize_payload_for_roots::<B, _, T>(ctx, operand.payload(), class_roots)?;
159 normalized_payloads.push(normalized);
160 normalized_roots.push(roots);
161 }
162
163 let input_labels_u32: Vec<Vec<u32>> = normalized_roots
164 .iter()
165 .map(|roots| usize_vec_to_u32(roots))
166 .collect::<Result<_>>()?;
167 let output_labels_u32 = usize_vec_to_u32(&plan.output_compressed_roots)?;
168 let input_refs: Vec<&[u32]> = input_labels_u32.iter().map(Vec::as_slice).collect();
169 let payload_refs: Vec<&Tensor<T>> = normalized_payloads.iter().collect();
170 let backend_subs = Subscripts::new(&input_refs, &output_labels_u32);
171
172 let compressed_output = tf_einsum::einsum_with_subscripts::<Standard<T>, B>(
173 ctx,
174 &backend_subs,
175 &payload_refs,
176 None,
177 )
178 .map_err(Error::from)?;
179
180 Ok(StructuredTensor(tenferro_tensor::StructuredTensor::new(
181 plan.output_dims.clone(),
182 plan.output_axis_classes.clone(),
183 compressed_output,
184 )?))
185}
186
187pub fn accumulate_tangent<T>(
188 lhs: StructuredTensor<T>,
189 rhs: &StructuredTensor<T>,
190) -> Result<StructuredTensor<T>>
191where
192 T: Scalar,
193{
194 if lhs.logical_dims() != rhs.logical_dims() || lhs.axis_classes() != rhs.axis_classes() {
195 return Err(Error::InvalidTensorOperands {
196 message: format!(
197 "structured tangent layout mismatch: lhs dims {:?} classes {:?}, rhs dims {:?} classes {:?}",
198 lhs.logical_dims(),
199 lhs.axis_classes(),
200 rhs.logical_dims(),
201 rhs.axis_classes(),
202 ),
203 });
204 }
205
206 let logical_dims = lhs.logical_dims().to_vec();
207 let axis_classes = lhs.axis_classes().to_vec();
208 let payload = Tensor::<T>::accumulate_tangent(lhs.0.into_payload(), rhs.payload());
209 Ok(StructuredTensor(tenferro_tensor::StructuredTensor::new(
210 logical_dims,
211 axis_classes,
212 payload,
213 )?))
214}
215
216pub fn reverse_subscripts(subscripts: &Subscripts, input_idx: usize) -> Subscripts {
217 let mut rev_inputs = vec![subscripts.output.clone()];
218 for (idx, input) in subscripts.inputs.iter().enumerate() {
219 if idx != input_idx {
220 rev_inputs.push(input.clone());
221 }
222 }
223 Subscripts {
224 inputs: rev_inputs,
225 output: subscripts.inputs[input_idx].clone(),
226 }
227}
228
229#[doc(hidden)]
230pub fn unique_ids_first_appearance(ids: &[usize]) -> Vec<usize> {
231 let mut seen = HashSet::new();
232 let mut out = Vec::new();
233 for &id in ids {
234 if seen.insert(id) {
235 out.push(id);
236 }
237 }
238 out
239}
240
241#[doc(hidden)]
242pub fn first_duplicate_pair(ids: &[usize]) -> Option<(usize, usize)> {
243 let mut first_pos: HashMap<usize, usize> = HashMap::new();
244 for (pos, &id) in ids.iter().enumerate() {
245 if let Some(&first) = first_pos.get(&id) {
246 return Some((first, pos));
247 }
248 first_pos.insert(id, pos);
249 }
250 None
251}
252
253#[doc(hidden)]
254pub fn normalize_payload_for_roots<B, C, T>(
255 ctx: &mut C,
256 payload: &Tensor<T>,
257 roots: &[usize],
258) -> Result<(Tensor<T>, Vec<usize>)>
259where
260 T: StructuredEinsumRuntimeValue,
261 B: StructuredDenseEinsumBackend<T, C>,
262 C: TensorTempPoolContext,
263{
264 if payload.dims().len() != roots.len() {
265 return Err(Error::InvalidTensorOperands {
266 message: format!(
267 "payload rank {} must match roots length {}",
268 payload.dims().len(),
269 roots.len()
270 ),
271 });
272 }
273 if unique_ids_first_appearance(roots).len() == roots.len() {
274 return Ok((payload.clone(), roots.to_vec()));
275 }
276
277 let mut current_payload = payload.clone();
278 let mut current_roots = roots.to_vec();
279 let mut round = 0u32;
280
281 while let Some((pos_a, pos_b)) = first_duplicate_pair(¤t_roots) {
282 let rank = current_roots.len();
283 debug_assert!(pos_b < rank);
284 let base = 1_000_000u32.saturating_add(round.saturating_mul(10_000));
285 let mut input_labels: Vec<u32> = (0..rank).map(|i| base + i as u32).collect();
286 input_labels[pos_b] = input_labels[pos_a];
287 let output_labels: Vec<u32> = input_labels
288 .iter()
289 .enumerate()
290 .filter_map(|(axis, &label)| (axis != pos_b).then_some(label))
291 .collect();
292 let inputs = [input_labels.as_slice()];
293 let subs = Subscripts::new(&inputs, &output_labels);
294 current_payload = tf_einsum::einsum_with_subscripts::<Standard<T>, B>(
295 ctx,
296 &subs,
297 &[¤t_payload],
298 None,
299 )
300 .map_err(Error::from)?;
301 current_roots.remove(pos_b);
302 round = round.saturating_add(1);
303 }
304
305 Ok((current_payload, current_roots))
306}
307
308#[doc(hidden)]
309pub fn usize_vec_to_u32(values: &[usize]) -> Result<Vec<u32>> {
310 values
311 .iter()
312 .map(|&v| {
313 u32::try_from(v).map_err(|_| Error::InvalidTensorOperands {
314 message: format!("label id {} does not fit into u32", v),
315 })
316 })
317 .collect()
318}