1use crate::defaults::DynIndex;
2use crate::index_like::IndexLike;
3use crate::index_ops::{common_ind_positions, prepare_contraction, prepare_contraction_pairs};
4use crate::tensor_like::LinearizationOrder;
5use crate::AnyScalar;
6use anyhow::{Context, Result};
7use num_complex::Complex64;
8use num_traits::Zero;
9use rand::Rng;
10use rand_distr::{Distribution, StandardNormal};
11use std::cell::RefCell;
12use std::cmp::Reverse;
13use std::collections::{HashMap, HashSet};
14use std::env;
15use std::sync::{Arc, OnceLock};
16use std::time::{Duration, Instant};
17use tenferro::eager_tensor::einsum_subscripts as eager_einsum_ad;
18use tenferro::{
19 CpuBackend, DType, DotGeneralConfig, EagerTensor, EinsumSubscripts, Tensor as NativeTensor,
20};
21use tensor4all_tensorbackend::{
22 axpby_native_tensor, contract_native_tensor, default_eager_ctx,
23 dense_native_tensor_from_col_major, diag_native_tensor_from_col_major,
24 native_tensor_primal_to_dense_col_major, native_tensor_primal_to_diag_c64,
25 native_tensor_primal_to_diag_f64, native_tensor_primal_to_storage, scale_native_tensor,
26 storage_payload_native_read_input, storage_to_native_tensor, AnyScalar as BackendScalar,
27 StorageScalar, TensorElement,
28};
29use tensor4all_tensorbackend::{Storage, StorageKind};
30
31use super::contract::PairwiseContractionOptions;
32use super::structured_contraction::{
33 normalize_payload_read_for_roots, storage_from_payload_native, storage_payload_native,
34 OperandLayout, StructuredContractionPlan, StructuredContractionSpec,
35};
36
37#[derive(Debug, Default, Clone)]
38struct PairwiseContractProfileEntry {
39 calls: usize,
40 total_time: Duration,
41 total_bytes: usize,
42}
43
44thread_local! {
45 static PAIRWISE_CONTRACT_PROFILE_STATE: RefCell<HashMap<&'static str, PairwiseContractProfileEntry>> =
46 RefCell::new(HashMap::new());
47}
48
49fn pairwise_contract_profile_enabled() -> bool {
50 static ENABLED: OnceLock<bool> = OnceLock::new();
51 *ENABLED.get_or_init(|| env::var("T4A_PROFILE_PAIRWISE_CONTRACT").is_ok())
52}
53
54fn record_pairwise_contract_profile(section: &'static str, elapsed: Duration) {
55 if !pairwise_contract_profile_enabled() {
56 return;
57 }
58 PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| {
59 let mut state = state.borrow_mut();
60 let entry = state.entry(section).or_default();
61 entry.calls += 1;
62 entry.total_time += elapsed;
63 });
64}
65
66fn record_pairwise_contract_profile_bytes(section: &'static str, bytes: usize) {
67 if !pairwise_contract_profile_enabled() {
68 return;
69 }
70 PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| {
71 let mut state = state.borrow_mut();
72 let entry = state.entry(section).or_default();
73 entry.total_bytes += bytes;
74 });
75}
76
77fn profile_pairwise_contract_section<T>(section: &'static str, f: impl FnOnce() -> T) -> T {
78 if !pairwise_contract_profile_enabled() {
79 return f();
80 }
81 let started = Instant::now();
82 let result = f();
83 record_pairwise_contract_profile(section, started.elapsed());
84 result
85}
86
87pub fn reset_pairwise_contract_profile() {
89 PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| state.borrow_mut().clear());
90}
91
92pub fn print_and_reset_pairwise_contract_profile() {
94 if !pairwise_contract_profile_enabled() {
95 return;
96 }
97 PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| {
98 let mut entries: Vec<_> = state
99 .borrow()
100 .iter()
101 .map(|(section, entry)| (*section, entry.clone()))
102 .collect();
103 state.borrow_mut().clear();
104 entries.sort_by_key(|(_, entry)| Reverse(entry.total_time));
105
106 eprintln!("=== TensorDynLen pairwise contract profile ===");
107 for (section, entry) in entries {
108 let per_call_us = if entry.calls == 0 {
109 0.0
110 } else {
111 entry.total_time.as_secs_f64() * 1.0e6 / entry.calls as f64
112 };
113 eprintln!(
114 "{section}: calls={} total={:.6}ms per_call={:.3}us bytes={}",
115 entry.calls,
116 entry.total_time.as_secs_f64() * 1.0e3,
117 per_call_us,
118 entry.total_bytes,
119 );
120 }
121 });
122}
123
124fn native_tensor_profile_bytes(native: &NativeTensor) -> usize {
125 let element_size = match native.dtype() {
126 DType::F32 => 4,
127 DType::F64 => 8,
128 DType::C32 => 8,
129 DType::C64 => 16,
130 DType::I64 => 8,
131 };
132 native.shape().iter().product::<usize>() * element_size
133}
134
135pub trait RandomScalar: TensorElement {
140 fn random_value<R: Rng>(rng: &mut R) -> Self;
142}
143
144impl RandomScalar for f64 {
145 fn random_value<R: Rng>(rng: &mut R) -> Self {
146 StandardNormal.sample(rng)
147 }
148}
149
150impl RandomScalar for Complex64 {
151 fn random_value<R: Rng>(rng: &mut R) -> Self {
152 Complex64::new(StandardNormal.sample(rng), StandardNormal.sample(rng))
153 }
154}
155
156pub fn compute_permutation_from_indices(
189 original_indices: &[DynIndex],
190 new_indices: &[DynIndex],
191) -> Result<Vec<usize>> {
192 anyhow::ensure!(
193 new_indices.len() == original_indices.len(),
194 "new_indices length must match original_indices length"
195 );
196
197 let mut perm = Vec::with_capacity(new_indices.len());
198 let mut used = std::collections::HashSet::new();
199
200 for new_idx in new_indices {
201 let pos = original_indices
204 .iter()
205 .position(|old_idx| old_idx == new_idx)
206 .ok_or_else(|| {
207 anyhow::anyhow!("new_indices must be a permutation of original_indices")
208 })?;
209
210 anyhow::ensure!(used.insert(pos), "duplicate index in new_indices");
211 perm.push(pos);
212 }
213
214 Ok(perm)
215}
216
217#[derive(Clone)]
218pub(crate) struct StructuredAdValue {
219 payload: Arc<EagerTensor<CpuBackend>>,
220 payload_dims: Vec<usize>,
221 axis_classes: Vec<usize>,
222}
223
224#[derive(Clone)]
225pub(crate) enum TensorDynLenStorage {
226 Materialized(Arc<Storage>),
227 Eager {
228 inner: Arc<EagerTensor<CpuBackend>>,
229 axis_classes: Vec<usize>,
230 },
231}
232
233impl TensorDynLenStorage {
234 fn from_storage(storage: Arc<Storage>) -> Self {
235 Self::Materialized(storage)
236 }
237
238 fn from_eager_dense(inner: EagerTensor<CpuBackend>, rank: usize) -> Self {
239 Self::Eager {
240 inner: Arc::new(inner),
241 axis_classes: TensorDynLen::dense_axis_classes(rank),
242 }
243 }
244
245 fn eager(&self) -> Option<&EagerTensor<CpuBackend>> {
246 match self {
247 Self::Materialized(_) => None,
248 Self::Eager { inner, .. } => Some(inner.as_ref()),
249 }
250 }
251
252 fn axis_classes(&self) -> &[usize] {
253 match self {
254 Self::Materialized(storage) => storage.axis_classes(),
255 Self::Eager { axis_classes, .. } => axis_classes,
256 }
257 }
258
259 fn payload_dims(&self) -> &[usize] {
260 match self {
261 Self::Materialized(storage) => storage.payload_dims(),
262 Self::Eager { inner, .. } => inner.data().shape(),
263 }
264 }
265
266 fn payload_strides_vec(&self) -> Vec<isize> {
267 match self {
268 Self::Materialized(storage) => storage.payload_strides().to_vec(),
269 Self::Eager { inner, .. } => {
270 let mut stride = 1isize;
271 inner
272 .data()
273 .shape()
274 .iter()
275 .map(|&dim| {
276 let current = stride;
277 stride *= isize::try_from(dim).unwrap_or(isize::MAX);
278 current
279 })
280 .collect()
281 }
282 }
283 }
284
285 fn is_f64(&self) -> bool {
286 match self {
287 Self::Materialized(storage) => storage.is_f64(),
288 Self::Eager { inner, .. } => inner.data().dtype() == DType::F64,
289 }
290 }
291
292 fn is_c64(&self) -> bool {
293 match self {
294 Self::Materialized(storage) => storage.is_c64(),
295 Self::Eager { inner, .. } => inner.data().dtype() == DType::C64,
296 }
297 }
298
299 fn is_complex(&self) -> bool {
300 match self {
301 Self::Materialized(storage) => storage.is_complex(),
302 Self::Eager { inner, .. } => matches!(inner.data().dtype(), DType::C32 | DType::C64),
303 }
304 }
305
306 fn is_diag(&self) -> bool {
307 match self {
308 Self::Materialized(storage) => storage.is_diag(),
309 Self::Eager { axis_classes, .. } => TensorDynLen::is_diag_axis_classes(axis_classes),
310 }
311 }
312
313 fn storage_kind(&self) -> StorageKind {
314 match self {
315 Self::Materialized(storage) => storage.storage_kind(),
316 Self::Eager { axis_classes, .. } => {
317 if axis_classes.iter().copied().eq(0..axis_classes.len()) {
318 StorageKind::Dense
319 } else if TensorDynLen::is_diag_axis_classes(axis_classes) {
320 StorageKind::Diagonal
321 } else {
322 StorageKind::Structured
323 }
324 }
325 }
326 }
327
328 fn materialize(&self, logical_rank: usize) -> Result<Arc<Storage>> {
329 match self {
330 Self::Materialized(storage) => Ok(Arc::clone(storage)),
331 Self::Eager {
332 inner,
333 axis_classes,
334 } => Ok(Arc::new(
335 TensorDynLen::storage_from_native_with_axis_classes(
336 inner.data(),
337 axis_classes,
338 logical_rank,
339 )?,
340 )),
341 }
342 }
343
344 fn scale(&self, scalar: &BackendScalar) -> Result<Storage> {
345 Ok(self.materialize(self.axis_classes().len())?.scale(scalar))
346 }
347
348 fn conj(&self) -> Result<Self> {
349 match self {
350 Self::Materialized(storage) => Ok(Self::Materialized(Arc::new(storage.conj()))),
351 Self::Eager {
352 inner,
353 axis_classes,
354 } => Ok(Self::Eager {
355 inner: Arc::new(inner.conj()?),
356 axis_classes: axis_classes.clone(),
357 }),
358 }
359 }
360
361 fn max_abs(&self) -> Result<f64> {
362 Ok(self.materialize(self.axis_classes().len())?.max_abs())
363 }
364}
365
366#[derive(Clone)]
418pub struct TensorDynLen {
419 pub indices: Vec<DynIndex>,
421 pub(crate) storage: TensorDynLenStorage,
423 pub(crate) structured_ad: Option<Arc<StructuredAdValue>>,
425 pub(crate) eager_cache: Arc<OnceLock<Arc<EagerTensor<CpuBackend>>>>,
427}
428
429impl TensorDynLen {
430 fn dense_axis_classes(rank: usize) -> Vec<usize> {
431 (0..rank).collect()
432 }
433
434 fn diag_axis_classes(rank: usize) -> Vec<usize> {
435 if rank == 0 {
436 vec![]
437 } else {
438 vec![0; rank]
439 }
440 }
441
442 fn canonicalize_axis_classes(axis_classes: &[usize]) -> Vec<usize> {
443 let mut map = std::collections::HashMap::new();
444 let mut next = 0usize;
445 axis_classes
446 .iter()
447 .map(|&class_id| {
448 *map.entry(class_id).or_insert_with(|| {
449 let canonical = next;
450 next += 1;
451 canonical
452 })
453 })
454 .collect()
455 }
456
457 fn permute_axis_classes(&self, perm: &[usize]) -> Vec<usize> {
458 let axis_classes = self.storage.axis_classes();
459 let permuted: Vec<usize> = perm.iter().map(|&index| axis_classes[index]).collect();
460 Self::canonicalize_axis_classes(&permuted)
461 }
462
463 fn normalize_insert_axis(op: &str, axis: isize, rank: usize) -> Result<usize> {
464 let normalized = if axis < 0 {
465 rank as isize + 1 + axis
466 } else {
467 axis
468 };
469 anyhow::ensure!(
470 normalized >= 0 && normalized <= rank as isize,
471 "{op}: axis {axis} is out of bounds for inserting into rank {rank}"
472 );
473 Ok(normalized as usize)
474 }
475
476 fn is_diag_axis_classes(axis_classes: &[usize]) -> bool {
477 axis_classes.len() >= 2 && axis_classes.iter().all(|&class_id| class_id == 0)
478 }
479
480 fn einsum_subscripts_from_usize_ids(
481 inputs: &[Vec<usize>],
482 output: &[usize],
483 ) -> Result<EinsumSubscripts> {
484 let input_labels = inputs
485 .iter()
486 .map(|ids| {
487 ids.iter()
488 .map(|&id| {
489 u32::try_from(id)
490 .map_err(|_| anyhow::anyhow!("einsum label {id} exceeds u32 range"))
491 })
492 .collect::<Result<Vec<_>>>()
493 })
494 .collect::<Result<Vec<_>>>()?;
495 let output_labels = output
496 .iter()
497 .map(|&id| {
498 u32::try_from(id)
499 .map_err(|_| anyhow::anyhow!("einsum label {id} exceeds u32 range"))
500 })
501 .collect::<Result<Vec<_>>>()?;
502 let input_refs = input_labels.iter().map(Vec::as_slice).collect::<Vec<_>>();
503 Ok(EinsumSubscripts::new(&input_refs, &output_labels))
504 }
505
506 fn build_binary_einsum_subscripts(
507 lhs_rank: usize,
508 axes_a: &[usize],
509 rhs_rank: usize,
510 axes_b: &[usize],
511 ) -> Result<EinsumSubscripts> {
512 anyhow::ensure!(
513 axes_a.len() == axes_b.len(),
514 "contract axis length mismatch: lhs {:?}, rhs {:?}",
515 axes_a,
516 axes_b
517 );
518
519 let mut lhs_ids = vec![usize::MAX; lhs_rank];
520 let mut rhs_ids = vec![usize::MAX; rhs_rank];
521 let mut next_id = 0usize;
522
523 let mut seen_lhs = vec![false; lhs_rank];
524 let mut seen_rhs = vec![false; rhs_rank];
525
526 for (&lhs_axis, &rhs_axis) in axes_a.iter().zip(axes_b.iter()) {
527 anyhow::ensure!(
528 lhs_axis < lhs_rank,
529 "lhs contract axis {lhs_axis} out of range"
530 );
531 anyhow::ensure!(
532 rhs_axis < rhs_rank,
533 "rhs contract axis {rhs_axis} out of range"
534 );
535 anyhow::ensure!(
536 !seen_lhs[lhs_axis],
537 "duplicate lhs contract axis {lhs_axis}"
538 );
539 anyhow::ensure!(
540 !seen_rhs[rhs_axis],
541 "duplicate rhs contract axis {rhs_axis}"
542 );
543 seen_lhs[lhs_axis] = true;
544 seen_rhs[rhs_axis] = true;
545 lhs_ids[lhs_axis] = next_id;
546 rhs_ids[rhs_axis] = next_id;
547 next_id += 1;
548 }
549
550 let mut output_ids = Vec::with_capacity(lhs_rank + rhs_rank - 2 * axes_a.len());
551 for id in &mut lhs_ids {
552 if *id == usize::MAX {
553 *id = next_id;
554 output_ids.push(next_id);
555 next_id += 1;
556 }
557 }
558 for id in &mut rhs_ids {
559 if *id == usize::MAX {
560 *id = next_id;
561 output_ids.push(next_id);
562 next_id += 1;
563 }
564 }
565
566 Self::einsum_subscripts_from_usize_ids(&[lhs_ids, rhs_ids], &output_ids)
567 }
568
569 fn binary_dot_general_config(axes_a: &[usize], axes_b: &[usize]) -> Result<DotGeneralConfig> {
570 anyhow::ensure!(
571 axes_a.len() == axes_b.len(),
572 "contract axis length mismatch: lhs {:?}, rhs {:?}",
573 axes_a,
574 axes_b
575 );
576 Ok(DotGeneralConfig {
577 lhs_contracting_dims: axes_a.to_vec(),
578 rhs_contracting_dims: axes_b.to_vec(),
579 lhs_batch_dims: vec![],
580 rhs_batch_dims: vec![],
581 })
582 }
583
584 fn binary_contraction_axis_classes(
585 lhs_axis_classes: &[usize],
586 axes_a: &[usize],
587 rhs_axis_classes: &[usize],
588 axes_b: &[usize],
589 ) -> Vec<usize> {
590 debug_assert_eq!(axes_a.len(), axes_b.len());
591
592 fn find(parent: &mut [usize], value: usize) -> usize {
593 if parent[value] != value {
594 parent[value] = find(parent, parent[value]);
595 }
596 parent[value]
597 }
598
599 fn union(parent: &mut [usize], lhs: usize, rhs: usize) {
600 let lhs_root = find(parent, lhs);
601 let rhs_root = find(parent, rhs);
602 if lhs_root != rhs_root {
603 parent[rhs_root] = lhs_root;
604 }
605 }
606
607 let lhs_payload_rank = lhs_axis_classes
608 .iter()
609 .copied()
610 .max()
611 .map(|value| value + 1)
612 .unwrap_or(0);
613 let rhs_payload_rank = rhs_axis_classes
614 .iter()
615 .copied()
616 .max()
617 .map(|value| value + 1)
618 .unwrap_or(0);
619 let rhs_offset = lhs_payload_rank;
620 let mut parent: Vec<usize> = (0..lhs_payload_rank + rhs_payload_rank).collect();
621
622 for (&lhs_axis, &rhs_axis) in axes_a.iter().zip(axes_b.iter()) {
623 union(
624 &mut parent,
625 lhs_axis_classes[lhs_axis],
626 rhs_offset + rhs_axis_classes[rhs_axis],
627 );
628 }
629
630 let mut lhs_contracted = vec![false; lhs_axis_classes.len()];
631 for &axis in axes_a {
632 lhs_contracted[axis] = true;
633 }
634 let mut rhs_contracted = vec![false; rhs_axis_classes.len()];
635 for &axis in axes_b {
636 rhs_contracted[axis] = true;
637 }
638
639 let mut root_to_class = std::collections::HashMap::new();
640 let mut next_class = 0usize;
641 let mut axis_classes = Vec::new();
642
643 for (axis, &class_id) in lhs_axis_classes.iter().enumerate() {
644 if !lhs_contracted[axis] {
645 let root = find(&mut parent, class_id);
646 let class = *root_to_class.entry(root).or_insert_with(|| {
647 let value = next_class;
648 next_class += 1;
649 value
650 });
651 axis_classes.push(class);
652 }
653 }
654 for (axis, &class_id) in rhs_axis_classes.iter().enumerate() {
655 if !rhs_contracted[axis] {
656 let root = find(&mut parent, rhs_offset + class_id);
657 let class = *root_to_class.entry(root).or_insert_with(|| {
658 let value = next_class;
659 next_class += 1;
660 value
661 });
662 axis_classes.push(class);
663 }
664 }
665
666 axis_classes
667 }
668
669 fn scale_subscripts(rank: usize) -> Result<EinsumSubscripts> {
670 let ids: Vec<usize> = (0..rank).collect();
671 Self::einsum_subscripts_from_usize_ids(&[ids.clone(), Vec::new()], &ids)
672 }
673
674 fn validate_indices(indices: &[DynIndex]) -> Result<()> {
675 let mut seen = HashSet::new();
676 for idx in indices {
677 anyhow::ensure!(
678 seen.insert(idx.clone()),
679 "Tensor indices must all be unique"
680 );
681 }
682 Ok(())
683 }
684
685 fn validate_diag_dims(dims: &[usize]) -> Result<()> {
686 if !dims.is_empty() {
687 let first_dim = dims[0];
688 for (i, &dim) in dims.iter().enumerate() {
689 anyhow::ensure!(
690 dim == first_dim,
691 "DiagTensor requires all indices to have the same dimension, but dims[{i}] = {dim} != dims[0] = {first_dim}"
692 );
693 }
694 }
695 Ok(())
696 }
697
698 fn seed_native_payload(storage: &Storage, dims: &[usize]) -> Result<NativeTensor> {
699 storage_to_native_tensor(storage, dims)
700 }
701
702 fn empty_eager_cache() -> Arc<OnceLock<Arc<EagerTensor<CpuBackend>>>> {
703 Arc::new(OnceLock::new())
704 }
705
706 fn eager_cache_with(
707 inner: EagerTensor<CpuBackend>,
708 ) -> Arc<OnceLock<Arc<EagerTensor<CpuBackend>>>> {
709 let cache = Arc::new(OnceLock::new());
710 let _ = cache.set(Arc::new(inner));
711 cache
712 }
713
714 fn compact_payload_inner(&self) -> Result<EagerTensor<CpuBackend>> {
715 Ok(EagerTensor::from_tensor_in(
716 storage_payload_native(self.storage.materialize(self.indices.len())?.as_ref())?,
717 default_eager_ctx(),
718 ))
719 }
720
721 fn tracked_compact_payload_value(&self) -> Option<&StructuredAdValue> {
722 self.structured_ad.as_deref()
723 }
724
725 fn compact_payload_is_logical_dense(&self, payload_dims: &[usize]) -> bool {
726 self.storage.axis_classes() == Self::dense_axis_classes(self.indices.len())
727 && payload_dims == self.dims()
728 }
729
730 fn uses_tracked_compact_storage(&self) -> bool {
731 self.tracked_compact_payload_value()
732 .is_some_and(|value| !self.compact_payload_is_logical_dense(&value.payload_dims))
733 }
734
735 fn ensure_shape_packing_preserves_ad(&self, op_name: &str) -> Result<()> {
736 anyhow::ensure!(
737 !self.uses_tracked_compact_storage(),
738 "{op_name}: structured AD tensors with compact storage are not supported because materializing compact storage would detach gradients"
739 );
740 Ok(())
741 }
742
743 fn operand_indices_for_contraction(&self, conjugate: bool) -> Vec<DynIndex> {
744 if conjugate {
745 self.indices.iter().map(|index| index.conj()).collect()
746 } else {
747 self.indices.clone()
748 }
749 }
750
751 fn build_binary_contraction_labels(
752 lhs_rank: usize,
753 axes_a: &[usize],
754 rhs_rank: usize,
755 axes_b: &[usize],
756 ) -> Result<(Vec<usize>, Vec<usize>, Vec<usize>)> {
757 anyhow::ensure!(
758 axes_a.len() == axes_b.len(),
759 "contract axis length mismatch: lhs {:?}, rhs {:?}",
760 axes_a,
761 axes_b
762 );
763
764 let mut lhs_ids = vec![usize::MAX; lhs_rank];
765 let mut rhs_ids = vec![usize::MAX; rhs_rank];
766 let mut next_id = 0usize;
767
768 let mut seen_lhs = vec![false; lhs_rank];
769 let mut seen_rhs = vec![false; rhs_rank];
770
771 for (&lhs_axis, &rhs_axis) in axes_a.iter().zip(axes_b.iter()) {
772 anyhow::ensure!(
773 lhs_axis < lhs_rank,
774 "lhs contract axis {lhs_axis} out of range"
775 );
776 anyhow::ensure!(
777 rhs_axis < rhs_rank,
778 "rhs contract axis {rhs_axis} out of range"
779 );
780 anyhow::ensure!(
781 !seen_lhs[lhs_axis],
782 "duplicate lhs contract axis {lhs_axis}"
783 );
784 anyhow::ensure!(
785 !seen_rhs[rhs_axis],
786 "duplicate rhs contract axis {rhs_axis}"
787 );
788 seen_lhs[lhs_axis] = true;
789 seen_rhs[rhs_axis] = true;
790 lhs_ids[lhs_axis] = next_id;
791 rhs_ids[rhs_axis] = next_id;
792 next_id += 1;
793 }
794
795 let mut output_ids = Vec::with_capacity(lhs_rank + rhs_rank - 2 * axes_a.len());
796 for id in &mut lhs_ids {
797 if *id == usize::MAX {
798 *id = next_id;
799 output_ids.push(next_id);
800 next_id += 1;
801 }
802 }
803 for id in &mut rhs_ids {
804 if *id == usize::MAX {
805 *id = next_id;
806 output_ids.push(next_id);
807 next_id += 1;
808 }
809 }
810
811 Ok((lhs_ids, rhs_ids, output_ids))
812 }
813
814 fn build_payload_einsum_subscripts(
815 input_roots: &[Vec<usize>],
816 output_roots: &[usize],
817 ) -> Result<EinsumSubscripts> {
818 Self::einsum_subscripts_from_usize_ids(input_roots, output_roots)
819 }
820
821 fn normalize_eager_payload_for_roots(
822 payload: &EagerTensor<CpuBackend>,
823 roots: &[usize],
824 ) -> Result<(Option<EagerTensor<CpuBackend>>, Vec<usize>)> {
825 anyhow::ensure!(
826 payload.data().shape().len() == roots.len(),
827 "payload rank {} does not match root label count {}",
828 payload.data().shape().len(),
829 roots.len()
830 );
831
832 let mut current_payload = None;
833 let mut current_roots = roots.to_vec();
834 while let Some((axis_a, axis_b)) = Self::first_duplicate_pair(¤t_roots) {
835 let source = current_payload.as_ref().unwrap_or(payload);
836 current_payload = Some(source.extract_diag(axis_a, axis_b)?);
837 current_roots.remove(axis_b);
838 }
839
840 Ok((current_payload, current_roots))
841 }
842
843 fn first_duplicate_pair(values: &[usize]) -> Option<(usize, usize)> {
844 let mut first_axis_by_value = std::collections::HashMap::new();
845 for (axis, &value) in values.iter().enumerate() {
846 if let Some(&first_axis) = first_axis_by_value.get(&value) {
847 return Some((first_axis, axis));
848 }
849 first_axis_by_value.insert(value, axis);
850 }
851 None
852 }
853
854 fn binary_structured_contraction_plan(
855 &self,
856 other: &Self,
857 axes_a: &[usize],
858 axes_b: &[usize],
859 ) -> Result<(StructuredContractionPlan, Vec<Vec<usize>>, Vec<usize>)> {
860 let (lhs_labels, rhs_labels, output_labels) = Self::build_binary_contraction_labels(
861 self.indices.len(),
862 axes_a,
863 other.indices.len(),
864 axes_b,
865 )?;
866 let operands = vec![
867 OperandLayout::new(self.dims(), self.storage.axis_classes().to_vec())?,
868 OperandLayout::new(other.dims(), other.storage.axis_classes().to_vec())?,
869 ];
870 let spec = StructuredContractionSpec {
871 input_labels: vec![lhs_labels, rhs_labels],
872 output_labels,
873 retained_labels: Default::default(),
874 };
875 let plan = StructuredContractionPlan::new(&operands, &spec)?;
876 Ok((plan, spec.input_labels, spec.output_labels))
877 }
878
879 fn from_structured_payload_inner(
880 indices: Vec<DynIndex>,
881 payload_inner: EagerTensor<CpuBackend>,
882 payload_dims: Vec<usize>,
883 axis_classes: Vec<usize>,
884 ) -> Result<Self> {
885 Self::validate_indices(&indices)?;
886 if payload_inner.data().shape() != payload_dims {
887 return Err(anyhow::anyhow!(
888 "structured payload dims {:?} do not match planned payload dims {:?}",
889 payload_inner.data().shape(),
890 payload_dims
891 ));
892 }
893 let storage = storage_from_payload_native(
894 payload_inner.data().clone(),
895 &payload_dims,
896 axis_classes.clone(),
897 )?;
898 Self::validate_storage_matches_indices(&indices, &storage)?;
899 Ok(Self {
900 indices,
901 storage: TensorDynLenStorage::from_storage(Arc::new(storage)),
902 structured_ad: Some(Arc::new(StructuredAdValue {
903 payload: Arc::new(payload_inner),
904 payload_dims,
905 axis_classes,
906 })),
907 eager_cache: Self::empty_eager_cache(),
908 })
909 }
910
911 fn contract_structured_payloads(
912 &self,
913 other: &Self,
914 result_indices: Vec<DynIndex>,
915 axes_a: &[usize],
916 axes_b: &[usize],
917 ) -> Result<Self> {
918 let (plan, _, _) = self.binary_structured_contraction_plan(other, axes_a, axes_b)?;
919 let lhs_roots = plan.operand_plans[0].class_roots.clone();
920 let rhs_roots = plan.operand_plans[1].class_roots.clone();
921 let scalar_multiply =
922 lhs_roots.is_empty() && rhs_roots.is_empty() && plan.output_payload_roots.is_empty();
923
924 if let (Some(lhs_ad), Some(rhs_ad)) = (
925 self.tracked_compact_payload_value(),
926 other.tracked_compact_payload_value(),
927 ) {
928 if lhs_ad.payload.data().dtype() != rhs_ad.payload.data().dtype() {
929 return Err(anyhow::anyhow!(
930 "structured AD contraction requires matching payload dtypes"
931 ));
932 }
933 let (lhs_normalized, lhs_labels) =
934 Self::normalize_eager_payload_for_roots(lhs_ad.payload.as_ref(), &lhs_roots)?;
935 let (rhs_normalized, rhs_labels) =
936 Self::normalize_eager_payload_for_roots(rhs_ad.payload.as_ref(), &rhs_roots)?;
937 let lhs_payload = lhs_normalized
938 .as_ref()
939 .unwrap_or_else(|| lhs_ad.payload.as_ref());
940 let rhs_payload = rhs_normalized
941 .as_ref()
942 .unwrap_or_else(|| rhs_ad.payload.as_ref());
943 let payload = if scalar_multiply {
944 lhs_payload.mul(rhs_payload)?
945 } else {
946 let subscripts = Self::build_payload_einsum_subscripts(
947 &[lhs_labels, rhs_labels],
948 &plan.output_payload_roots,
949 )?;
950 eager_einsum_ad(&[lhs_payload, rhs_payload], &subscripts)?
951 };
952 return Self::from_structured_payload_inner(
953 result_indices,
954 payload,
955 plan.output_payload_dims,
956 plan.output_axis_classes,
957 );
958 }
959
960 if self.tracked_compact_payload_value().is_some()
961 || other.tracked_compact_payload_value().is_some()
962 {
963 let lhs_owned = if self.tracked_compact_payload_value().is_some() {
964 None
965 } else {
966 Some(self.compact_payload_inner()?)
967 };
968 let rhs_owned = if other.tracked_compact_payload_value().is_some() {
969 None
970 } else {
971 Some(other.compact_payload_inner()?)
972 };
973 let lhs = if let Some(value) = self.tracked_compact_payload_value() {
974 value.payload.as_ref()
975 } else {
976 lhs_owned
977 .as_ref()
978 .ok_or_else(|| anyhow::anyhow!("missing untracked left compact payload"))?
979 };
980 let rhs = if let Some(value) = other.tracked_compact_payload_value() {
981 value.payload.as_ref()
982 } else {
983 rhs_owned
984 .as_ref()
985 .ok_or_else(|| anyhow::anyhow!("missing untracked right compact payload"))?
986 };
987 if lhs.data().dtype() != rhs.data().dtype() {
988 return Err(anyhow::anyhow!(
989 "structured AD contraction requires matching payload dtypes"
990 ));
991 }
992 let (lhs_normalized, lhs_labels) =
993 Self::normalize_eager_payload_for_roots(lhs, &lhs_roots)?;
994 let (rhs_normalized, rhs_labels) =
995 Self::normalize_eager_payload_for_roots(rhs, &rhs_roots)?;
996 let lhs_payload = lhs_normalized.as_ref().unwrap_or(lhs);
997 let rhs_payload = rhs_normalized.as_ref().unwrap_or(rhs);
998 let payload = if scalar_multiply {
999 lhs_payload.mul(rhs_payload)?
1000 } else {
1001 let subscripts = Self::build_payload_einsum_subscripts(
1002 &[lhs_labels, rhs_labels],
1003 &plan.output_payload_roots,
1004 )?;
1005 eager_einsum_ad(&[lhs_payload, rhs_payload], &subscripts)?
1006 };
1007 return Self::from_structured_payload_inner(
1008 result_indices,
1009 payload,
1010 plan.output_payload_dims,
1011 plan.output_axis_classes,
1012 );
1013 }
1014
1015 let lhs_storage = self.storage.materialize(self.indices.len())?;
1016 let rhs_storage = other.storage.materialize(other.indices.len())?;
1017 let lhs = storage_payload_native_read_input(lhs_storage.as_ref())?;
1018 let rhs = storage_payload_native_read_input(rhs_storage.as_ref())?;
1019 if lhs.dtype() != rhs.dtype() {
1020 return Err(anyhow::anyhow!(
1021 "structured payload contraction requires matching payload dtypes"
1022 ));
1023 }
1024 let (lhs, lhs_labels) = normalize_payload_read_for_roots(lhs, &lhs_roots)?;
1025 let (rhs, rhs_labels) = normalize_payload_read_for_roots(rhs, &rhs_roots)?;
1026 let payload = tensor4all_tensorbackend::einsum_native_tensor_reads(
1027 &[(&lhs, lhs_labels.as_slice()), (&rhs, rhs_labels.as_slice())],
1028 &plan.output_payload_roots,
1029 )?;
1030 let storage = storage_from_payload_native(
1031 payload,
1032 &plan.output_payload_dims,
1033 plan.output_axis_classes,
1034 )?;
1035 Self::from_storage(result_indices, Arc::new(storage))
1036 }
1037
1038 fn should_use_structured_payload_contract(&self, other: &Self) -> bool {
1039 let same_payload_dtype = self.storage.is_f64() == other.storage.is_f64()
1040 && self.storage.is_complex() == other.storage.is_complex();
1041 same_payload_dtype
1042 && (self.tracked_compact_payload_value().is_some()
1043 || other.tracked_compact_payload_value().is_some()
1044 || self.storage.axis_classes() != Self::dense_axis_classes(self.indices.len())
1045 || other.storage.axis_classes() != Self::dense_axis_classes(other.indices.len()))
1046 }
1047
1048 fn storage_from_native_with_axis_classes(
1049 native: &NativeTensor,
1050 axis_classes: &[usize],
1051 logical_rank: usize,
1052 ) -> Result<Storage> {
1053 if Self::is_diag_axis_classes(axis_classes) {
1054 match native.dtype() {
1055 DType::F32 | DType::F64 | DType::I64 => Storage::from_diag_col_major(
1056 native_tensor_primal_to_diag_f64(native)?,
1057 logical_rank,
1058 ),
1059 DType::C32 | DType::C64 => Storage::from_diag_col_major(
1060 native_tensor_primal_to_diag_c64(native)?,
1061 logical_rank,
1062 ),
1063 }
1064 } else {
1065 native_tensor_primal_to_storage(native)
1066 }
1067 }
1068
1069 fn dense_selected_diag_payload<T: TensorElement + Copy + Zero>(
1070 payload: Vec<T>,
1071 kept_dims: &[usize],
1072 selected_positions: &[usize],
1073 ) -> Vec<T> {
1074 let output_len = kept_dims.iter().product::<usize>();
1075 let mut data = vec![T::zero(); output_len];
1076 if output_len == 0 {
1077 return data;
1078 }
1079
1080 let Some((&first_position, rest)) = selected_positions.split_first() else {
1081 return data;
1082 };
1083 if rest.iter().any(|&position| position != first_position) {
1084 return data;
1085 }
1086
1087 let value = payload[first_position];
1088 if kept_dims.is_empty() {
1089 data[0] = value;
1090 return data;
1091 }
1092
1093 let mut offset = 0usize;
1094 let mut stride = 1usize;
1095 for &dim in kept_dims {
1096 offset += first_position * stride;
1097 stride *= dim;
1098 }
1099 data[offset] = value;
1100 data
1101 }
1102
1103 fn select_diag_indices(
1104 &self,
1105 kept_indices: Vec<DynIndex>,
1106 kept_dims: Vec<usize>,
1107 positions: &[usize],
1108 ) -> Result<Self> {
1109 if self.storage.is_f64() {
1110 let storage = self.storage.materialize(self.indices.len())?;
1111 let payload = storage
1112 .payload_f64_col_major_vec()
1113 .map_err(anyhow::Error::msg)?;
1114 let data = Self::dense_selected_diag_payload(payload, &kept_dims, positions);
1115 Self::from_dense(kept_indices, data)
1116 } else if self.storage.is_c64() {
1117 let storage = self.storage.materialize(self.indices.len())?;
1118 let payload = storage
1119 .payload_c64_col_major_vec()
1120 .map_err(anyhow::Error::msg)?;
1121 let data = Self::dense_selected_diag_payload(payload, &kept_dims, positions);
1122 Self::from_dense(kept_indices, data)
1123 } else {
1124 Err(anyhow::anyhow!("unsupported diagonal storage scalar type"))
1125 }
1126 }
1127
1128 fn col_major_strides(dims: &[usize]) -> Result<Vec<isize>> {
1129 let mut strides = Vec::with_capacity(dims.len());
1130 let mut stride = 1isize;
1131 for &dim in dims {
1132 strides.push(stride);
1133 let dim = isize::try_from(dim)
1134 .map_err(|_| anyhow::anyhow!("dimension does not fit in isize"))?;
1135 stride = stride
1136 .checked_mul(dim)
1137 .ok_or_else(|| anyhow::anyhow!("column-major stride overflow"))?;
1138 }
1139 Ok(strides)
1140 }
1141
1142 fn zero_structured_selection<T>(
1143 kept_indices: Vec<DynIndex>,
1144 kept_dims: &[usize],
1145 ) -> Result<Self>
1146 where
1147 T: TensorElement + Zero,
1148 {
1149 let output_len = checked_product(kept_dims)?;
1150 Self::from_dense(kept_indices, vec![T::zero(); output_len])
1151 }
1152
1153 fn select_structured_indices_typed<T>(
1154 &self,
1155 payload: Vec<T>,
1156 kept_axes: &[usize],
1157 kept_indices: Vec<DynIndex>,
1158 kept_dims: Vec<usize>,
1159 selected_axes: &[usize],
1160 positions: &[usize],
1161 ) -> Result<Self>
1162 where
1163 T: TensorElement + StorageScalar + Zero,
1164 {
1165 let payload_dims = self.storage.payload_dims();
1166 let axis_classes = self.storage.axis_classes();
1167 let payload_rank = payload_dims.len();
1168 let mut selected_class_positions = vec![None; payload_rank];
1169
1170 for (&axis, &position) in selected_axes.iter().zip(positions.iter()) {
1171 let class_id = axis_classes[axis];
1172 if let Some(existing) = selected_class_positions[class_id] {
1173 if existing != position {
1174 return Self::zero_structured_selection::<T>(kept_indices, &kept_dims);
1175 }
1176 } else {
1177 selected_class_positions[class_id] = Some(position);
1178 }
1179 }
1180
1181 let selected_class_kept = kept_axes
1182 .iter()
1183 .any(|&axis| selected_class_positions[axis_classes[axis]].is_some());
1184 if selected_class_kept {
1185 return self.select_structured_indices_dense(
1186 payload,
1187 kept_axes,
1188 kept_indices,
1189 kept_dims,
1190 &selected_class_positions,
1191 );
1192 }
1193
1194 let mut old_to_new_class = vec![None; payload_rank];
1195 let mut output_payload_dims = Vec::new();
1196 let mut output_axis_classes = Vec::with_capacity(kept_axes.len());
1197 for &axis in kept_axes {
1198 let class_id = axis_classes[axis];
1199 let new_class = match old_to_new_class[class_id] {
1200 Some(new_class) => new_class,
1201 None => {
1202 let new_class = output_payload_dims.len();
1203 old_to_new_class[class_id] = Some(new_class);
1204 output_payload_dims.push(payload_dims[class_id]);
1205 new_class
1206 }
1207 };
1208 output_axis_classes.push(new_class);
1209 }
1210
1211 let output_len = checked_product(&output_payload_dims)?;
1212 let mut output_payload = Vec::with_capacity(output_len);
1213 for linear in 0..output_len {
1214 let output_payload_index = decode_col_major_linear(linear, &output_payload_dims)?;
1215 let mut input_payload_index = vec![0usize; payload_rank];
1216 for class_id in 0..payload_rank {
1217 input_payload_index[class_id] =
1218 if let Some(position) = selected_class_positions[class_id] {
1219 position
1220 } else if let Some(new_class) = old_to_new_class[class_id] {
1221 output_payload_index[new_class]
1222 } else {
1223 return Err(anyhow::anyhow!(
1224 "structured payload class {class_id} is neither selected nor kept"
1225 ));
1226 };
1227 }
1228 let input_linear = encode_col_major_linear(&input_payload_index, payload_dims)?;
1229 output_payload.push(payload[input_linear]);
1230 }
1231
1232 let output_strides = Self::col_major_strides(&output_payload_dims)?;
1233 let storage = Storage::new_structured(
1234 output_payload,
1235 output_payload_dims,
1236 output_strides,
1237 output_axis_classes,
1238 )?;
1239 Self::from_storage(kept_indices, Arc::new(storage))
1240 }
1241
1242 fn select_structured_indices_dense<T>(
1243 &self,
1244 payload: Vec<T>,
1245 kept_axes: &[usize],
1246 kept_indices: Vec<DynIndex>,
1247 kept_dims: Vec<usize>,
1248 selected_class_positions: &[Option<usize>],
1249 ) -> Result<Self>
1250 where
1251 T: TensorElement + Zero,
1252 {
1253 let payload_dims = self.storage.payload_dims();
1254 let axis_classes = self.storage.axis_classes();
1255 let output_len = checked_product(&kept_dims)?;
1256 let mut output = Vec::with_capacity(output_len);
1257
1258 for linear in 0..output_len {
1259 let kept_position = decode_col_major_linear(linear, &kept_dims)?;
1260 let mut input_payload_index = selected_class_positions.to_vec();
1261 let mut is_structural_zero = false;
1262
1263 for (&axis, &position) in kept_axes.iter().zip(kept_position.iter()) {
1264 let class_id = axis_classes[axis];
1265 match input_payload_index[class_id] {
1266 Some(existing) if existing != position => {
1267 is_structural_zero = true;
1268 break;
1269 }
1270 Some(_) => {}
1271 None => input_payload_index[class_id] = Some(position),
1272 }
1273 }
1274
1275 if is_structural_zero {
1276 output.push(T::zero());
1277 continue;
1278 }
1279
1280 let input_payload_index = input_payload_index
1281 .into_iter()
1282 .enumerate()
1283 .map(|(class_id, position)| {
1284 position.ok_or_else(|| {
1285 anyhow::anyhow!(
1286 "structured payload class {class_id} is neither selected nor kept"
1287 )
1288 })
1289 })
1290 .collect::<Result<Vec<_>>>()?;
1291 let input_linear = encode_col_major_linear(&input_payload_index, payload_dims)?;
1292 output.push(payload[input_linear]);
1293 }
1294
1295 Self::from_dense(kept_indices, output)
1296 }
1297
1298 fn select_structured_indices(
1299 &self,
1300 kept_axes: &[usize],
1301 kept_indices: Vec<DynIndex>,
1302 kept_dims: Vec<usize>,
1303 selected_axes: &[usize],
1304 positions: &[usize],
1305 ) -> Result<Self> {
1306 if self.storage.is_f64() {
1307 let storage = self.storage.materialize(self.indices.len())?;
1308 let payload = storage
1309 .payload_f64_col_major_vec()
1310 .map_err(anyhow::Error::msg)?;
1311 self.select_structured_indices_typed(
1312 payload,
1313 kept_axes,
1314 kept_indices,
1315 kept_dims,
1316 selected_axes,
1317 positions,
1318 )
1319 } else if self.storage.is_c64() {
1320 let storage = self.storage.materialize(self.indices.len())?;
1321 let payload = storage
1322 .payload_c64_col_major_vec()
1323 .map_err(anyhow::Error::msg)?;
1324 self.select_structured_indices_typed(
1325 payload,
1326 kept_axes,
1327 kept_indices,
1328 kept_dims,
1329 selected_axes,
1330 positions,
1331 )
1332 } else {
1333 Err(anyhow::anyhow!(
1334 "unsupported structured storage scalar type"
1335 ))
1336 }
1337 }
1338
1339 fn validate_storage_matches_indices(indices: &[DynIndex], storage: &Storage) -> Result<()> {
1340 let dims = Self::expected_dims_from_indices(indices);
1341 let storage_dims = storage.logical_dims();
1342 if storage_dims != dims {
1343 return Err(anyhow::anyhow!(
1344 "storage logical dims {:?} do not match indices dims {:?}",
1345 storage_dims,
1346 dims
1347 ));
1348 }
1349 if storage.is_diag() {
1350 Self::validate_diag_dims(&dims)?;
1351 }
1352 Ok(())
1353 }
1354
1355 fn try_materialized_inner(&self) -> Result<&EagerTensor<CpuBackend>> {
1356 if let Some(value) = self.tracked_compact_payload_value() {
1357 if self.compact_payload_is_logical_dense(&value.payload_dims) {
1358 return Ok(value.payload.as_ref());
1359 }
1360 }
1361 if let Some(inner) = self.storage.eager() {
1362 return Ok(inner);
1363 }
1364 if self.eager_cache.get().is_none() {
1365 let dims = self.dims();
1366 let native = profile_pairwise_contract_section("materialize_storage_to_native", || {
1367 let storage = self.storage.materialize(self.indices.len())?;
1368 Self::seed_native_payload(storage.as_ref(), &dims)
1369 })
1370 .context("TensorDynLen materialization failed")?;
1371 record_pairwise_contract_profile_bytes(
1372 "materialize_storage_to_native",
1373 native_tensor_profile_bytes(&native),
1374 );
1375 let _ = self.eager_cache.set(Arc::new(EagerTensor::from_tensor_in(
1376 native,
1377 default_eager_ctx(),
1378 )));
1379 }
1380 self.eager_cache
1381 .get()
1382 .map(|inner| inner.as_ref())
1383 .ok_or_else(|| {
1384 anyhow::anyhow!("TensorDynLen materialization cache was not initialized")
1385 })
1386 }
1387
1388 pub(crate) fn as_inner(&self) -> Result<&EagerTensor<CpuBackend>> {
1389 self.try_materialized_inner()
1390 }
1391
1392 #[inline]
1394 fn expected_dims_from_indices(indices: &[DynIndex]) -> Vec<usize> {
1395 indices.iter().map(|idx| idx.dim()).collect()
1396 }
1397
1398 pub fn dims(&self) -> Vec<usize> {
1417 Self::expected_dims_from_indices(&self.indices)
1418 }
1419
1420 pub fn select_indices(
1462 &self,
1463 selected_indices: &[DynIndex],
1464 positions: &[usize],
1465 ) -> Result<Self> {
1466 if selected_indices.len() != positions.len() {
1467 return Err(anyhow::anyhow!(
1468 "selected_indices length {} does not match positions length {}",
1469 selected_indices.len(),
1470 positions.len()
1471 ));
1472 }
1473 if selected_indices.is_empty() {
1474 return Ok(self.clone());
1475 }
1476
1477 let mut selected_axes = Vec::with_capacity(selected_indices.len());
1478 let mut seen_axes = HashSet::with_capacity(selected_indices.len());
1479 for (selected, &position) in selected_indices.iter().zip(positions.iter()) {
1480 let axis = self
1481 .indices
1482 .iter()
1483 .position(|index| index == selected)
1484 .ok_or_else(|| anyhow::anyhow!("selected index is not present in tensor"))?;
1485 if !seen_axes.insert(axis) {
1486 return Err(anyhow::anyhow!("selected index appears more than once"));
1487 }
1488 let dim = self.indices[axis].dim();
1489 if position >= dim {
1490 return Err(anyhow::anyhow!(
1491 "selected coordinate {position} is out of range for axis {axis} with dim {dim}"
1492 ));
1493 }
1494 selected_axes.push(axis);
1495 }
1496
1497 let kept_axes = self
1498 .indices
1499 .iter()
1500 .enumerate()
1501 .filter(|(axis, _)| !seen_axes.contains(axis))
1502 .map(|(axis, _)| axis)
1503 .collect::<Vec<_>>();
1504 let kept_indices = kept_axes
1505 .iter()
1506 .map(|&axis| self.indices[axis].clone())
1507 .collect::<Vec<_>>();
1508 let kept_dims = kept_axes
1509 .iter()
1510 .map(|&axis| self.indices[axis].dim())
1511 .collect::<Vec<_>>();
1512
1513 if self.storage.storage_kind() == StorageKind::Diagonal {
1514 return self.select_diag_indices(kept_indices, kept_dims, positions);
1515 }
1516 if self.storage.storage_kind() == StorageKind::Structured {
1517 return self.select_structured_indices(
1518 &kept_axes,
1519 kept_indices,
1520 kept_dims,
1521 &selected_axes,
1522 positions,
1523 );
1524 }
1525 if self.storage.storage_kind() != StorageKind::Dense {
1526 return Err(anyhow::anyhow!(
1527 "select_indices got unsupported storage kind {:?}",
1528 self.storage.storage_kind()
1529 ));
1530 }
1531
1532 let rank = self.indices.len();
1533 let mut starts = vec![0_i64; rank];
1534 let mut slice_sizes = self.dims();
1535 for (&axis, &position) in selected_axes.iter().zip(positions.iter()) {
1536 starts[axis] = i64::try_from(position)
1537 .map_err(|_| anyhow::anyhow!("selected coordinate does not fit in i64"))?;
1538 slice_sizes[axis] = 1;
1539 }
1540
1541 let starts_tensor = EagerTensor::from_tensor_in(
1542 NativeTensor::from_vec(vec![rank], starts),
1543 default_eager_ctx(),
1544 );
1545 let sliced = self
1546 .try_materialized_inner()?
1547 .dynamic_slice(&starts_tensor, &slice_sizes)?;
1548 Self::from_inner(kept_indices, sliced.reshape(&kept_dims)?)
1549 }
1550
1551 pub fn stack_along_new_index(
1584 tensors: &[&Self],
1585 new_index: DynIndex,
1586 axis: isize,
1587 ) -> Result<Self> {
1588 let first = tensors
1589 .first()
1590 .copied()
1591 .ok_or_else(|| anyhow::anyhow!("stack_along_new_index requires at least one tensor"))?;
1592 anyhow::ensure!(
1593 new_index.dim() == tensors.len(),
1594 "stack_along_new_index: new index dim {} does not match tensor count {}",
1595 new_index.dim(),
1596 tensors.len()
1597 );
1598
1599 let base_indices = first.indices.clone();
1600 for tensor in tensors.iter().copied().skip(1) {
1601 anyhow::ensure!(
1602 tensor.indices == base_indices,
1603 "stack_along_new_index: input tensors must have identical index order"
1604 );
1605 }
1606 for &tensor in tensors {
1607 tensor.ensure_shape_packing_preserves_ad("stack_along_new_index")?;
1608 }
1609
1610 let insert_axis =
1611 Self::normalize_insert_axis("stack_along_new_index", axis, base_indices.len())?;
1612 let mut result_indices = base_indices;
1613 result_indices.insert(insert_axis, new_index);
1614
1615 let inner_refs = tensors
1616 .iter()
1617 .map(|tensor| tensor.try_materialized_inner())
1618 .collect::<Result<Vec<_>>>()?;
1619 let stacked = EagerTensor::stack(&inner_refs, axis)?;
1620 Self::from_inner(result_indices, stacked)
1621 }
1622
1623 pub fn index_select(
1656 &self,
1657 source_index: &DynIndex,
1658 target_index: DynIndex,
1659 positions: &[usize],
1660 ) -> Result<Self> {
1661 anyhow::ensure!(
1662 target_index.dim() == positions.len(),
1663 "index_select: target index dim {} does not match position count {}",
1664 target_index.dim(),
1665 positions.len()
1666 );
1667 let axis = self
1668 .indices
1669 .iter()
1670 .position(|index| index == source_index)
1671 .ok_or_else(|| anyhow::anyhow!("index_select: source index is not present"))?;
1672 let source_dim = self.indices[axis].dim();
1673 for &position in positions {
1674 anyhow::ensure!(
1675 position < source_dim,
1676 "index_select: position {position} is out of range for source dim {source_dim}"
1677 );
1678 }
1679 self.ensure_shape_packing_preserves_ad("index_select")?;
1680
1681 let axis = isize::try_from(axis)
1682 .map_err(|_| anyhow::anyhow!("index_select: axis does not fit in isize"))?;
1683 let selected = self
1684 .try_materialized_inner()?
1685 .index_select(axis, positions)?;
1686 let mut result_indices = self.indices.clone();
1687 result_indices[axis as usize] = target_index;
1688 Self::from_inner(result_indices, selected)
1689 }
1690
1691 pub fn new(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Result<Self> {
1711 Self::from_storage(indices, storage)
1712 }
1713
1714 pub fn from_indices(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Result<Self> {
1736 Self::new(indices, storage)
1737 }
1738
1739 pub fn from_storage(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Result<Self> {
1755 Self::validate_indices(&indices)?;
1756 Self::validate_storage_matches_indices(&indices, storage.as_ref())?;
1757 Ok(Self {
1758 indices,
1759 storage: TensorDynLenStorage::from_storage(storage),
1760 structured_ad: None,
1761 eager_cache: Self::empty_eager_cache(),
1762 })
1763 }
1764
1765 pub fn from_structured_storage(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Result<Self> {
1789 Self::from_storage(indices, storage)
1790 }
1791
1792 pub(crate) fn from_native(indices: Vec<DynIndex>, native: NativeTensor) -> Result<Self> {
1794 let axis_classes = Self::dense_axis_classes(indices.len());
1795 Self::from_native_with_axis_classes(indices, native, axis_classes)
1796 }
1797
1798 pub(crate) fn from_native_with_axis_classes(
1799 indices: Vec<DynIndex>,
1800 native: NativeTensor,
1801 axis_classes: Vec<usize>,
1802 ) -> Result<Self> {
1803 Self::from_inner_with_axis_classes(
1804 indices,
1805 EagerTensor::from_tensor_in(native, default_eager_ctx()),
1806 axis_classes,
1807 )
1808 }
1809
1810 pub(crate) fn from_inner(
1811 indices: Vec<DynIndex>,
1812 inner: EagerTensor<CpuBackend>,
1813 ) -> Result<Self> {
1814 let axis_classes = Self::dense_axis_classes(indices.len());
1815 Self::from_inner_with_axis_classes(indices, inner, axis_classes)
1816 }
1817
1818 pub(crate) fn from_diag_inner(
1819 indices: Vec<DynIndex>,
1820 payload_inner: EagerTensor<CpuBackend>,
1821 ) -> Result<Self> {
1822 let dims = Self::expected_dims_from_indices(&indices);
1823 Self::validate_indices(&indices)?;
1824 Self::validate_diag_dims(&dims)?;
1825 Self::validate_diag_payload_len(payload_inner.data().shape().iter().product(), &dims)?;
1826 let axis_classes = Self::diag_axis_classes(dims.len());
1827 let diag_inner = payload_inner.embed_diag(0, 1)?;
1828 Self::from_inner_with_axis_classes(indices, diag_inner, axis_classes)
1829 }
1830
1831 pub(crate) fn from_inner_with_axis_classes(
1832 indices: Vec<DynIndex>,
1833 inner: EagerTensor<CpuBackend>,
1834 axis_classes: Vec<usize>,
1835 ) -> Result<Self> {
1836 let dims = profile_pairwise_contract_section("from_inner_expected_dims", || {
1837 Self::expected_dims_from_indices(&indices)
1838 });
1839 profile_pairwise_contract_section("from_inner_validate_indices", || {
1840 Self::validate_indices(&indices)
1841 })?;
1842 if dims != inner.data().shape() {
1843 return Err(anyhow::anyhow!(
1844 "native payload dims {:?} do not match indices dims {:?}",
1845 inner.data().shape(),
1846 dims
1847 ));
1848 }
1849 if Self::is_diag_axis_classes(&axis_classes) {
1850 profile_pairwise_contract_section("from_inner_validate_diag_dims", || {
1851 Self::validate_diag_dims(&dims)
1852 })?;
1853 }
1854 let (storage, eager_cache) = if axis_classes == Self::dense_axis_classes(indices.len()) {
1855 (
1856 TensorDynLenStorage::from_eager_dense(inner, indices.len()),
1857 Self::empty_eager_cache(),
1858 )
1859 } else {
1860 let storage = profile_pairwise_contract_section("from_inner_storage_snapshot", || {
1861 Self::storage_from_native_with_axis_classes(
1862 inner.data(),
1863 &axis_classes,
1864 indices.len(),
1865 )
1866 })?;
1867 record_pairwise_contract_profile_bytes(
1868 "from_inner_storage_snapshot",
1869 native_tensor_profile_bytes(inner.data()),
1870 );
1871 (
1872 TensorDynLenStorage::from_storage(Arc::new(storage)),
1873 profile_pairwise_contract_section("from_inner_eager_cache", || {
1874 Self::eager_cache_with(inner)
1875 }),
1876 )
1877 };
1878 Ok(Self {
1879 indices,
1880 storage,
1881 structured_ad: None,
1882 eager_cache,
1883 })
1884 }
1885
1886 pub fn indices(&self) -> &[DynIndex] {
1888 &self.indices
1889 }
1890
1891 pub(crate) fn as_native(&self) -> Result<&NativeTensor> {
1893 Ok(self.try_materialized_inner()?.data())
1894 }
1895
1896 pub fn enable_grad(self) -> Result<Self> {
1898 let materialized = self.storage.materialize(self.indices.len())?;
1899 let payload = storage_payload_native(materialized.as_ref())
1900 .context("TensorDynLen::enable_grad failed")?;
1901 let payload_dims = self.storage.payload_dims().to_vec();
1902 let axis_classes = self.storage.axis_classes().to_vec();
1903 Ok(Self {
1904 indices: self.indices,
1905 storage: self.storage,
1906 structured_ad: Some(Arc::new(StructuredAdValue {
1907 payload: Arc::new(EagerTensor::requires_grad_in(payload, default_eager_ctx())),
1908 payload_dims,
1909 axis_classes,
1910 })),
1911 eager_cache: Self::empty_eager_cache(),
1912 })
1913 }
1914
1915 pub fn tracks_grad(&self) -> bool {
1917 self.structured_ad
1918 .as_ref()
1919 .is_some_and(|value| value.payload.tracks_grad())
1920 || self.storage.eager().is_some_and(EagerTensor::tracks_grad)
1921 || self
1922 .eager_cache
1923 .get()
1924 .is_some_and(|inner| inner.tracks_grad())
1925 }
1926
1927 pub fn grad(&self) -> Result<Option<Self>> {
1929 if let Some(value) = self.tracked_compact_payload_value() {
1930 return value
1931 .payload
1932 .grad()
1933 .map(|grad| {
1934 let storage = storage_from_payload_native(
1935 grad.as_ref().clone(),
1936 &value.payload_dims,
1937 value.axis_classes.clone(),
1938 )?;
1939 Self::from_storage(self.indices.clone(), Arc::new(storage))
1940 })
1941 .transpose();
1942 }
1943 self.try_materialized_inner()?
1944 .grad()
1945 .map(|grad| {
1946 Self::from_native_with_axis_classes(
1947 self.indices.clone(),
1948 grad.as_ref().clone(),
1949 self.storage.axis_classes().to_vec(),
1950 )
1951 })
1952 .transpose()
1953 }
1954
1955 pub fn clear_grad(&self) -> Result<()> {
1957 if let Some(value) = self.tracked_compact_payload_value() {
1958 value.payload.clear_grad();
1959 }
1960 if let Some(inner) = self.storage.eager() {
1961 inner.clear_grad();
1962 }
1963 if let Some(inner) = self.eager_cache.get() {
1964 inner.clear_grad();
1965 }
1966 Ok(())
1967 }
1968
1969 pub fn backward(&self) -> Result<()> {
1971 if let Some(value) = self.tracked_compact_payload_value() {
1972 return value
1973 .payload
1974 .backward()
1975 .map(|_| ())
1976 .map_err(|e| anyhow::anyhow!("TensorDynLen::backward failed: {e}"));
1977 }
1978 self.try_materialized_inner()?
1979 .backward()
1980 .map(|_| ())
1981 .map_err(|e| anyhow::anyhow!("TensorDynLen::backward failed: {e}"))
1982 }
1983
1984 pub fn detach(&self) -> Result<Self> {
1986 if self.tracked_compact_payload_value().is_some() {
1987 return Self::from_storage(
1988 self.indices.clone(),
1989 self.storage.materialize(self.indices.len())?,
1990 );
1991 }
1992 Self::from_inner_with_axis_classes(
1993 self.indices.clone(),
1994 self.try_materialized_inner()?.detach(),
1995 self.storage.axis_classes().to_vec(),
1996 )
1997 }
1998
1999 pub fn is_simple(&self) -> bool {
2001 true
2002 }
2003
2004 pub fn to_storage(&self) -> Result<Arc<Storage>> {
2006 self.storage.materialize(self.indices.len())
2007 }
2008
2009 pub fn storage(&self) -> Arc<Storage> {
2011 self.storage
2012 .materialize(self.indices.len())
2013 .expect("TensorDynLen storage materialization failed")
2014 }
2015
2016 pub fn sum(&self) -> Result<AnyScalar> {
2029 if self.indices.is_empty() {
2030 return AnyScalar::from_tensor(self.clone());
2031 }
2032 let axes: Vec<usize> = (0..self.indices.len()).collect();
2033 let reduced = self.try_materialized_inner()?.reduce_sum(&axes)?;
2034 AnyScalar::from_tensor(Self::from_inner(Vec::new(), reduced)?)
2035 }
2036
2037 pub fn only(&self) -> Result<AnyScalar> {
2058 let dims = self.dims();
2059 let total_size = checked_product(&dims)?;
2060 anyhow::ensure!(
2061 total_size == 1 || dims.is_empty(),
2062 "only() requires a scalar tensor (1 element), got {} elements with dims {:?}",
2063 if dims.is_empty() { 1 } else { total_size },
2064 dims
2065 );
2066 self.sum()
2067 }
2068
2069 pub fn permute_indices(&self, new_indices: &[DynIndex]) -> Result<Self> {
2100 let perm = compute_permutation_from_indices(&self.indices, new_indices)?;
2102 if perm.iter().copied().eq(0..perm.len()) {
2103 return Ok(Self {
2104 indices: new_indices.to_vec(),
2105 storage: self.storage.clone(),
2106 structured_ad: self.structured_ad.clone(),
2107 eager_cache: Arc::clone(&self.eager_cache),
2108 });
2109 }
2110
2111 let permuted = self.try_materialized_inner()?.transpose(&perm)?;
2112 let axis_classes = self.permute_axis_classes(&perm);
2113 Self::from_inner_with_axis_classes(new_indices.to_vec(), permuted, axis_classes)
2114 }
2115
2116 pub fn permute(&self, perm: &[usize]) -> Result<Self> {
2145 anyhow::ensure!(
2146 perm.len() == self.indices.len(),
2147 "permutation length must match tensor rank"
2148 );
2149 let mut seen = HashSet::new();
2150 for &axis in perm {
2151 anyhow::ensure!(
2152 axis < self.indices.len(),
2153 "permutation axis {axis} out of range"
2154 );
2155 anyhow::ensure!(seen.insert(axis), "duplicate axis {axis} in permutation");
2156 }
2157 if perm.iter().copied().eq(0..perm.len()) {
2158 return Ok(self.clone());
2159 }
2160
2161 let new_indices: Vec<DynIndex> = perm.iter().map(|&i| self.indices[i].clone()).collect();
2163 let permuted = self.try_materialized_inner()?.transpose(perm)?;
2164 let axis_classes = self.permute_axis_classes(perm);
2165 Self::from_inner_with_axis_classes(new_indices, permuted, axis_classes)
2166 }
2167
2168 pub(crate) fn try_contract_pairwise_default(&self, other: &Self) -> Result<Self> {
2169 self.try_contract_pairwise_default_with_options(other, PairwiseContractionOptions::new())
2170 }
2171
2172 pub(crate) fn try_contract_pairwise_default_with_options(
2173 &self,
2174 other: &Self,
2175 options: PairwiseContractionOptions,
2176 ) -> Result<Self> {
2177 let self_indices = profile_pairwise_contract_section("operand_indices", || {
2178 self.operand_indices_for_contraction(options.lhs_conj)
2179 });
2180 let other_indices = profile_pairwise_contract_section("operand_indices", || {
2181 other.operand_indices_for_contraction(options.rhs_conj)
2182 });
2183 let self_dims = profile_pairwise_contract_section("expected_dims", || {
2184 Self::expected_dims_from_indices(&self_indices)
2185 });
2186 let other_dims = profile_pairwise_contract_section("expected_dims", || {
2187 Self::expected_dims_from_indices(&other_indices)
2188 });
2189 let spec = profile_pairwise_contract_section("prepare_contraction", || {
2190 prepare_contraction(&self_indices, &self_dims, &other_indices, &other_dims)
2191 })
2192 .context("contraction preparation failed")?;
2193 let result_axis_classes = profile_pairwise_contract_section("result_axis_classes", || {
2194 Self::binary_contraction_axis_classes(
2195 self.storage.axis_classes(),
2196 &spec.axes_a,
2197 other.storage.axis_classes(),
2198 &spec.axes_b,
2199 )
2200 });
2201
2202 if profile_pairwise_contract_section("structured_check", || {
2203 self.should_use_structured_payload_contract(other)
2204 }) {
2205 if options.has_conj() {
2206 let lhs = if options.lhs_conj {
2207 self.conj()
2208 } else {
2209 self.clone()
2210 };
2211 let rhs = if options.rhs_conj {
2212 other.conj()
2213 } else {
2214 other.clone()
2215 };
2216 return profile_pairwise_contract_section("structured_conj_fallback", || {
2217 lhs.try_contract_pairwise_default(&rhs)
2218 });
2219 }
2220 return profile_pairwise_contract_section("structured_payload_contract", || {
2221 self.contract_structured_payloads(
2222 other,
2223 spec.result_indices.into_vec(),
2224 &spec.axes_a,
2225 &spec.axes_b,
2226 )
2227 });
2228 }
2229
2230 if self.indices.is_empty() && other.indices.is_empty() {
2231 if options.has_conj() {
2232 let lhs = if options.lhs_conj {
2233 self.conj()
2234 } else {
2235 self.clone()
2236 };
2237 let rhs = if options.rhs_conj {
2238 other.conj()
2239 } else {
2240 other.clone()
2241 };
2242 return lhs.try_contract_pairwise_default(&rhs);
2243 }
2244 let result = profile_pairwise_contract_section("scalar_mul", || {
2245 Ok::<_, anyhow::Error>(
2246 self.try_materialized_inner()?
2247 .mul(other.try_materialized_inner()?)?,
2248 )
2249 })?;
2250 return profile_pairwise_contract_section("from_inner", || {
2251 Self::from_inner(spec.result_indices.into_vec(), result)
2252 });
2253 }
2254
2255 let self_native = profile_pairwise_contract_section("as_native", || self.as_native())?;
2256 let other_native = profile_pairwise_contract_section("as_native", || other.as_native())?;
2257 if self_native.dtype() != other_native.dtype() {
2258 if options.has_conj() {
2259 let lhs = if options.lhs_conj {
2260 self.conj()
2261 } else {
2262 self.clone()
2263 };
2264 let rhs = if options.rhs_conj {
2265 other.conj()
2266 } else {
2267 other.clone()
2268 };
2269 return lhs.try_contract_pairwise_default(&rhs);
2270 }
2271 let result_native = profile_pairwise_contract_section("native_contract", || {
2272 contract_native_tensor(self_native, &spec.axes_a, other_native, &spec.axes_b)
2273 })?;
2274 return profile_pairwise_contract_section("from_native", || {
2275 Self::from_native_with_axis_classes(
2276 spec.result_indices.into_vec(),
2277 result_native,
2278 result_axis_classes,
2279 )
2280 });
2281 }
2282
2283 let config = profile_pairwise_contract_section("build_dot_general_config", || {
2284 Self::binary_dot_general_config(&spec.axes_a, &spec.axes_b)
2285 })?;
2286 let result = profile_pairwise_contract_section("dot_general_with_conj", || {
2287 let lhs = profile_pairwise_contract_section("lhs_try_materialized_inner", || {
2288 self.try_materialized_inner()
2289 })?;
2290 let rhs = profile_pairwise_contract_section("rhs_try_materialized_inner", || {
2291 other.try_materialized_inner()
2292 })?;
2293 profile_pairwise_contract_section("dot_general_execute", || {
2294 lhs.dot_general_with_conj(rhs, &config, options.lhs_conj, options.rhs_conj)
2295 })
2296 .map_err(anyhow::Error::from)
2297 })?;
2298 record_pairwise_contract_profile_bytes(
2299 "dot_general_output",
2300 native_tensor_profile_bytes(result.data()),
2301 );
2302 profile_pairwise_contract_section("from_inner_axis_classes", || {
2303 Self::from_inner_with_axis_classes(
2304 spec.result_indices.into_vec(),
2305 result,
2306 result_axis_classes,
2307 )
2308 })
2309 }
2310
2311 pub(crate) fn try_tensordot_pairwise_explicit(
2312 &self,
2313 other: &Self,
2314 pairs: &[(DynIndex, DynIndex)],
2315 ) -> Result<Self> {
2316 use crate::index_ops::ContractionError;
2317
2318 let self_dims = Self::expected_dims_from_indices(&self.indices);
2319 let other_dims = Self::expected_dims_from_indices(&other.indices);
2320 let spec = prepare_contraction_pairs(
2321 &self.indices,
2322 &self_dims,
2323 &other.indices,
2324 &other_dims,
2325 pairs,
2326 )
2327 .map_err(|e| match e {
2328 ContractionError::NoCommonIndices => {
2329 anyhow::anyhow!("tensordot: No pairs specified for contraction")
2330 }
2331 ContractionError::BatchContractionNotImplemented => anyhow::anyhow!(
2332 "tensordot: Common index found but not in contraction pairs. \
2333 Batch contraction is not yet implemented."
2334 ),
2335 ContractionError::IndexNotFound { tensor } => {
2336 anyhow::anyhow!("tensordot: Index not found in {} tensor", tensor)
2337 }
2338 ContractionError::DimensionMismatch {
2339 pos_a,
2340 pos_b,
2341 dim_a,
2342 dim_b,
2343 } => anyhow::anyhow!(
2344 "tensordot: Dimension mismatch: self[{}]={} != other[{}]={}",
2345 pos_a,
2346 dim_a,
2347 pos_b,
2348 dim_b
2349 ),
2350 ContractionError::DuplicateAxis { tensor, pos } => {
2351 anyhow::anyhow!("tensordot: Duplicate axis {} in {} tensor", pos, tensor)
2352 }
2353 })?;
2354 let result_axis_classes = Self::binary_contraction_axis_classes(
2355 self.storage.axis_classes(),
2356 &spec.axes_a,
2357 other.storage.axis_classes(),
2358 &spec.axes_b,
2359 );
2360
2361 if self.should_use_structured_payload_contract(other) {
2362 return self.contract_structured_payloads(
2363 other,
2364 spec.result_indices.into_vec(),
2365 &spec.axes_a,
2366 &spec.axes_b,
2367 );
2368 }
2369
2370 if self.indices.is_empty() && other.indices.is_empty() {
2371 let result = self
2372 .try_materialized_inner()?
2373 .mul(other.try_materialized_inner()?)
2374 .map_err(|e| anyhow::anyhow!("tensordot scalar multiply failed: {e}"))?;
2375 return Self::from_inner(spec.result_indices.into_vec(), result);
2376 }
2377
2378 let self_native = self.as_native()?;
2379 let other_native = other.as_native()?;
2380 if self_native.dtype() != other_native.dtype() {
2381 let result_native =
2382 contract_native_tensor(self_native, &spec.axes_a, other_native, &spec.axes_b)?;
2383 return Self::from_native_with_axis_classes(
2384 spec.result_indices.into_vec(),
2385 result_native,
2386 result_axis_classes,
2387 );
2388 }
2389
2390 let subscripts = Self::build_binary_einsum_subscripts(
2391 self.indices.len(),
2392 &spec.axes_a,
2393 other.indices.len(),
2394 &spec.axes_b,
2395 )?;
2396 let result = eager_einsum_ad(
2397 &[
2398 self.try_materialized_inner()?,
2399 other.try_materialized_inner()?,
2400 ],
2401 &subscripts,
2402 )
2403 .map_err(|e| anyhow::anyhow!("tensordot failed: {e}"))?;
2404 Self::from_inner_with_axis_classes(
2405 spec.result_indices.into_vec(),
2406 result,
2407 result_axis_classes,
2408 )
2409 }
2410
2411 pub(crate) fn try_outer_product_pairwise(&self, other: &Self) -> Result<Self> {
2412 use anyhow::Context;
2413
2414 let common_positions = common_ind_positions(&self.indices, &other.indices);
2416 if !common_positions.is_empty() {
2417 let common_ids: Vec<_> = common_positions
2418 .iter()
2419 .map(|(pos_a, _)| self.indices[*pos_a].id())
2420 .collect();
2421 return Err(anyhow::anyhow!(
2422 "outer_product: tensors have common indices {:?}. \
2423 Use tensordot to contract common indices, or use sim() to replace \
2424 indices with fresh IDs before computing outer product.",
2425 common_ids
2426 ))
2427 .context("outer_product: common indices found");
2428 }
2429
2430 let mut result_indices = self.indices.clone();
2432 result_indices.extend(other.indices.iter().cloned());
2433 let result_axis_classes = Self::binary_contraction_axis_classes(
2434 self.storage.axis_classes(),
2435 &[],
2436 other.storage.axis_classes(),
2437 &[],
2438 );
2439 if self.should_use_structured_payload_contract(other) {
2440 return self.contract_structured_payloads(other, result_indices, &[], &[]);
2441 }
2442 let self_native = self.as_native()?;
2443 let other_native = other.as_native()?;
2444 if self_native.dtype() != other_native.dtype() {
2445 let result_native = contract_native_tensor(self_native, &[], other_native, &[])?;
2446 return Self::from_native_with_axis_classes(
2447 result_indices,
2448 result_native,
2449 result_axis_classes,
2450 );
2451 }
2452
2453 let subscripts = Self::build_binary_einsum_subscripts(
2454 self.indices.len(),
2455 &[],
2456 other.indices.len(),
2457 &[],
2458 )?;
2459 let result = eager_einsum_ad(
2460 &[
2461 self.try_materialized_inner()?,
2462 other.try_materialized_inner()?,
2463 ],
2464 &subscripts,
2465 )
2466 .map_err(|e| anyhow::anyhow!("outer_product failed: {e}"))?;
2467 Self::from_inner_with_axis_classes(result_indices, result, result_axis_classes)
2468 }
2469}
2470
2471impl TensorDynLen {
2476 pub fn random<T: RandomScalar, R: Rng>(rng: &mut R, indices: Vec<DynIndex>) -> Result<Self> {
2503 let dims: Vec<usize> = indices.iter().map(|idx| idx.dim()).collect();
2504 let size = checked_product(&dims)?;
2505 let data: Vec<T> = (0..size).map(|_| T::random_value(rng)).collect();
2506 Self::from_dense(indices, data)
2507 }
2508}
2509
2510impl TensorDynLen {
2511 pub fn add(&self, other: &Self) -> Result<Self> {
2545 if self.indices.len() != other.indices.len() {
2547 return Err(anyhow::anyhow!(
2548 "Index count mismatch: self has {} indices, other has {}",
2549 self.indices.len(),
2550 other.indices.len()
2551 ));
2552 }
2553
2554 let self_set: HashSet<_> = self.indices.iter().collect();
2556 let other_set: HashSet<_> = other.indices.iter().collect();
2557
2558 if self_set != other_set {
2559 return Err(anyhow::anyhow!(
2560 "Index set mismatch: tensors must have the same indices"
2561 ));
2562 }
2563
2564 let other_aligned = other.permute_indices(&self.indices)?;
2566
2567 let self_expected_dims = Self::expected_dims_from_indices(&self.indices);
2569 let other_expected_dims = Self::expected_dims_from_indices(&other_aligned.indices);
2570 if self_expected_dims != other_expected_dims {
2571 use crate::TagSetLike;
2572 let fmt = |indices: &[DynIndex]| -> Vec<String> {
2573 indices
2574 .iter()
2575 .map(|idx| {
2576 let tags: Vec<String> = idx.tags().iter().collect();
2577 format!("{:?}(dim={},tags={:?})", idx.id(), idx.dim(), tags)
2578 })
2579 .collect()
2580 };
2581 return Err(anyhow::anyhow!(
2582 "Dimension mismatch after alignment.\n\
2583 self: dims={:?}, indices(order)={:?}\n\
2584 other_aligned: dims={:?}, indices(order)={:?}",
2585 self_expected_dims,
2586 fmt(&self.indices),
2587 other_expected_dims,
2588 fmt(&other_aligned.indices)
2589 ));
2590 }
2591
2592 self.axpby(
2593 AnyScalar::new_real(1.0),
2594 &other_aligned,
2595 AnyScalar::new_real(1.0),
2596 )
2597 }
2598
2599 pub fn axpby(&self, a: AnyScalar, other: &Self, b: AnyScalar) -> Result<Self> {
2621 if self.indices.len() != other.indices.len() {
2623 return Err(anyhow::anyhow!(
2624 "Index count mismatch: self has {} indices, other has {}",
2625 self.indices.len(),
2626 other.indices.len()
2627 ));
2628 }
2629
2630 let self_set: HashSet<_> = self.indices.iter().collect();
2632 let other_set: HashSet<_> = other.indices.iter().collect();
2633 if self_set != other_set {
2634 return Err(anyhow::anyhow!(
2635 "Index set mismatch: tensors must have the same indices"
2636 ));
2637 }
2638
2639 let other_aligned = other.permute_indices(&self.indices)?;
2641
2642 let self_expected_dims = Self::expected_dims_from_indices(&self.indices);
2644 let other_expected_dims = Self::expected_dims_from_indices(&other_aligned.indices);
2645 if self_expected_dims != other_expected_dims {
2646 return Err(anyhow::anyhow!(
2647 "Dimension mismatch after alignment: self={:?}, other_aligned={:?}",
2648 self_expected_dims,
2649 other_expected_dims
2650 ));
2651 }
2652
2653 let axis_classes = if self.storage.axis_classes() == other_aligned.storage.axis_classes() {
2654 self.storage.axis_classes().to_vec()
2655 } else {
2656 Self::dense_axis_classes(self.indices.len())
2657 };
2658
2659 let same_compact_layout = self.storage.payload_dims()
2660 == other_aligned.storage.payload_dims()
2661 && self.storage.payload_strides_vec() == other_aligned.storage.payload_strides_vec()
2662 && self.storage.axis_classes() == other_aligned.storage.axis_classes();
2663 if same_compact_layout
2664 && !self.tracks_grad()
2665 && !other_aligned.tracks_grad()
2666 && !a.tracks_grad()
2667 && !b.tracks_grad()
2668 {
2669 let lhs_storage = self.storage.materialize(self.indices.len())?;
2670 let rhs_storage = other_aligned
2671 .storage
2672 .materialize(other_aligned.indices.len())?;
2673 let combined = lhs_storage
2674 .axpby(
2675 &a.to_backend_scalar(),
2676 rhs_storage.as_ref(),
2677 &b.to_backend_scalar(),
2678 )
2679 .map_err(|e| anyhow::anyhow!("storage axpby failed: {e}"))?;
2680 return Self::from_storage(self.indices.clone(), Arc::new(combined));
2681 }
2682
2683 let self_native = self.as_native()?;
2684 let other_native = other_aligned.as_native()?;
2685 let a_native = a.as_tensor()?.as_native()?;
2686 let b_native = b.as_tensor()?.as_native()?;
2687 if self_native.dtype() != other_native.dtype()
2688 || self_native.dtype() != a_native.dtype()
2689 || other_native.dtype() != b_native.dtype()
2690 {
2691 let combined = axpby_native_tensor(
2692 self_native,
2693 &a.to_backend_scalar(),
2694 other_native,
2695 &b.to_backend_scalar(),
2696 )?;
2697 return Self::from_native_with_axis_classes(
2698 self.indices.clone(),
2699 combined,
2700 axis_classes,
2701 );
2702 }
2703
2704 let lhs = self.scale(a)?;
2705 let rhs = other_aligned.scale(b)?;
2706 let combined = lhs
2707 .try_materialized_inner()?
2708 .add(rhs.try_materialized_inner()?)
2709 .map_err(|e| anyhow::anyhow!("tensor addition failed: {e}"))?;
2710 Self::from_inner_with_axis_classes(self.indices.clone(), combined, axis_classes)
2711 }
2712
2713 pub fn scale(&self, scalar: AnyScalar) -> Result<Self> {
2728 if !self.tracks_grad() && !scalar.tracks_grad() {
2729 let scaled = self.storage.scale(&scalar.to_backend_scalar())?;
2730 return Self::from_storage(self.indices.clone(), Arc::new(scaled));
2731 }
2732
2733 let self_native = self.as_native()?;
2734 let scalar_native = scalar.as_tensor()?.as_native()?;
2735 if self_native.dtype() != scalar_native.dtype() {
2736 let scaled = scale_native_tensor(self_native, &scalar.to_backend_scalar())?;
2737 return Self::from_native_with_axis_classes(
2738 self.indices.clone(),
2739 scaled,
2740 self.storage.axis_classes().to_vec(),
2741 );
2742 }
2743
2744 let scaled = if self.indices.is_empty() {
2745 self.try_materialized_inner()?
2746 .mul(scalar.as_tensor()?.try_materialized_inner()?)
2747 .map_err(|e| anyhow::anyhow!("scalar multiplication failed: {e}"))?
2748 } else {
2749 let subscripts = Self::scale_subscripts(self.indices.len())?;
2750 eager_einsum_ad(
2751 &[
2752 self.try_materialized_inner()?,
2753 scalar.as_tensor()?.try_materialized_inner()?,
2754 ],
2755 &subscripts,
2756 )
2757 .map_err(|e| anyhow::anyhow!("tensor scaling failed: {e}"))?
2758 };
2759 Self::from_inner_with_axis_classes(
2760 self.indices.clone(),
2761 scaled,
2762 self.storage.axis_classes().to_vec(),
2763 )
2764 }
2765
2766 pub fn inner_product(&self, other: &Self) -> Result<AnyScalar> {
2784 if self.indices.len() == other.indices.len() {
2785 let self_set: HashSet<_> = self.indices.iter().collect();
2786 let other_set: HashSet<_> = other.indices.iter().collect();
2787 if self_set == other_set {
2788 let other_aligned = other.permute_indices(&self.indices)?;
2789 let result = super::contract::contract_pair_with_operand_options(
2790 self,
2791 &other_aligned,
2792 PairwiseContractionOptions::new().with_lhs_conj(true),
2793 )?;
2794 return result.sum();
2795 }
2796 }
2797
2798 let result = super::contract::contract_pair_with_operand_options(
2800 self,
2801 other,
2802 PairwiseContractionOptions::new().with_lhs_conj(true),
2803 )?;
2804 result.sum()
2806 }
2807}
2808
2809impl TensorDynLen {
2814 pub fn replaceind(&self, old_index: &DynIndex, new_index: &DynIndex) -> Result<Self> {
2848 if old_index.dim() != new_index.dim() {
2850 return Err(anyhow::anyhow!(
2851 "Index space mismatch: cannot replace index with dimension {} with index of dimension {}",
2852 old_index.dim(),
2853 new_index.dim()
2854 ));
2855 }
2856
2857 let new_indices: Vec<_> = self
2858 .indices
2859 .iter()
2860 .map(|idx| {
2861 if *idx == *old_index {
2862 new_index.clone()
2863 } else {
2864 idx.clone()
2865 }
2866 })
2867 .collect();
2868
2869 Ok(Self {
2870 indices: new_indices,
2871 storage: self.storage.clone(),
2872 structured_ad: self.structured_ad.clone(),
2873 eager_cache: Arc::clone(&self.eager_cache),
2874 })
2875 }
2876
2877 pub fn replaceinds(&self, old_indices: &[DynIndex], new_indices: &[DynIndex]) -> Result<Self> {
2915 anyhow::ensure!(
2916 old_indices.len() == new_indices.len(),
2917 "old_indices and new_indices must have the same length"
2918 );
2919
2920 for (old, new) in old_indices.iter().zip(new_indices.iter()) {
2922 if old.dim() != new.dim() {
2923 return Err(anyhow::anyhow!(
2924 "Index space mismatch: cannot replace index with dimension {} with index of dimension {}",
2925 old.dim(),
2926 new.dim()
2927 ));
2928 }
2929 }
2930
2931 let replacement_map: std::collections::HashMap<_, _> =
2933 old_indices.iter().zip(new_indices.iter()).collect();
2934
2935 let new_indices_vec: Vec<_> = self
2936 .indices
2937 .iter()
2938 .map(|idx| {
2939 if let Some(new_idx) = replacement_map.get(idx) {
2940 (*new_idx).clone()
2941 } else {
2942 idx.clone()
2943 }
2944 })
2945 .collect();
2946
2947 Ok(Self {
2948 indices: new_indices_vec,
2949 storage: self.storage.clone(),
2950 structured_ad: self.structured_ad.clone(),
2951 eager_cache: Arc::clone(&self.eager_cache),
2952 })
2953 }
2954}
2955
2956impl TensorDynLen {
2961 pub fn conj(&self) -> Self {
2984 let new_indices: Vec<DynIndex> = self.indices.iter().map(|idx| idx.conj()).collect();
2988 let structured_ad = self.tracked_compact_payload_value().and_then(|value| {
2989 value.payload.conj().ok().map(|payload| {
2990 Arc::new(StructuredAdValue {
2991 payload: Arc::new(payload),
2992 payload_dims: value.payload_dims.clone(),
2993 axis_classes: value.axis_classes.clone(),
2994 })
2995 })
2996 });
2997 let eager_cache = self
2998 .eager_cache
2999 .get()
3000 .and_then(|inner| inner.conj().ok())
3001 .map(Self::eager_cache_with)
3002 .unwrap_or_else(Self::empty_eager_cache);
3003 Self {
3004 indices: new_indices,
3005 storage: self.storage.conj().unwrap_or_else(|_| {
3006 TensorDynLenStorage::from_storage(Arc::new(self.storage().conj()))
3007 }),
3008 structured_ad,
3009 eager_cache,
3010 }
3011 }
3012}
3013
3014impl TensorDynLen {
3019 pub fn norm_squared(&self) -> f64 {
3037 self.try_norm_squared().unwrap_or(f64::NAN)
3038 }
3039
3040 pub fn try_norm_squared(&self) -> Result<f64> {
3045 if self.indices.is_empty() {
3047 let value = self.sum()?;
3049 let abs_val = value.abs();
3050 return Ok(abs_val * abs_val);
3051 }
3052
3053 let conj = self.conj();
3056 let scalar = super::contract::contract_pair(self, &conj)?;
3057 Ok(scalar.sum()?.real().max(0.0))
3060 }
3061
3062 pub fn norm(&self) -> f64 {
3076 self.norm_squared().sqrt()
3077 }
3078
3079 pub fn maxabs(&self) -> f64 {
3091 self.storage.max_abs().unwrap_or(0.0)
3092 }
3093
3094 pub fn sub(&self, other: &Self) -> Result<Self> {
3102 self.axpby(AnyScalar::new_real(1.0), other, AnyScalar::new_real(-1.0))
3103 }
3104
3105 pub fn neg(&self) -> Result<Self> {
3110 self.scale(AnyScalar::new_real(-1.0))
3111 }
3112
3113 pub fn isapprox(&self, other: &Self, atol: f64, rtol: f64) -> bool {
3118 let diff = match self.sub(other) {
3119 Ok(d) => d,
3120 Err(_) => return false,
3121 };
3122 let diff_norm = diff.norm();
3123 diff_norm <= atol.max(rtol * self.norm().max(other.norm()))
3124 }
3125
3126 pub fn diagonal(input_index: &DynIndex, output_index: &DynIndex) -> Result<Self> {
3131 <Self as TensorConstructionLike>::diagonal(input_index, output_index)
3132 }
3133
3134 pub fn delta(input_indices: &[DynIndex], output_indices: &[DynIndex]) -> Result<Self> {
3140 <Self as TensorConstructionLike>::delta(input_indices, output_indices)
3141 }
3142
3143 pub fn scalar_one() -> Result<Self> {
3148 <Self as TensorConstructionLike>::scalar_one()
3149 }
3150
3151 pub fn ones(indices: &[DynIndex]) -> Result<Self> {
3156 <Self as TensorConstructionLike>::ones(indices)
3157 }
3158
3159 pub fn onehot(index_vals: &[(DynIndex, usize)]) -> Result<Self> {
3164 <Self as TensorConstructionLike>::onehot(index_vals)
3165 }
3166
3167 pub fn distance(&self, other: &Self) -> Result<f64> {
3198 let norm_self = self.norm();
3199
3200 let neg_other = other.scale(AnyScalar::new_real(-1.0))?;
3202 let diff = self.add(&neg_other)?;
3203 let norm_diff = diff.norm();
3204
3205 if norm_self > 0.0 {
3206 Ok(norm_diff / norm_self)
3207 } else {
3208 Ok(norm_diff)
3209 }
3210 }
3211}
3212
3213impl std::fmt::Debug for TensorDynLen {
3214 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3215 f.debug_struct("TensorDynLen")
3216 .field("indices", &self.indices)
3217 .field("dims", &self.dims())
3218 .field("is_diag", &self.is_diag())
3219 .finish()
3220 }
3221}
3222
3223pub fn diag_tensor_dyn_len(indices: Vec<DynIndex>, diag_data: Vec<f64>) -> Result<TensorDynLen> {
3248 TensorDynLen::from_diag(indices, diag_data)
3249}
3250
3251#[allow(clippy::type_complexity)]
3252pub(crate) type UnfoldSplitInnerResult = (
3253 EagerTensor<CpuBackend>,
3254 usize,
3255 usize,
3256 usize,
3257 Vec<DynIndex>,
3258 Vec<DynIndex>,
3259);
3260
3261#[allow(clippy::type_complexity)]
3308pub fn unfold_split(
3309 t: &TensorDynLen,
3310 left_inds: &[DynIndex],
3311) -> Result<(
3312 NativeTensor,
3313 usize,
3314 usize,
3315 usize,
3316 Vec<DynIndex>,
3317 Vec<DynIndex>,
3318)> {
3319 let (matrix_inner, left_len, m, n, left_indices, right_indices) =
3320 unfold_split_inner(t, left_inds)?;
3321
3322 Ok((
3323 matrix_inner.data().clone(),
3324 left_len,
3325 m,
3326 n,
3327 left_indices,
3328 right_indices,
3329 ))
3330}
3331
3332pub(crate) fn unfold_split_inner(
3333 t: &TensorDynLen,
3334 left_inds: &[DynIndex],
3335) -> Result<UnfoldSplitInnerResult> {
3336 let rank = t.indices.len();
3337
3338 anyhow::ensure!(rank >= 2, "Tensor must have rank >= 2, got rank {}", rank);
3340
3341 let left_len = left_inds.len();
3342
3343 anyhow::ensure!(
3345 left_len > 0 && left_len < rank,
3346 "Left indices must be a non-empty proper subset of tensor indices (0 < left_len < rank), got left_len={}, rank={}",
3347 left_len,
3348 rank
3349 );
3350
3351 let tensor_set: HashSet<_> = t.indices.iter().collect();
3353 let mut left_set = HashSet::new();
3354
3355 for left_idx in left_inds {
3356 anyhow::ensure!(
3357 tensor_set.contains(left_idx),
3358 "Index in left_inds not found in tensor"
3359 );
3360 anyhow::ensure!(left_set.insert(left_idx), "Duplicate index in left_inds");
3361 }
3362
3363 let mut right_inds = Vec::new();
3365 for idx in &t.indices {
3366 if !left_set.contains(idx) {
3367 right_inds.push(idx.clone());
3368 }
3369 }
3370
3371 let mut new_indices = Vec::with_capacity(rank);
3373 new_indices.extend_from_slice(left_inds);
3374 new_indices.extend_from_slice(&right_inds);
3375
3376 let unfolded = t.permute_indices(&new_indices)?;
3378
3379 let unfolded_dims = unfolded.dims();
3381 let m: usize = unfolded_dims[..left_len].iter().product();
3382 let n: usize = unfolded_dims[left_len..].iter().product();
3383
3384 let matrix_tensor = unfolded.try_materialized_inner()?.reshape(&[m, n])?;
3385
3386 Ok((
3387 matrix_tensor,
3388 left_len,
3389 m,
3390 n,
3391 left_inds.to_vec(),
3392 right_inds,
3393 ))
3394}
3395
3396use crate::tensor_index::TensorIndex;
3401
3402impl TensorIndex for TensorDynLen {
3403 type Index = DynIndex;
3404
3405 fn external_indices(&self) -> Vec<DynIndex> {
3406 self.indices.clone()
3408 }
3409
3410 fn num_external_indices(&self) -> usize {
3411 self.indices.len()
3412 }
3413
3414 fn replaceind(&self, old_index: &DynIndex, new_index: &DynIndex) -> Result<Self> {
3415 TensorDynLen::replaceind(self, old_index, new_index)
3417 }
3418
3419 fn replaceinds(&self, old_indices: &[DynIndex], new_indices: &[DynIndex]) -> Result<Self> {
3420 TensorDynLen::replaceinds(self, old_indices, new_indices)
3422 }
3423}
3424
3425use crate::tensor_like::{
3430 FactorizeError, FactorizeOptions, FactorizeResult, TensorConstructionLike,
3431 TensorContractionLike, TensorFactorizationLike, TensorVectorSpace,
3432};
3433
3434impl TensorVectorSpace for TensorDynLen {
3435 fn norm_squared(&self) -> f64 {
3436 TensorDynLen::norm_squared(self)
3437 }
3438
3439 fn maxabs(&self) -> f64 {
3440 TensorDynLen::maxabs(self)
3441 }
3442
3443 fn axpby(&self, a: crate::AnyScalar, other: &Self, b: crate::AnyScalar) -> Result<Self> {
3444 TensorDynLen::axpby(self, a, other, b)
3445 }
3446
3447 fn scale(&self, scalar: crate::AnyScalar) -> Result<Self> {
3448 TensorDynLen::scale(self, scalar)
3449 }
3450
3451 fn inner_product(&self, other: &Self) -> Result<crate::AnyScalar> {
3452 TensorDynLen::inner_product(self, other)
3453 }
3454}
3455
3456impl TensorFactorizationLike for TensorDynLen {
3457 fn factorize(
3458 &self,
3459 left_inds: &[DynIndex],
3460 options: &FactorizeOptions,
3461 ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
3462 crate::factorize::factorize(self, left_inds, options)
3463 }
3464
3465 fn factorize_full_rank(
3466 &self,
3467 left_inds: &[DynIndex],
3468 alg: crate::FactorizeAlg,
3469 canonical: crate::Canonical,
3470 ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
3471 crate::factorize::factorize_full_rank(self, left_inds, alg, canonical)
3472 }
3473}
3474
3475impl TensorContractionLike for TensorDynLen {
3476 fn conj(&self) -> Self {
3477 TensorDynLen::conj(self)
3479 }
3480
3481 fn direct_sum(
3482 &self,
3483 other: &Self,
3484 pairs: &[(DynIndex, DynIndex)],
3485 ) -> Result<crate::tensor_like::DirectSumResult<Self>> {
3486 let (tensor, new_indices) = crate::direct_sum::direct_sum(self, other, pairs)?;
3487 Ok(crate::tensor_like::DirectSumResult {
3488 tensor,
3489 new_indices,
3490 })
3491 }
3492
3493 fn outer_product(&self, other: &Self) -> Result<Self> {
3494 super::contract::outer_product(self, other)
3495 }
3496
3497 fn permuteinds(&self, new_order: &[DynIndex]) -> Result<Self> {
3498 TensorDynLen::permute_indices(self, new_order)
3500 }
3501
3502 fn fuse_indices(
3503 &self,
3504 old_indices: &[DynIndex],
3505 new_index: DynIndex,
3506 order: LinearizationOrder,
3507 ) -> Result<Self> {
3508 TensorDynLen::fuse_indices(self, old_indices, new_index, order)
3509 }
3510
3511 fn contract(tensors: &[&Self]) -> Result<Self> {
3512 super::contract::contract(tensors)
3513 }
3514
3515 fn contract_pair(&self, other: &Self) -> Result<Self> {
3516 super::contract::contract_pair(self, other)
3517 }
3518}
3519
3520impl TensorConstructionLike for TensorDynLen {
3521 fn select_indices(&self, selected_indices: &[DynIndex], positions: &[usize]) -> Result<Self> {
3522 TensorDynLen::select_indices(self, selected_indices, positions)
3523 }
3524
3525 fn diagonal(input_index: &DynIndex, output_index: &DynIndex) -> Result<Self> {
3526 let dim = input_index.dim();
3527 if dim != output_index.dim() {
3528 return Err(anyhow::anyhow!(
3529 "Dimension mismatch: input index has dim {}, output has dim {}",
3530 dim,
3531 output_index.dim(),
3532 ));
3533 }
3534
3535 TensorDynLen::from_diag(
3536 vec![input_index.clone(), output_index.clone()],
3537 vec![1.0_f64; dim],
3538 )
3539 }
3540
3541 fn scalar_one() -> Result<Self> {
3542 TensorDynLen::from_dense(vec![], vec![1.0_f64])
3543 }
3544
3545 fn ones(indices: &[DynIndex]) -> Result<Self> {
3546 if indices.is_empty() {
3547 return Self::scalar_one();
3548 }
3549 let dims: Vec<usize> = indices.iter().map(|idx| idx.size()).collect();
3550 let total_size = checked_total_size(&dims)?;
3551 TensorDynLen::from_dense(indices.to_vec(), vec![1.0_f64; total_size])
3552 }
3553
3554 fn onehot(index_vals: &[(DynIndex, usize)]) -> Result<Self> {
3555 if index_vals.is_empty() {
3556 return Self::scalar_one();
3557 }
3558 let indices: Vec<DynIndex> = index_vals.iter().map(|(idx, _)| idx.clone()).collect();
3559 let vals: Vec<usize> = index_vals.iter().map(|(_, v)| *v).collect();
3560 let dims: Vec<usize> = indices.iter().map(|idx| idx.size()).collect();
3561
3562 for (k, (&v, &d)) in vals.iter().zip(dims.iter()).enumerate() {
3563 if v >= d {
3564 return Err(anyhow::anyhow!(
3565 "onehot: value {} at position {} is >= dimension {}",
3566 v,
3567 k,
3568 d
3569 ));
3570 }
3571 }
3572
3573 let total_size = checked_total_size(&dims)?;
3574 let mut data = vec![0.0_f64; total_size];
3575
3576 let offset = column_major_offset(&dims, &vals)?;
3577 data[offset] = 1.0;
3578
3579 Self::from_dense(indices, data)
3580 }
3581
3582 }
3584
3585fn checked_total_size(dims: &[usize]) -> Result<usize> {
3586 dims.iter().try_fold(1_usize, |acc, &d| {
3587 if d == 0 {
3588 return Err(anyhow::anyhow!("invalid dimension 0"));
3589 }
3590 acc.checked_mul(d)
3591 .ok_or_else(|| anyhow::anyhow!("tensor size overflow"))
3592 })
3593}
3594
3595fn column_major_offset(dims: &[usize], vals: &[usize]) -> Result<usize> {
3596 if dims.len() != vals.len() {
3597 return Err(anyhow::anyhow!(
3598 "column_major_offset: dims.len() != vals.len()"
3599 ));
3600 }
3601 checked_total_size(dims)?;
3602
3603 let mut offset = 0usize;
3604 let mut stride = 1usize;
3605 for (k, (&v, &d)) in vals.iter().zip(dims.iter()).enumerate() {
3606 if d == 0 {
3607 return Err(anyhow::anyhow!("invalid dimension 0 at position {}", k));
3608 }
3609 if v >= d {
3610 return Err(anyhow::anyhow!(
3611 "column_major_offset: value {} at position {} is >= dimension {}",
3612 v,
3613 k,
3614 d
3615 ));
3616 }
3617 let term = v
3618 .checked_mul(stride)
3619 .ok_or_else(|| anyhow::anyhow!("column_major_offset: overflow"))?;
3620 offset = offset
3621 .checked_add(term)
3622 .ok_or_else(|| anyhow::anyhow!("column_major_offset: overflow"))?;
3623 stride = stride
3624 .checked_mul(d)
3625 .ok_or_else(|| anyhow::anyhow!("column_major_offset: overflow"))?;
3626 }
3627 Ok(offset)
3628}
3629
3630impl TensorDynLen {
3635 fn any_scalar_payload_to_complex(data: Vec<AnyScalar>) -> Vec<Complex64> {
3636 data.into_iter()
3637 .map(|value| {
3638 value
3639 .as_c64()
3640 .unwrap_or_else(|| Complex64::new(value.real(), 0.0))
3641 })
3642 .collect()
3643 }
3644
3645 fn any_scalar_payload_to_real(data: Vec<AnyScalar>) -> Vec<f64> {
3646 data.into_iter().map(|value| value.real()).collect()
3647 }
3648
3649 fn validate_dense_payload_len(data_len: usize, dims: &[usize]) -> Result<()> {
3650 let expected_len = checked_total_size(dims)?;
3651 anyhow::ensure!(
3652 data_len == expected_len,
3653 "dense payload length {} does not match dims {:?} (expected {})",
3654 data_len,
3655 dims,
3656 expected_len
3657 );
3658 Ok(())
3659 }
3660
3661 fn validate_diag_payload_len(data_len: usize, dims: &[usize]) -> Result<()> {
3662 anyhow::ensure!(
3663 !dims.is_empty(),
3664 "diagonal tensor construction requires at least one index"
3665 );
3666 Self::validate_diag_dims(dims)?;
3667 anyhow::ensure!(
3668 data_len == dims[0],
3669 "diagonal payload length {} does not match diagonal dimension {}",
3670 data_len,
3671 dims[0]
3672 );
3673 Ok(())
3674 }
3675
3676 pub fn from_dense<T: TensorElement>(indices: Vec<DynIndex>, data: Vec<T>) -> Result<Self> {
3703 let dims = Self::expected_dims_from_indices(&indices);
3704 Self::validate_indices(&indices)?;
3705 Self::validate_dense_payload_len(data.len(), &dims)?;
3706 let native = dense_native_tensor_from_col_major(&data, &dims)?;
3707 Self::from_native(indices, native)
3708 }
3709
3710 pub fn from_dense_any(indices: Vec<DynIndex>, data: Vec<AnyScalar>) -> Result<Self> {
3736 if data.iter().any(AnyScalar::is_complex) {
3737 Self::from_dense(indices, Self::any_scalar_payload_to_complex(data))
3738 } else {
3739 Self::from_dense(indices, Self::any_scalar_payload_to_real(data))
3740 }
3741 }
3742
3743 pub fn from_diag<T: TensorElement>(indices: Vec<DynIndex>, data: Vec<T>) -> Result<Self> {
3771 let dims = Self::expected_dims_from_indices(&indices);
3772 Self::validate_indices(&indices)?;
3773 Self::validate_diag_payload_len(data.len(), &dims)?;
3774 let native = diag_native_tensor_from_col_major(&data, dims.len())?;
3775 Self::from_native_with_axis_classes(indices, native, Self::diag_axis_classes(dims.len()))
3776 }
3777
3778 pub fn from_diag_any(indices: Vec<DynIndex>, data: Vec<AnyScalar>) -> Result<Self> {
3800 if data.iter().any(AnyScalar::is_complex) {
3801 Self::from_diag(indices, Self::any_scalar_payload_to_complex(data))
3802 } else {
3803 Self::from_diag(indices, Self::any_scalar_payload_to_real(data))
3804 }
3805 }
3806
3807 pub fn copy_tensor(indices: Vec<DynIndex>, value: AnyScalar) -> Result<Self> {
3828 if indices.is_empty() {
3829 return Self::from_dense_any(vec![], vec![value]);
3830 }
3831 let dim = indices[0].dim();
3832 let data = vec![value; dim];
3833 Self::from_diag_any(indices, data)
3834 }
3835
3836 pub fn fuse_indices(
3890 &self,
3891 old_indices: &[DynIndex],
3892 new_index: DynIndex,
3893 order: LinearizationOrder,
3894 ) -> Result<Self> {
3895 anyhow::ensure!(
3896 !old_indices.is_empty(),
3897 "fuse_indices requires at least one index to fuse"
3898 );
3899
3900 let old_dims = self.dims();
3901 let mut seen_indices = HashSet::new();
3902 let mut old_axes = Vec::with_capacity(old_indices.len());
3903 for old_index in old_indices {
3904 anyhow::ensure!(
3905 seen_indices.insert(old_index),
3906 "duplicate index in old_indices"
3907 );
3908 let axis = self
3909 .indices
3910 .iter()
3911 .position(|idx| idx == old_index)
3912 .ok_or_else(|| anyhow::anyhow!("index {:?} not found in tensor", old_index))?;
3913 anyhow::ensure!(
3914 old_index.dim() == old_dims[axis],
3915 "old index dimension does not match tensor axis dimension"
3916 );
3917 old_axes.push(axis);
3918 }
3919
3920 let fused_dims: Vec<usize> = old_axes.iter().map(|&axis| old_dims[axis]).collect();
3921 let fused_product = checked_product(&fused_dims)?;
3922 anyhow::ensure!(
3923 fused_product == new_index.dim(),
3924 "product of old index dimensions must match the replacement index dimension"
3925 );
3926
3927 let insertion_axis =
3928 old_axes.iter().copied().min().ok_or_else(|| {
3929 anyhow::anyhow!("fuse_indices requires at least one index to fuse")
3930 })?;
3931 let old_axis_set: HashSet<usize> = old_axes.iter().copied().collect();
3932
3933 let mut result_indices =
3934 Vec::with_capacity(self.indices.len() - old_indices.len() + 1usize);
3935 for (axis, index) in self.indices.iter().enumerate() {
3936 if axis == insertion_axis {
3937 result_indices.push(new_index.clone());
3938 }
3939 if !old_axis_set.contains(&axis) {
3940 result_indices.push(index.clone());
3941 }
3942 }
3943 let mut result_seen = HashSet::new();
3944 for index in &result_indices {
3945 anyhow::ensure!(
3946 result_seen.insert(index),
3947 "fuse_indices result would contain duplicate index"
3948 );
3949 }
3950 Self::validate_indices(&result_indices)?;
3951
3952 let mut new_dims = Vec::with_capacity(old_dims.len() - old_indices.len() + 1usize);
3953 for (axis, dim) in old_dims.iter().copied().enumerate() {
3954 if axis == insertion_axis {
3955 new_dims.push(new_index.dim());
3956 }
3957 if !old_axis_set.contains(&axis) {
3958 new_dims.push(dim);
3959 }
3960 }
3961
3962 let old_data = self.to_vec_any()?;
3963 let mut new_data = vec![AnyScalar::new_real(0.0); old_data.len()];
3964 for (old_linear, value) in old_data.into_iter().enumerate() {
3965 let old_multi = decode_col_major_linear(old_linear, &old_dims)?;
3966 let fused_multi: Vec<usize> = old_axes.iter().map(|&axis| old_multi[axis]).collect();
3967 let fused_linear = encode_linear_with_order(&fused_multi, &fused_dims, order)?;
3968
3969 let mut new_multi = Vec::with_capacity(new_dims.len());
3970 for (axis, old_coord) in old_multi.iter().copied().enumerate() {
3971 if axis == insertion_axis {
3972 new_multi.push(fused_linear);
3973 }
3974 if !old_axis_set.contains(&axis) {
3975 new_multi.push(old_coord);
3976 }
3977 }
3978 let new_linear = encode_col_major_linear(&new_multi, &new_dims)?;
3979 new_data[new_linear] = value;
3980 }
3981
3982 Self::from_dense_any(result_indices, new_data)
3983 }
3984
3985 pub fn unfuse_index(
4007 &self,
4008 old_index: &DynIndex,
4009 new_indices: &[DynIndex],
4010 order: LinearizationOrder,
4011 ) -> Result<Self> {
4012 anyhow::ensure!(
4013 !new_indices.is_empty(),
4014 "unfuse_index requires at least one replacement index"
4015 );
4016
4017 let axis = self
4018 .indices
4019 .iter()
4020 .position(|idx| idx == old_index)
4021 .ok_or_else(|| anyhow::anyhow!("index {:?} not found in tensor", old_index))?;
4022
4023 let replacement_dims: Vec<usize> = new_indices.iter().map(DynIndex::dim).collect();
4024 let replacement_product = checked_product(&replacement_dims)?;
4025 anyhow::ensure!(
4026 replacement_product == old_index.dim(),
4027 "product of new index dimensions must match the replaced index dimension"
4028 );
4029
4030 let mut result_indices =
4031 Vec::with_capacity(self.indices.len() - 1usize + new_indices.len());
4032 result_indices.extend_from_slice(&self.indices[..axis]);
4033 result_indices.extend(new_indices.iter().cloned());
4034 result_indices.extend_from_slice(&self.indices[axis + 1..]);
4035 Self::validate_indices(&result_indices)?;
4036
4037 let old_dims = self.dims();
4038 let mut new_dims = Vec::with_capacity(old_dims.len() - 1usize + replacement_dims.len());
4039 new_dims.extend_from_slice(&old_dims[..axis]);
4040 new_dims.extend_from_slice(&replacement_dims);
4041 new_dims.extend_from_slice(&old_dims[axis + 1..]);
4042
4043 let old_data = self.to_vec_any()?;
4044 let mut new_data = vec![AnyScalar::new_real(0.0); old_data.len()];
4045 for (old_linear, value) in old_data.into_iter().enumerate() {
4046 let old_multi = decode_col_major_linear(old_linear, &old_dims)?;
4047 let split_multi = decode_linear_with_order(old_multi[axis], &replacement_dims, order)?;
4048 let mut new_multi = Vec::with_capacity(new_dims.len());
4049 new_multi.extend_from_slice(&old_multi[..axis]);
4050 new_multi.extend_from_slice(&split_multi);
4051 new_multi.extend_from_slice(&old_multi[axis + 1..]);
4052 let new_linear = encode_col_major_linear(&new_multi, &new_dims)?;
4053 new_data[new_linear] = value;
4054 }
4055
4056 Self::from_dense_any(result_indices, new_data)
4057 }
4058
4059 pub fn scalar<T: TensorElement>(value: T) -> Result<Self> {
4070 Self::from_dense(vec![], vec![value])
4071 }
4072
4073 pub fn zeros<T: TensorElement + Zero + Clone>(indices: Vec<DynIndex>) -> Result<Self> {
4086 let dims: Vec<usize> = indices.iter().map(|idx| idx.dim()).collect();
4087 let size: usize = dims.iter().product();
4088 Self::from_dense(indices, vec![T::zero(); size])
4089 }
4090}
4091
4092impl TensorDynLen {
4097 pub fn to_vec<T: TensorElement>(&self) -> Result<Vec<T>> {
4119 native_tensor_primal_to_dense_col_major(self.as_native()?)
4120 }
4121
4122 pub fn into_dense_col_major_parts<T: TensorElement>(self) -> Result<(Vec<DynIndex>, Vec<T>)> {
4158 anyhow::ensure!(
4159 self.structured_ad.is_none() && !self.tracks_grad(),
4160 "TensorDynLen::into_dense_col_major_parts cannot consume tensors with tracked autodiff state"
4161 );
4162 let data = self.to_vec::<T>()?;
4163 Ok((self.indices, data))
4164 }
4165
4166 fn to_vec_any(&self) -> Result<Vec<AnyScalar>> {
4167 if self.is_complex() {
4168 self.to_vec::<Complex64>().map(|data| {
4169 data.into_iter()
4170 .map(|value| AnyScalar::new_complex(value.re, value.im))
4171 .collect()
4172 })
4173 } else {
4174 self.to_vec::<f64>()
4175 .map(|data| data.into_iter().map(AnyScalar::new_real).collect())
4176 }
4177 }
4178
4179 pub fn as_slice_f64(&self) -> Result<Vec<f64>> {
4184 self.to_vec::<f64>()
4185 }
4186
4187 pub fn as_slice_c64(&self) -> Result<Vec<Complex64>> {
4192 self.to_vec::<Complex64>()
4193 }
4194
4195 pub fn is_f64(&self) -> bool {
4208 self.storage.is_f64()
4209 }
4210
4211 pub fn is_diag(&self) -> bool {
4238 self.storage.is_diag()
4239 }
4240
4241 pub fn is_complex(&self) -> bool {
4260 self.storage.is_complex()
4261 }
4262}
4263
4264fn checked_product(dims: &[usize]) -> Result<usize> {
4265 dims.iter().try_fold(1usize, |acc, &dim| {
4266 acc.checked_mul(dim)
4267 .ok_or_else(|| anyhow::anyhow!("dimension product overflow"))
4268 })
4269}
4270
4271fn decode_col_major_linear(linear: usize, dims: &[usize]) -> Result<Vec<usize>> {
4272 let total = checked_product(dims)?;
4273 anyhow::ensure!(
4274 linear < total,
4275 "linear offset {} out of bounds for dims {:?}",
4276 linear,
4277 dims
4278 );
4279 let mut remaining = linear;
4280 let mut out = Vec::with_capacity(dims.len());
4281 for &dim in dims {
4282 out.push(remaining % dim);
4283 remaining /= dim;
4284 }
4285 Ok(out)
4286}
4287
4288fn encode_col_major_linear(indices: &[usize], dims: &[usize]) -> Result<usize> {
4289 anyhow::ensure!(
4290 indices.len() == dims.len(),
4291 "index rank {} does not match dims {:?}",
4292 indices.len(),
4293 dims
4294 );
4295 let mut linear = 0usize;
4296 let mut stride = 1usize;
4297 for (&index, &dim) in indices.iter().zip(dims.iter()) {
4298 anyhow::ensure!(
4299 index < dim,
4300 "index {} out of bounds for dimension {}",
4301 index,
4302 dim
4303 );
4304 linear += index * stride;
4305 stride = stride
4306 .checked_mul(dim)
4307 .ok_or_else(|| anyhow::anyhow!("stride overflow"))?;
4308 }
4309 Ok(linear)
4310}
4311
4312fn decode_linear_with_order(
4313 linear: usize,
4314 dims: &[usize],
4315 order: LinearizationOrder,
4316) -> Result<Vec<usize>> {
4317 let total = checked_product(dims)?;
4318 anyhow::ensure!(
4319 linear < total,
4320 "linear offset {} out of bounds for dims {:?}",
4321 linear,
4322 dims
4323 );
4324
4325 let mut remaining = linear;
4326 let mut out = vec![0usize; dims.len()];
4327 match order {
4328 LinearizationOrder::ColumnMajor => {
4329 for (slot, &dim) in out.iter_mut().zip(dims.iter()) {
4330 *slot = remaining % dim;
4331 remaining /= dim;
4332 }
4333 }
4334 LinearizationOrder::RowMajor => {
4335 for (slot, &dim) in out.iter_mut().rev().zip(dims.iter().rev()) {
4336 *slot = remaining % dim;
4337 remaining /= dim;
4338 }
4339 }
4340 }
4341 Ok(out)
4342}
4343
4344fn encode_linear_with_order(
4345 indices: &[usize],
4346 dims: &[usize],
4347 order: LinearizationOrder,
4348) -> Result<usize> {
4349 match order {
4350 LinearizationOrder::ColumnMajor => encode_col_major_linear(indices, dims),
4351 LinearizationOrder::RowMajor => {
4352 anyhow::ensure!(
4353 indices.len() == dims.len(),
4354 "index rank {} does not match dims {:?}",
4355 indices.len(),
4356 dims
4357 );
4358 let mut linear = 0usize;
4359 let mut stride = 1usize;
4360 for (&index, &dim) in indices.iter().rev().zip(dims.iter().rev()) {
4361 anyhow::ensure!(
4362 index < dim,
4363 "index {} out of bounds for dimension {}",
4364 index,
4365 dim
4366 );
4367 linear += index * stride;
4368 stride = stride
4369 .checked_mul(dim)
4370 .ok_or_else(|| anyhow::anyhow!("stride overflow"))?;
4371 }
4372 Ok(linear)
4373 }
4374 }
4375}