1use crate::defaults::DynIndex;
2use crate::index_like::IndexLike;
3use crate::index_ops::{common_ind_positions, prepare_contraction, prepare_contraction_pairs};
4use crate::tensor_like::LinearizationOrder;
5use crate::AnyScalar;
6use anyhow::{Context, Result};
7use num_complex::Complex64;
8use num_traits::Zero;
9use rand::Rng;
10use rand_distr::{Distribution, StandardNormal};
11use std::cell::RefCell;
12use std::cmp::Reverse;
13use std::collections::{HashMap, HashSet};
14use std::env;
15use std::sync::{Arc, OnceLock};
16use std::time::{Duration, Instant};
17use tenferro::{DType, DotGeneralConfig, Tensor as NativeTensor};
18use tenferro_ad::EagerTensor;
19use tenferro_einsum::eager_tensor::einsum_subscripts as eager_einsum_ad;
20use tenferro_einsum::EinsumSubscripts;
21use tensor4all_tensorbackend::{
22 axpby_native_tensor, contract_native_tensor, default_eager_ctx,
23 dense_native_tensor_from_col_major, diag_native_tensor_from_col_major,
24 native_tensor_primal_to_dense_col_major, native_tensor_primal_to_diag_c64,
25 native_tensor_primal_to_diag_f64, native_tensor_primal_to_storage, scale_native_tensor,
26 storage_payload_native_read_input, storage_to_native_tensor, AnyScalar as BackendScalar,
27 StorageScalar, TensorElement,
28};
29use tensor4all_tensorbackend::{Storage, StorageKind};
30
31use super::contract::PairwiseContractionOptions;
32use super::structured_contraction::{
33 normalize_payload_read_for_roots, storage_from_payload_native, storage_payload_native,
34 OperandLayout, StructuredContractionPlan, StructuredContractionSpec,
35};
36
37#[derive(Debug, Default, Clone)]
38struct PairwiseContractProfileEntry {
39 calls: usize,
40 total_time: Duration,
41 total_bytes: usize,
42}
43
44thread_local! {
45 static PAIRWISE_CONTRACT_PROFILE_STATE: RefCell<HashMap<&'static str, PairwiseContractProfileEntry>> =
46 RefCell::new(HashMap::new());
47}
48
49fn pairwise_contract_profile_enabled() -> bool {
50 static ENABLED: OnceLock<bool> = OnceLock::new();
51 *ENABLED.get_or_init(|| env::var("T4A_PROFILE_PAIRWISE_CONTRACT").is_ok())
52}
53
54fn record_pairwise_contract_profile(section: &'static str, elapsed: Duration) {
55 if !pairwise_contract_profile_enabled() {
56 return;
57 }
58 PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| {
59 let mut state = state.borrow_mut();
60 let entry = state.entry(section).or_default();
61 entry.calls += 1;
62 entry.total_time += elapsed;
63 });
64}
65
66fn record_pairwise_contract_profile_bytes(section: &'static str, bytes: usize) {
67 if !pairwise_contract_profile_enabled() {
68 return;
69 }
70 PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| {
71 let mut state = state.borrow_mut();
72 let entry = state.entry(section).or_default();
73 entry.total_bytes += bytes;
74 });
75}
76
77fn profile_pairwise_contract_section<T>(section: &'static str, f: impl FnOnce() -> T) -> T {
78 if !pairwise_contract_profile_enabled() {
79 return f();
80 }
81 let started = Instant::now();
82 let result = f();
83 record_pairwise_contract_profile(section, started.elapsed());
84 result
85}
86
87pub fn reset_pairwise_contract_profile() {
89 PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| state.borrow_mut().clear());
90}
91
92pub fn print_and_reset_pairwise_contract_profile() {
94 if !pairwise_contract_profile_enabled() {
95 return;
96 }
97 PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| {
98 let mut entries: Vec<_> = state
99 .borrow()
100 .iter()
101 .map(|(section, entry)| (*section, entry.clone()))
102 .collect();
103 state.borrow_mut().clear();
104 entries.sort_by_key(|(_, entry)| Reverse(entry.total_time));
105
106 eprintln!("=== TensorDynLen pairwise contract profile ===");
107 for (section, entry) in entries {
108 let per_call_us = if entry.calls == 0 {
109 0.0
110 } else {
111 entry.total_time.as_secs_f64() * 1.0e6 / entry.calls as f64
112 };
113 eprintln!(
114 "{section}: calls={} total={:.6}ms per_call={:.3}us bytes={}",
115 entry.calls,
116 entry.total_time.as_secs_f64() * 1.0e3,
117 per_call_us,
118 entry.total_bytes,
119 );
120 }
121 });
122}
123
124fn native_tensor_profile_bytes(native: &NativeTensor) -> usize {
125 let element_size = match native.dtype() {
126 DType::F32 => 4,
127 DType::F64 => 8,
128 DType::C32 => 8,
129 DType::C64 => 16,
130 DType::I32 => 4,
131 DType::I64 => 8,
132 DType::Bool => 1,
133 };
134 native.shape().iter().product::<usize>() * element_size
135}
136
137pub trait RandomScalar: TensorElement {
142 fn random_value<R: Rng>(rng: &mut R) -> Self;
144}
145
146impl RandomScalar for f64 {
147 fn random_value<R: Rng>(rng: &mut R) -> Self {
148 StandardNormal.sample(rng)
149 }
150}
151
152impl RandomScalar for Complex64 {
153 fn random_value<R: Rng>(rng: &mut R) -> Self {
154 Complex64::new(StandardNormal.sample(rng), StandardNormal.sample(rng))
155 }
156}
157
158pub fn compute_permutation_from_indices(
191 original_indices: &[DynIndex],
192 new_indices: &[DynIndex],
193) -> Result<Vec<usize>> {
194 anyhow::ensure!(
195 new_indices.len() == original_indices.len(),
196 "new_indices length must match original_indices length"
197 );
198
199 let mut perm = Vec::with_capacity(new_indices.len());
200 let mut used = std::collections::HashSet::new();
201
202 for new_idx in new_indices {
203 let pos = original_indices
206 .iter()
207 .position(|old_idx| old_idx == new_idx)
208 .ok_or_else(|| {
209 anyhow::anyhow!("new_indices must be a permutation of original_indices")
210 })?;
211
212 anyhow::ensure!(used.insert(pos), "duplicate index in new_indices");
213 perm.push(pos);
214 }
215
216 Ok(perm)
217}
218
219#[derive(Clone)]
220pub(crate) struct StructuredAdValue {
221 payload: Arc<EagerTensor>,
222 payload_dims: Vec<usize>,
223 axis_classes: Vec<usize>,
224}
225
226#[derive(Clone)]
227pub(crate) enum TensorDynLenStorage {
228 Materialized(Arc<Storage>),
229 Eager {
230 inner: Arc<EagerTensor>,
231 axis_classes: Vec<usize>,
232 },
233}
234
235impl TensorDynLenStorage {
236 fn from_storage(storage: Arc<Storage>) -> Self {
237 Self::Materialized(storage)
238 }
239
240 fn from_eager_dense(inner: EagerTensor, rank: usize) -> Self {
241 Self::Eager {
242 inner: Arc::new(inner),
243 axis_classes: TensorDynLen::dense_axis_classes(rank),
244 }
245 }
246
247 fn eager(&self) -> Option<&EagerTensor> {
248 match self {
249 Self::Materialized(_) => None,
250 Self::Eager { inner, .. } => Some(inner.as_ref()),
251 }
252 }
253
254 fn axis_classes(&self) -> &[usize] {
255 match self {
256 Self::Materialized(storage) => storage.axis_classes(),
257 Self::Eager { axis_classes, .. } => axis_classes,
258 }
259 }
260
261 fn payload_dims(&self) -> &[usize] {
262 match self {
263 Self::Materialized(storage) => storage.payload_dims(),
264 Self::Eager { inner, .. } => inner.data().shape(),
265 }
266 }
267
268 fn payload_strides_vec(&self) -> Vec<isize> {
269 match self {
270 Self::Materialized(storage) => storage.payload_strides().to_vec(),
271 Self::Eager { inner, .. } => {
272 let mut stride = 1isize;
273 inner
274 .data()
275 .shape()
276 .iter()
277 .map(|&dim| {
278 let current = stride;
279 stride *= isize::try_from(dim).unwrap_or(isize::MAX);
280 current
281 })
282 .collect()
283 }
284 }
285 }
286
287 fn is_f64(&self) -> bool {
288 match self {
289 Self::Materialized(storage) => storage.is_f64(),
290 Self::Eager { inner, .. } => inner.data().dtype() == DType::F64,
291 }
292 }
293
294 fn is_c64(&self) -> bool {
295 match self {
296 Self::Materialized(storage) => storage.is_c64(),
297 Self::Eager { inner, .. } => inner.data().dtype() == DType::C64,
298 }
299 }
300
301 fn is_complex(&self) -> bool {
302 match self {
303 Self::Materialized(storage) => storage.is_complex(),
304 Self::Eager { inner, .. } => matches!(inner.data().dtype(), DType::C32 | DType::C64),
305 }
306 }
307
308 fn is_diag(&self) -> bool {
309 match self {
310 Self::Materialized(storage) => storage.is_diag(),
311 Self::Eager { axis_classes, .. } => TensorDynLen::is_diag_axis_classes(axis_classes),
312 }
313 }
314
315 fn storage_kind(&self) -> StorageKind {
316 match self {
317 Self::Materialized(storage) => storage.storage_kind(),
318 Self::Eager { axis_classes, .. } => {
319 if axis_classes.iter().copied().eq(0..axis_classes.len()) {
320 StorageKind::Dense
321 } else if TensorDynLen::is_diag_axis_classes(axis_classes) {
322 StorageKind::Diagonal
323 } else {
324 StorageKind::Structured
325 }
326 }
327 }
328 }
329
330 fn materialize(&self, logical_rank: usize) -> Result<Arc<Storage>> {
331 match self {
332 Self::Materialized(storage) => Ok(Arc::clone(storage)),
333 Self::Eager {
334 inner,
335 axis_classes,
336 } => Ok(Arc::new(
337 TensorDynLen::storage_from_native_with_axis_classes(
338 inner.data(),
339 axis_classes,
340 logical_rank,
341 )?,
342 )),
343 }
344 }
345
346 fn scale(&self, scalar: &BackendScalar) -> Result<Storage> {
347 Ok(self.materialize(self.axis_classes().len())?.scale(scalar))
348 }
349
350 fn conj(&self) -> Result<Self> {
351 match self {
352 Self::Materialized(storage) => Ok(Self::Materialized(Arc::new(storage.conj()))),
353 Self::Eager {
354 inner,
355 axis_classes,
356 } => Ok(Self::Eager {
357 inner: Arc::new(inner.conj()?),
358 axis_classes: axis_classes.clone(),
359 }),
360 }
361 }
362
363 fn max_abs(&self) -> Result<f64> {
364 Ok(self.materialize(self.axis_classes().len())?.max_abs())
365 }
366}
367
368#[derive(Clone)]
420pub struct TensorDynLen {
421 pub indices: Vec<DynIndex>,
423 pub(crate) storage: TensorDynLenStorage,
425 pub(crate) structured_ad: Option<Arc<StructuredAdValue>>,
427 pub(crate) eager_cache: Arc<OnceLock<Arc<EagerTensor>>>,
429}
430
431impl TensorDynLen {
432 fn dense_axis_classes(rank: usize) -> Vec<usize> {
433 (0..rank).collect()
434 }
435
436 fn diag_axis_classes(rank: usize) -> Vec<usize> {
437 if rank == 0 {
438 vec![]
439 } else {
440 vec![0; rank]
441 }
442 }
443
444 fn canonicalize_axis_classes(axis_classes: &[usize]) -> Vec<usize> {
445 let mut map = std::collections::HashMap::new();
446 let mut next = 0usize;
447 axis_classes
448 .iter()
449 .map(|&class_id| {
450 *map.entry(class_id).or_insert_with(|| {
451 let canonical = next;
452 next += 1;
453 canonical
454 })
455 })
456 .collect()
457 }
458
459 fn permute_axis_classes(&self, perm: &[usize]) -> Vec<usize> {
460 let axis_classes = self.storage.axis_classes();
461 let permuted: Vec<usize> = perm.iter().map(|&index| axis_classes[index]).collect();
462 Self::canonicalize_axis_classes(&permuted)
463 }
464
465 fn normalize_insert_axis(op: &str, axis: isize, rank: usize) -> Result<usize> {
466 let normalized = if axis < 0 {
467 rank as isize + 1 + axis
468 } else {
469 axis
470 };
471 anyhow::ensure!(
472 normalized >= 0 && normalized <= rank as isize,
473 "{op}: axis {axis} is out of bounds for inserting into rank {rank}"
474 );
475 Ok(normalized as usize)
476 }
477
478 fn is_diag_axis_classes(axis_classes: &[usize]) -> bool {
479 axis_classes.len() >= 2 && axis_classes.iter().all(|&class_id| class_id == 0)
480 }
481
482 fn einsum_subscripts_from_usize_ids(
483 inputs: &[Vec<usize>],
484 output: &[usize],
485 ) -> Result<EinsumSubscripts> {
486 let input_labels = inputs
487 .iter()
488 .map(|ids| {
489 ids.iter()
490 .map(|&id| {
491 u32::try_from(id)
492 .map_err(|_| anyhow::anyhow!("einsum label {id} exceeds u32 range"))
493 })
494 .collect::<Result<Vec<_>>>()
495 })
496 .collect::<Result<Vec<_>>>()?;
497 let output_labels = output
498 .iter()
499 .map(|&id| {
500 u32::try_from(id)
501 .map_err(|_| anyhow::anyhow!("einsum label {id} exceeds u32 range"))
502 })
503 .collect::<Result<Vec<_>>>()?;
504 let input_refs = input_labels.iter().map(Vec::as_slice).collect::<Vec<_>>();
505 Ok(EinsumSubscripts::new(&input_refs, &output_labels))
506 }
507
508 fn build_binary_einsum_subscripts(
509 lhs_rank: usize,
510 axes_a: &[usize],
511 rhs_rank: usize,
512 axes_b: &[usize],
513 ) -> Result<EinsumSubscripts> {
514 anyhow::ensure!(
515 axes_a.len() == axes_b.len(),
516 "contract axis length mismatch: lhs {:?}, rhs {:?}",
517 axes_a,
518 axes_b
519 );
520
521 let mut lhs_ids = vec![usize::MAX; lhs_rank];
522 let mut rhs_ids = vec![usize::MAX; rhs_rank];
523 let mut next_id = 0usize;
524
525 let mut seen_lhs = vec![false; lhs_rank];
526 let mut seen_rhs = vec![false; rhs_rank];
527
528 for (&lhs_axis, &rhs_axis) in axes_a.iter().zip(axes_b.iter()) {
529 anyhow::ensure!(
530 lhs_axis < lhs_rank,
531 "lhs contract axis {lhs_axis} out of range"
532 );
533 anyhow::ensure!(
534 rhs_axis < rhs_rank,
535 "rhs contract axis {rhs_axis} out of range"
536 );
537 anyhow::ensure!(
538 !seen_lhs[lhs_axis],
539 "duplicate lhs contract axis {lhs_axis}"
540 );
541 anyhow::ensure!(
542 !seen_rhs[rhs_axis],
543 "duplicate rhs contract axis {rhs_axis}"
544 );
545 seen_lhs[lhs_axis] = true;
546 seen_rhs[rhs_axis] = true;
547 lhs_ids[lhs_axis] = next_id;
548 rhs_ids[rhs_axis] = next_id;
549 next_id += 1;
550 }
551
552 let mut output_ids = Vec::with_capacity(lhs_rank + rhs_rank - 2 * axes_a.len());
553 for id in &mut lhs_ids {
554 if *id == usize::MAX {
555 *id = next_id;
556 output_ids.push(next_id);
557 next_id += 1;
558 }
559 }
560 for id in &mut rhs_ids {
561 if *id == usize::MAX {
562 *id = next_id;
563 output_ids.push(next_id);
564 next_id += 1;
565 }
566 }
567
568 Self::einsum_subscripts_from_usize_ids(&[lhs_ids, rhs_ids], &output_ids)
569 }
570
571 fn binary_dot_general_config(axes_a: &[usize], axes_b: &[usize]) -> Result<DotGeneralConfig> {
572 anyhow::ensure!(
573 axes_a.len() == axes_b.len(),
574 "contract axis length mismatch: lhs {:?}, rhs {:?}",
575 axes_a,
576 axes_b
577 );
578 Ok(DotGeneralConfig {
579 lhs_contracting_dims: axes_a.to_vec(),
580 rhs_contracting_dims: axes_b.to_vec(),
581 lhs_batch_dims: vec![],
582 rhs_batch_dims: vec![],
583 })
584 }
585
586 fn binary_contraction_axis_classes(
587 lhs_axis_classes: &[usize],
588 axes_a: &[usize],
589 rhs_axis_classes: &[usize],
590 axes_b: &[usize],
591 ) -> Vec<usize> {
592 debug_assert_eq!(axes_a.len(), axes_b.len());
593
594 fn find(parent: &mut [usize], value: usize) -> usize {
595 if parent[value] != value {
596 parent[value] = find(parent, parent[value]);
597 }
598 parent[value]
599 }
600
601 fn union(parent: &mut [usize], lhs: usize, rhs: usize) {
602 let lhs_root = find(parent, lhs);
603 let rhs_root = find(parent, rhs);
604 if lhs_root != rhs_root {
605 parent[rhs_root] = lhs_root;
606 }
607 }
608
609 let lhs_payload_rank = lhs_axis_classes
610 .iter()
611 .copied()
612 .max()
613 .map(|value| value + 1)
614 .unwrap_or(0);
615 let rhs_payload_rank = rhs_axis_classes
616 .iter()
617 .copied()
618 .max()
619 .map(|value| value + 1)
620 .unwrap_or(0);
621 let rhs_offset = lhs_payload_rank;
622 let mut parent: Vec<usize> = (0..lhs_payload_rank + rhs_payload_rank).collect();
623
624 for (&lhs_axis, &rhs_axis) in axes_a.iter().zip(axes_b.iter()) {
625 union(
626 &mut parent,
627 lhs_axis_classes[lhs_axis],
628 rhs_offset + rhs_axis_classes[rhs_axis],
629 );
630 }
631
632 let mut lhs_contracted = vec![false; lhs_axis_classes.len()];
633 for &axis in axes_a {
634 lhs_contracted[axis] = true;
635 }
636 let mut rhs_contracted = vec![false; rhs_axis_classes.len()];
637 for &axis in axes_b {
638 rhs_contracted[axis] = true;
639 }
640
641 let mut root_to_class = std::collections::HashMap::new();
642 let mut next_class = 0usize;
643 let mut axis_classes = Vec::new();
644
645 for (axis, &class_id) in lhs_axis_classes.iter().enumerate() {
646 if !lhs_contracted[axis] {
647 let root = find(&mut parent, class_id);
648 let class = *root_to_class.entry(root).or_insert_with(|| {
649 let value = next_class;
650 next_class += 1;
651 value
652 });
653 axis_classes.push(class);
654 }
655 }
656 for (axis, &class_id) in rhs_axis_classes.iter().enumerate() {
657 if !rhs_contracted[axis] {
658 let root = find(&mut parent, rhs_offset + class_id);
659 let class = *root_to_class.entry(root).or_insert_with(|| {
660 let value = next_class;
661 next_class += 1;
662 value
663 });
664 axis_classes.push(class);
665 }
666 }
667
668 axis_classes
669 }
670
671 fn scale_subscripts(rank: usize) -> Result<EinsumSubscripts> {
672 let ids: Vec<usize> = (0..rank).collect();
673 Self::einsum_subscripts_from_usize_ids(&[ids.clone(), Vec::new()], &ids)
674 }
675
676 fn validate_indices(indices: &[DynIndex]) -> Result<()> {
677 let mut seen = HashSet::new();
678 for idx in indices {
679 anyhow::ensure!(
680 seen.insert(idx.clone()),
681 "Tensor indices must all be unique"
682 );
683 }
684 Ok(())
685 }
686
687 fn validate_diag_dims(dims: &[usize]) -> Result<()> {
688 if !dims.is_empty() {
689 let first_dim = dims[0];
690 for (i, &dim) in dims.iter().enumerate() {
691 anyhow::ensure!(
692 dim == first_dim,
693 "DiagTensor requires all indices to have the same dimension, but dims[{i}] = {dim} != dims[0] = {first_dim}"
694 );
695 }
696 }
697 Ok(())
698 }
699
700 fn seed_native_payload(storage: &Storage, dims: &[usize]) -> Result<NativeTensor> {
701 storage_to_native_tensor(storage, dims)
702 }
703
704 fn empty_eager_cache() -> Arc<OnceLock<Arc<EagerTensor>>> {
705 Arc::new(OnceLock::new())
706 }
707
708 fn eager_cache_with(inner: EagerTensor) -> Arc<OnceLock<Arc<EagerTensor>>> {
709 let cache = Arc::new(OnceLock::new());
710 let _ = cache.set(Arc::new(inner));
711 cache
712 }
713
714 fn compact_payload_inner(&self) -> Result<EagerTensor> {
715 Ok(EagerTensor::from_tensor_in(
716 storage_payload_native(self.storage.materialize(self.indices.len())?.as_ref())?,
717 default_eager_ctx(),
718 ))
719 }
720
721 fn tracked_compact_payload_value(&self) -> Option<&StructuredAdValue> {
722 self.structured_ad.as_deref()
723 }
724
725 fn compact_payload_is_logical_dense(&self, payload_dims: &[usize]) -> bool {
726 self.storage.axis_classes() == Self::dense_axis_classes(self.indices.len())
727 && payload_dims == self.dims()
728 }
729
730 fn uses_tracked_compact_storage(&self) -> bool {
731 self.tracked_compact_payload_value()
732 .is_some_and(|value| !self.compact_payload_is_logical_dense(&value.payload_dims))
733 }
734
735 fn ensure_shape_packing_preserves_ad(&self, op_name: &str) -> Result<()> {
736 anyhow::ensure!(
737 !self.uses_tracked_compact_storage(),
738 "{op_name}: structured AD tensors with compact storage are not supported because materializing compact storage would detach gradients"
739 );
740 Ok(())
741 }
742
743 fn operand_indices_for_contraction(&self, conjugate: bool) -> Vec<DynIndex> {
744 if conjugate {
745 self.indices.iter().map(|index| index.conj()).collect()
746 } else {
747 self.indices.clone()
748 }
749 }
750
751 fn build_binary_contraction_labels(
752 lhs_rank: usize,
753 axes_a: &[usize],
754 rhs_rank: usize,
755 axes_b: &[usize],
756 ) -> Result<(Vec<usize>, Vec<usize>, Vec<usize>)> {
757 anyhow::ensure!(
758 axes_a.len() == axes_b.len(),
759 "contract axis length mismatch: lhs {:?}, rhs {:?}",
760 axes_a,
761 axes_b
762 );
763
764 let mut lhs_ids = vec![usize::MAX; lhs_rank];
765 let mut rhs_ids = vec![usize::MAX; rhs_rank];
766 let mut next_id = 0usize;
767
768 let mut seen_lhs = vec![false; lhs_rank];
769 let mut seen_rhs = vec![false; rhs_rank];
770
771 for (&lhs_axis, &rhs_axis) in axes_a.iter().zip(axes_b.iter()) {
772 anyhow::ensure!(
773 lhs_axis < lhs_rank,
774 "lhs contract axis {lhs_axis} out of range"
775 );
776 anyhow::ensure!(
777 rhs_axis < rhs_rank,
778 "rhs contract axis {rhs_axis} out of range"
779 );
780 anyhow::ensure!(
781 !seen_lhs[lhs_axis],
782 "duplicate lhs contract axis {lhs_axis}"
783 );
784 anyhow::ensure!(
785 !seen_rhs[rhs_axis],
786 "duplicate rhs contract axis {rhs_axis}"
787 );
788 seen_lhs[lhs_axis] = true;
789 seen_rhs[rhs_axis] = true;
790 lhs_ids[lhs_axis] = next_id;
791 rhs_ids[rhs_axis] = next_id;
792 next_id += 1;
793 }
794
795 let mut output_ids = Vec::with_capacity(lhs_rank + rhs_rank - 2 * axes_a.len());
796 for id in &mut lhs_ids {
797 if *id == usize::MAX {
798 *id = next_id;
799 output_ids.push(next_id);
800 next_id += 1;
801 }
802 }
803 for id in &mut rhs_ids {
804 if *id == usize::MAX {
805 *id = next_id;
806 output_ids.push(next_id);
807 next_id += 1;
808 }
809 }
810
811 Ok((lhs_ids, rhs_ids, output_ids))
812 }
813
814 fn build_payload_einsum_subscripts(
815 input_roots: &[Vec<usize>],
816 output_roots: &[usize],
817 ) -> Result<EinsumSubscripts> {
818 Self::einsum_subscripts_from_usize_ids(input_roots, output_roots)
819 }
820
821 fn normalize_eager_payload_for_roots(
822 payload: &EagerTensor,
823 roots: &[usize],
824 ) -> Result<(Option<EagerTensor>, Vec<usize>)> {
825 anyhow::ensure!(
826 payload.data().shape().len() == roots.len(),
827 "payload rank {} does not match root label count {}",
828 payload.data().shape().len(),
829 roots.len()
830 );
831
832 let mut current_payload = None;
833 let mut current_roots = roots.to_vec();
834 while let Some((axis_a, axis_b)) = Self::first_duplicate_pair(¤t_roots) {
835 let source = current_payload.as_ref().unwrap_or(payload);
836 current_payload = Some(source.extract_diag(axis_a, axis_b)?);
837 current_roots.remove(axis_b);
838 }
839
840 Ok((current_payload, current_roots))
841 }
842
843 fn first_duplicate_pair(values: &[usize]) -> Option<(usize, usize)> {
844 let mut first_axis_by_value = std::collections::HashMap::new();
845 for (axis, &value) in values.iter().enumerate() {
846 if let Some(&first_axis) = first_axis_by_value.get(&value) {
847 return Some((first_axis, axis));
848 }
849 first_axis_by_value.insert(value, axis);
850 }
851 None
852 }
853
854 fn binary_structured_contraction_plan(
855 &self,
856 other: &Self,
857 axes_a: &[usize],
858 axes_b: &[usize],
859 ) -> Result<(StructuredContractionPlan, Vec<Vec<usize>>, Vec<usize>)> {
860 let (lhs_labels, rhs_labels, output_labels) = Self::build_binary_contraction_labels(
861 self.indices.len(),
862 axes_a,
863 other.indices.len(),
864 axes_b,
865 )?;
866 let operands = vec![
867 OperandLayout::new(self.dims(), self.storage.axis_classes().to_vec())?,
868 OperandLayout::new(other.dims(), other.storage.axis_classes().to_vec())?,
869 ];
870 let spec = StructuredContractionSpec {
871 input_labels: vec![lhs_labels, rhs_labels],
872 output_labels,
873 retained_labels: Default::default(),
874 };
875 let plan = StructuredContractionPlan::new(&operands, &spec)?;
876 Ok((plan, spec.input_labels, spec.output_labels))
877 }
878
879 fn from_structured_payload_inner(
880 indices: Vec<DynIndex>,
881 payload_inner: EagerTensor,
882 payload_dims: Vec<usize>,
883 axis_classes: Vec<usize>,
884 ) -> Result<Self> {
885 Self::validate_indices(&indices)?;
886 if payload_inner.data().shape() != payload_dims {
887 return Err(anyhow::anyhow!(
888 "structured payload dims {:?} do not match planned payload dims {:?}",
889 payload_inner.data().shape(),
890 payload_dims
891 ));
892 }
893 let storage = storage_from_payload_native(
894 payload_inner.data().clone(),
895 &payload_dims,
896 axis_classes.clone(),
897 )?;
898 Self::validate_storage_matches_indices(&indices, &storage)?;
899 Ok(Self {
900 indices,
901 storage: TensorDynLenStorage::from_storage(Arc::new(storage)),
902 structured_ad: Some(Arc::new(StructuredAdValue {
903 payload: Arc::new(payload_inner),
904 payload_dims,
905 axis_classes,
906 })),
907 eager_cache: Self::empty_eager_cache(),
908 })
909 }
910
911 fn contract_structured_payloads(
912 &self,
913 other: &Self,
914 result_indices: Vec<DynIndex>,
915 axes_a: &[usize],
916 axes_b: &[usize],
917 ) -> Result<Self> {
918 let (plan, _, _) = self.binary_structured_contraction_plan(other, axes_a, axes_b)?;
919 let lhs_roots = plan.operand_plans[0].class_roots.clone();
920 let rhs_roots = plan.operand_plans[1].class_roots.clone();
921 let scalar_multiply =
922 lhs_roots.is_empty() && rhs_roots.is_empty() && plan.output_payload_roots.is_empty();
923
924 if let (Some(lhs_ad), Some(rhs_ad)) = (
925 self.tracked_compact_payload_value(),
926 other.tracked_compact_payload_value(),
927 ) {
928 if lhs_ad.payload.data().dtype() != rhs_ad.payload.data().dtype() {
929 return Err(anyhow::anyhow!(
930 "structured AD contraction requires matching payload dtypes"
931 ));
932 }
933 let (lhs_normalized, lhs_labels) =
934 Self::normalize_eager_payload_for_roots(lhs_ad.payload.as_ref(), &lhs_roots)?;
935 let (rhs_normalized, rhs_labels) =
936 Self::normalize_eager_payload_for_roots(rhs_ad.payload.as_ref(), &rhs_roots)?;
937 let lhs_payload = lhs_normalized
938 .as_ref()
939 .unwrap_or_else(|| lhs_ad.payload.as_ref());
940 let rhs_payload = rhs_normalized
941 .as_ref()
942 .unwrap_or_else(|| rhs_ad.payload.as_ref());
943 let payload = if scalar_multiply {
944 lhs_payload.mul(rhs_payload)?
945 } else {
946 let subscripts = Self::build_payload_einsum_subscripts(
947 &[lhs_labels, rhs_labels],
948 &plan.output_payload_roots,
949 )?;
950 eager_einsum_ad(&[lhs_payload, rhs_payload], &subscripts)?
951 };
952 return Self::from_structured_payload_inner(
953 result_indices,
954 payload,
955 plan.output_payload_dims,
956 plan.output_axis_classes,
957 );
958 }
959
960 if self.tracked_compact_payload_value().is_some()
961 || other.tracked_compact_payload_value().is_some()
962 {
963 let lhs_owned = if self.tracked_compact_payload_value().is_some() {
964 None
965 } else {
966 Some(self.compact_payload_inner()?)
967 };
968 let rhs_owned = if other.tracked_compact_payload_value().is_some() {
969 None
970 } else {
971 Some(other.compact_payload_inner()?)
972 };
973 let lhs = if let Some(value) = self.tracked_compact_payload_value() {
974 value.payload.as_ref()
975 } else {
976 lhs_owned
977 .as_ref()
978 .ok_or_else(|| anyhow::anyhow!("missing untracked left compact payload"))?
979 };
980 let rhs = if let Some(value) = other.tracked_compact_payload_value() {
981 value.payload.as_ref()
982 } else {
983 rhs_owned
984 .as_ref()
985 .ok_or_else(|| anyhow::anyhow!("missing untracked right compact payload"))?
986 };
987 if lhs.data().dtype() != rhs.data().dtype() {
988 return Err(anyhow::anyhow!(
989 "structured AD contraction requires matching payload dtypes"
990 ));
991 }
992 let (lhs_normalized, lhs_labels) =
993 Self::normalize_eager_payload_for_roots(lhs, &lhs_roots)?;
994 let (rhs_normalized, rhs_labels) =
995 Self::normalize_eager_payload_for_roots(rhs, &rhs_roots)?;
996 let lhs_payload = lhs_normalized.as_ref().unwrap_or(lhs);
997 let rhs_payload = rhs_normalized.as_ref().unwrap_or(rhs);
998 let payload = if scalar_multiply {
999 lhs_payload.mul(rhs_payload)?
1000 } else {
1001 let subscripts = Self::build_payload_einsum_subscripts(
1002 &[lhs_labels, rhs_labels],
1003 &plan.output_payload_roots,
1004 )?;
1005 eager_einsum_ad(&[lhs_payload, rhs_payload], &subscripts)?
1006 };
1007 return Self::from_structured_payload_inner(
1008 result_indices,
1009 payload,
1010 plan.output_payload_dims,
1011 plan.output_axis_classes,
1012 );
1013 }
1014
1015 let lhs_storage = self.storage.materialize(self.indices.len())?;
1016 let rhs_storage = other.storage.materialize(other.indices.len())?;
1017 let lhs = storage_payload_native_read_input(lhs_storage.as_ref())?;
1018 let rhs = storage_payload_native_read_input(rhs_storage.as_ref())?;
1019 if lhs.dtype() != rhs.dtype() {
1020 return Err(anyhow::anyhow!(
1021 "structured payload contraction requires matching payload dtypes"
1022 ));
1023 }
1024 let (lhs, lhs_labels) = normalize_payload_read_for_roots(lhs, &lhs_roots)?;
1025 let (rhs, rhs_labels) = normalize_payload_read_for_roots(rhs, &rhs_roots)?;
1026 let payload = tensor4all_tensorbackend::einsum_native_tensor_reads(
1027 &[(&lhs, lhs_labels.as_slice()), (&rhs, rhs_labels.as_slice())],
1028 &plan.output_payload_roots,
1029 )?;
1030 let storage = storage_from_payload_native(
1031 payload,
1032 &plan.output_payload_dims,
1033 plan.output_axis_classes,
1034 )?;
1035 Self::from_storage(result_indices, Arc::new(storage))
1036 }
1037
1038 fn should_use_structured_payload_contract(&self, other: &Self) -> bool {
1039 let same_payload_dtype = self.storage.is_f64() == other.storage.is_f64()
1040 && self.storage.is_complex() == other.storage.is_complex();
1041 same_payload_dtype
1042 && (self.tracked_compact_payload_value().is_some()
1043 || other.tracked_compact_payload_value().is_some()
1044 || self.storage.axis_classes() != Self::dense_axis_classes(self.indices.len())
1045 || other.storage.axis_classes() != Self::dense_axis_classes(other.indices.len()))
1046 }
1047
1048 fn storage_from_native_with_axis_classes(
1049 native: &NativeTensor,
1050 axis_classes: &[usize],
1051 logical_rank: usize,
1052 ) -> Result<Storage> {
1053 if Self::is_diag_axis_classes(axis_classes) {
1054 match native.dtype() {
1055 DType::F32 | DType::F64 | DType::I32 | DType::I64 | DType::Bool => {
1056 Storage::from_diag_col_major(
1057 native_tensor_primal_to_diag_f64(native)?,
1058 logical_rank,
1059 )
1060 }
1061 DType::C32 | DType::C64 => Storage::from_diag_col_major(
1062 native_tensor_primal_to_diag_c64(native)?,
1063 logical_rank,
1064 ),
1065 }
1066 } else {
1067 native_tensor_primal_to_storage(native)
1068 }
1069 }
1070
1071 fn dense_selected_diag_payload<T: TensorElement + Copy + Zero>(
1072 payload: Vec<T>,
1073 kept_dims: &[usize],
1074 selected_positions: &[usize],
1075 ) -> Vec<T> {
1076 let output_len = kept_dims.iter().product::<usize>();
1077 let mut data = vec![T::zero(); output_len];
1078 if output_len == 0 {
1079 return data;
1080 }
1081
1082 let Some((&first_position, rest)) = selected_positions.split_first() else {
1083 return data;
1084 };
1085 if rest.iter().any(|&position| position != first_position) {
1086 return data;
1087 }
1088
1089 let value = payload[first_position];
1090 if kept_dims.is_empty() {
1091 data[0] = value;
1092 return data;
1093 }
1094
1095 let mut offset = 0usize;
1096 let mut stride = 1usize;
1097 for &dim in kept_dims {
1098 offset += first_position * stride;
1099 stride *= dim;
1100 }
1101 data[offset] = value;
1102 data
1103 }
1104
1105 fn select_diag_indices(
1106 &self,
1107 kept_indices: Vec<DynIndex>,
1108 kept_dims: Vec<usize>,
1109 positions: &[usize],
1110 ) -> Result<Self> {
1111 if self.storage.is_f64() {
1112 let storage = self.storage.materialize(self.indices.len())?;
1113 let payload = storage
1114 .payload_f64_col_major_vec()
1115 .map_err(anyhow::Error::msg)?;
1116 let data = Self::dense_selected_diag_payload(payload, &kept_dims, positions);
1117 Self::from_dense(kept_indices, data)
1118 } else if self.storage.is_c64() {
1119 let storage = self.storage.materialize(self.indices.len())?;
1120 let payload = storage
1121 .payload_c64_col_major_vec()
1122 .map_err(anyhow::Error::msg)?;
1123 let data = Self::dense_selected_diag_payload(payload, &kept_dims, positions);
1124 Self::from_dense(kept_indices, data)
1125 } else {
1126 Err(anyhow::anyhow!("unsupported diagonal storage scalar type"))
1127 }
1128 }
1129
1130 fn col_major_strides(dims: &[usize]) -> Result<Vec<isize>> {
1131 let mut strides = Vec::with_capacity(dims.len());
1132 let mut stride = 1isize;
1133 for &dim in dims {
1134 strides.push(stride);
1135 let dim = isize::try_from(dim)
1136 .map_err(|_| anyhow::anyhow!("dimension does not fit in isize"))?;
1137 stride = stride
1138 .checked_mul(dim)
1139 .ok_or_else(|| anyhow::anyhow!("column-major stride overflow"))?;
1140 }
1141 Ok(strides)
1142 }
1143
1144 fn zero_structured_selection<T>(
1145 kept_indices: Vec<DynIndex>,
1146 kept_dims: &[usize],
1147 ) -> Result<Self>
1148 where
1149 T: TensorElement + Zero,
1150 {
1151 let output_len = checked_product(kept_dims)?;
1152 Self::from_dense(kept_indices, vec![T::zero(); output_len])
1153 }
1154
1155 fn select_structured_indices_typed<T>(
1156 &self,
1157 payload: Vec<T>,
1158 kept_axes: &[usize],
1159 kept_indices: Vec<DynIndex>,
1160 kept_dims: Vec<usize>,
1161 selected_axes: &[usize],
1162 positions: &[usize],
1163 ) -> Result<Self>
1164 where
1165 T: TensorElement + StorageScalar + Zero,
1166 {
1167 let payload_dims = self.storage.payload_dims();
1168 let axis_classes = self.storage.axis_classes();
1169 let payload_rank = payload_dims.len();
1170 let mut selected_class_positions = vec![None; payload_rank];
1171
1172 for (&axis, &position) in selected_axes.iter().zip(positions.iter()) {
1173 let class_id = axis_classes[axis];
1174 if let Some(existing) = selected_class_positions[class_id] {
1175 if existing != position {
1176 return Self::zero_structured_selection::<T>(kept_indices, &kept_dims);
1177 }
1178 } else {
1179 selected_class_positions[class_id] = Some(position);
1180 }
1181 }
1182
1183 let selected_class_kept = kept_axes
1184 .iter()
1185 .any(|&axis| selected_class_positions[axis_classes[axis]].is_some());
1186 if selected_class_kept {
1187 return self.select_structured_indices_dense(
1188 payload,
1189 kept_axes,
1190 kept_indices,
1191 kept_dims,
1192 &selected_class_positions,
1193 );
1194 }
1195
1196 let mut old_to_new_class = vec![None; payload_rank];
1197 let mut output_payload_dims = Vec::new();
1198 let mut output_axis_classes = Vec::with_capacity(kept_axes.len());
1199 for &axis in kept_axes {
1200 let class_id = axis_classes[axis];
1201 let new_class = match old_to_new_class[class_id] {
1202 Some(new_class) => new_class,
1203 None => {
1204 let new_class = output_payload_dims.len();
1205 old_to_new_class[class_id] = Some(new_class);
1206 output_payload_dims.push(payload_dims[class_id]);
1207 new_class
1208 }
1209 };
1210 output_axis_classes.push(new_class);
1211 }
1212
1213 let output_len = checked_product(&output_payload_dims)?;
1214 let mut output_payload = Vec::with_capacity(output_len);
1215 for linear in 0..output_len {
1216 let output_payload_index = decode_col_major_linear(linear, &output_payload_dims)?;
1217 let mut input_payload_index = vec![0usize; payload_rank];
1218 for class_id in 0..payload_rank {
1219 input_payload_index[class_id] =
1220 if let Some(position) = selected_class_positions[class_id] {
1221 position
1222 } else if let Some(new_class) = old_to_new_class[class_id] {
1223 output_payload_index[new_class]
1224 } else {
1225 return Err(anyhow::anyhow!(
1226 "structured payload class {class_id} is neither selected nor kept"
1227 ));
1228 };
1229 }
1230 let input_linear = encode_col_major_linear(&input_payload_index, payload_dims)?;
1231 output_payload.push(payload[input_linear]);
1232 }
1233
1234 let output_strides = Self::col_major_strides(&output_payload_dims)?;
1235 let storage = Storage::new_structured(
1236 output_payload,
1237 output_payload_dims,
1238 output_strides,
1239 output_axis_classes,
1240 )?;
1241 Self::from_storage(kept_indices, Arc::new(storage))
1242 }
1243
1244 fn select_structured_indices_dense<T>(
1245 &self,
1246 payload: Vec<T>,
1247 kept_axes: &[usize],
1248 kept_indices: Vec<DynIndex>,
1249 kept_dims: Vec<usize>,
1250 selected_class_positions: &[Option<usize>],
1251 ) -> Result<Self>
1252 where
1253 T: TensorElement + Zero,
1254 {
1255 let payload_dims = self.storage.payload_dims();
1256 let axis_classes = self.storage.axis_classes();
1257 let output_len = checked_product(&kept_dims)?;
1258 let mut output = Vec::with_capacity(output_len);
1259
1260 for linear in 0..output_len {
1261 let kept_position = decode_col_major_linear(linear, &kept_dims)?;
1262 let mut input_payload_index = selected_class_positions.to_vec();
1263 let mut is_structural_zero = false;
1264
1265 for (&axis, &position) in kept_axes.iter().zip(kept_position.iter()) {
1266 let class_id = axis_classes[axis];
1267 match input_payload_index[class_id] {
1268 Some(existing) if existing != position => {
1269 is_structural_zero = true;
1270 break;
1271 }
1272 Some(_) => {}
1273 None => input_payload_index[class_id] = Some(position),
1274 }
1275 }
1276
1277 if is_structural_zero {
1278 output.push(T::zero());
1279 continue;
1280 }
1281
1282 let input_payload_index = input_payload_index
1283 .into_iter()
1284 .enumerate()
1285 .map(|(class_id, position)| {
1286 position.ok_or_else(|| {
1287 anyhow::anyhow!(
1288 "structured payload class {class_id} is neither selected nor kept"
1289 )
1290 })
1291 })
1292 .collect::<Result<Vec<_>>>()?;
1293 let input_linear = encode_col_major_linear(&input_payload_index, payload_dims)?;
1294 output.push(payload[input_linear]);
1295 }
1296
1297 Self::from_dense(kept_indices, output)
1298 }
1299
1300 fn select_structured_indices(
1301 &self,
1302 kept_axes: &[usize],
1303 kept_indices: Vec<DynIndex>,
1304 kept_dims: Vec<usize>,
1305 selected_axes: &[usize],
1306 positions: &[usize],
1307 ) -> Result<Self> {
1308 if self.storage.is_f64() {
1309 let storage = self.storage.materialize(self.indices.len())?;
1310 let payload = storage
1311 .payload_f64_col_major_vec()
1312 .map_err(anyhow::Error::msg)?;
1313 self.select_structured_indices_typed(
1314 payload,
1315 kept_axes,
1316 kept_indices,
1317 kept_dims,
1318 selected_axes,
1319 positions,
1320 )
1321 } else if self.storage.is_c64() {
1322 let storage = self.storage.materialize(self.indices.len())?;
1323 let payload = storage
1324 .payload_c64_col_major_vec()
1325 .map_err(anyhow::Error::msg)?;
1326 self.select_structured_indices_typed(
1327 payload,
1328 kept_axes,
1329 kept_indices,
1330 kept_dims,
1331 selected_axes,
1332 positions,
1333 )
1334 } else {
1335 Err(anyhow::anyhow!(
1336 "unsupported structured storage scalar type"
1337 ))
1338 }
1339 }
1340
1341 fn validate_storage_matches_indices(indices: &[DynIndex], storage: &Storage) -> Result<()> {
1342 let dims = Self::expected_dims_from_indices(indices);
1343 let storage_dims = storage.logical_dims();
1344 if storage_dims != dims {
1345 return Err(anyhow::anyhow!(
1346 "storage logical dims {:?} do not match indices dims {:?}",
1347 storage_dims,
1348 dims
1349 ));
1350 }
1351 if storage.is_diag() {
1352 Self::validate_diag_dims(&dims)?;
1353 }
1354 Ok(())
1355 }
1356
1357 fn try_materialized_inner(&self) -> Result<&EagerTensor> {
1358 if let Some(value) = self.tracked_compact_payload_value() {
1359 if self.compact_payload_is_logical_dense(&value.payload_dims) {
1360 return Ok(value.payload.as_ref());
1361 }
1362 }
1363 if let Some(inner) = self.storage.eager() {
1364 return Ok(inner);
1365 }
1366 if self.eager_cache.get().is_none() {
1367 let dims = self.dims();
1368 let native = profile_pairwise_contract_section("materialize_storage_to_native", || {
1369 let storage = self.storage.materialize(self.indices.len())?;
1370 Self::seed_native_payload(storage.as_ref(), &dims)
1371 })
1372 .context("TensorDynLen materialization failed")?;
1373 record_pairwise_contract_profile_bytes(
1374 "materialize_storage_to_native",
1375 native_tensor_profile_bytes(&native),
1376 );
1377 let _ = self.eager_cache.set(Arc::new(EagerTensor::from_tensor_in(
1378 native,
1379 default_eager_ctx(),
1380 )));
1381 }
1382 self.eager_cache
1383 .get()
1384 .map(|inner| inner.as_ref())
1385 .ok_or_else(|| {
1386 anyhow::anyhow!("TensorDynLen materialization cache was not initialized")
1387 })
1388 }
1389
1390 pub(crate) fn as_inner(&self) -> Result<&EagerTensor> {
1391 self.try_materialized_inner()
1392 }
1393
1394 #[inline]
1396 fn expected_dims_from_indices(indices: &[DynIndex]) -> Vec<usize> {
1397 indices.iter().map(|idx| idx.dim()).collect()
1398 }
1399
1400 pub fn dims(&self) -> Vec<usize> {
1419 Self::expected_dims_from_indices(&self.indices)
1420 }
1421
1422 pub fn select_indices(
1464 &self,
1465 selected_indices: &[DynIndex],
1466 positions: &[usize],
1467 ) -> Result<Self> {
1468 if selected_indices.len() != positions.len() {
1469 return Err(anyhow::anyhow!(
1470 "selected_indices length {} does not match positions length {}",
1471 selected_indices.len(),
1472 positions.len()
1473 ));
1474 }
1475 if selected_indices.is_empty() {
1476 return Ok(self.clone());
1477 }
1478
1479 let mut selected_axes = Vec::with_capacity(selected_indices.len());
1480 let mut seen_axes = HashSet::with_capacity(selected_indices.len());
1481 for (selected, &position) in selected_indices.iter().zip(positions.iter()) {
1482 let axis = self
1483 .indices
1484 .iter()
1485 .position(|index| index == selected)
1486 .ok_or_else(|| anyhow::anyhow!("selected index is not present in tensor"))?;
1487 if !seen_axes.insert(axis) {
1488 return Err(anyhow::anyhow!("selected index appears more than once"));
1489 }
1490 let dim = self.indices[axis].dim();
1491 if position >= dim {
1492 return Err(anyhow::anyhow!(
1493 "selected coordinate {position} is out of range for axis {axis} with dim {dim}"
1494 ));
1495 }
1496 selected_axes.push(axis);
1497 }
1498
1499 let kept_axes = self
1500 .indices
1501 .iter()
1502 .enumerate()
1503 .filter(|(axis, _)| !seen_axes.contains(axis))
1504 .map(|(axis, _)| axis)
1505 .collect::<Vec<_>>();
1506 let kept_indices = kept_axes
1507 .iter()
1508 .map(|&axis| self.indices[axis].clone())
1509 .collect::<Vec<_>>();
1510 let kept_dims = kept_axes
1511 .iter()
1512 .map(|&axis| self.indices[axis].dim())
1513 .collect::<Vec<_>>();
1514
1515 if self.storage.storage_kind() == StorageKind::Diagonal {
1516 return self.select_diag_indices(kept_indices, kept_dims, positions);
1517 }
1518 if self.storage.storage_kind() == StorageKind::Structured {
1519 return self.select_structured_indices(
1520 &kept_axes,
1521 kept_indices,
1522 kept_dims,
1523 &selected_axes,
1524 positions,
1525 );
1526 }
1527 if self.storage.storage_kind() != StorageKind::Dense {
1528 return Err(anyhow::anyhow!(
1529 "select_indices got unsupported storage kind {:?}",
1530 self.storage.storage_kind()
1531 ));
1532 }
1533
1534 let rank = self.indices.len();
1535 let mut starts = vec![0_i64; rank];
1536 let mut slice_sizes = self.dims();
1537 for (&axis, &position) in selected_axes.iter().zip(positions.iter()) {
1538 starts[axis] = i64::try_from(position)
1539 .map_err(|_| anyhow::anyhow!("selected coordinate does not fit in i64"))?;
1540 slice_sizes[axis] = 1;
1541 }
1542
1543 let starts_tensor = EagerTensor::from_tensor_in(
1544 NativeTensor::from_vec_col_major(vec![rank], starts),
1545 default_eager_ctx(),
1546 );
1547 let sliced = self
1548 .try_materialized_inner()?
1549 .dynamic_slice(&starts_tensor, &slice_sizes)?;
1550 Self::from_inner(kept_indices, sliced.reshape(&kept_dims)?)
1551 }
1552
1553 pub fn stack_along_new_index(
1586 tensors: &[&Self],
1587 new_index: DynIndex,
1588 axis: isize,
1589 ) -> Result<Self> {
1590 let first = tensors
1591 .first()
1592 .copied()
1593 .ok_or_else(|| anyhow::anyhow!("stack_along_new_index requires at least one tensor"))?;
1594 anyhow::ensure!(
1595 new_index.dim() == tensors.len(),
1596 "stack_along_new_index: new index dim {} does not match tensor count {}",
1597 new_index.dim(),
1598 tensors.len()
1599 );
1600
1601 let base_indices = first.indices.clone();
1602 for tensor in tensors.iter().copied().skip(1) {
1603 anyhow::ensure!(
1604 tensor.indices == base_indices,
1605 "stack_along_new_index: input tensors must have identical index order"
1606 );
1607 }
1608 for &tensor in tensors {
1609 tensor.ensure_shape_packing_preserves_ad("stack_along_new_index")?;
1610 }
1611
1612 let insert_axis =
1613 Self::normalize_insert_axis("stack_along_new_index", axis, base_indices.len())?;
1614 let mut result_indices = base_indices;
1615 result_indices.insert(insert_axis, new_index);
1616
1617 let inner_refs = tensors
1618 .iter()
1619 .map(|tensor| tensor.try_materialized_inner())
1620 .collect::<Result<Vec<_>>>()?;
1621 let stacked = EagerTensor::stack(&inner_refs, axis)?;
1622 Self::from_inner(result_indices, stacked)
1623 }
1624
1625 pub fn index_select(
1658 &self,
1659 source_index: &DynIndex,
1660 target_index: DynIndex,
1661 positions: &[usize],
1662 ) -> Result<Self> {
1663 anyhow::ensure!(
1664 target_index.dim() == positions.len(),
1665 "index_select: target index dim {} does not match position count {}",
1666 target_index.dim(),
1667 positions.len()
1668 );
1669 let axis = self
1670 .indices
1671 .iter()
1672 .position(|index| index == source_index)
1673 .ok_or_else(|| anyhow::anyhow!("index_select: source index is not present"))?;
1674 let source_dim = self.indices[axis].dim();
1675 for &position in positions {
1676 anyhow::ensure!(
1677 position < source_dim,
1678 "index_select: position {position} is out of range for source dim {source_dim}"
1679 );
1680 }
1681 self.ensure_shape_packing_preserves_ad("index_select")?;
1682
1683 let axis = isize::try_from(axis)
1684 .map_err(|_| anyhow::anyhow!("index_select: axis does not fit in isize"))?;
1685 let selected = self
1686 .try_materialized_inner()?
1687 .index_select(axis, positions)?;
1688 let mut result_indices = self.indices.clone();
1689 result_indices[axis as usize] = target_index;
1690 Self::from_inner(result_indices, selected)
1691 }
1692
1693 pub fn new(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Result<Self> {
1713 Self::from_storage(indices, storage)
1714 }
1715
1716 pub fn from_indices(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Result<Self> {
1738 Self::new(indices, storage)
1739 }
1740
1741 pub fn from_storage(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Result<Self> {
1757 Self::validate_indices(&indices)?;
1758 Self::validate_storage_matches_indices(&indices, storage.as_ref())?;
1759 Ok(Self {
1760 indices,
1761 storage: TensorDynLenStorage::from_storage(storage),
1762 structured_ad: None,
1763 eager_cache: Self::empty_eager_cache(),
1764 })
1765 }
1766
1767 pub fn from_structured_storage(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Result<Self> {
1791 Self::from_storage(indices, storage)
1792 }
1793
1794 pub(crate) fn from_native(indices: Vec<DynIndex>, native: NativeTensor) -> Result<Self> {
1796 let axis_classes = Self::dense_axis_classes(indices.len());
1797 Self::from_native_with_axis_classes(indices, native, axis_classes)
1798 }
1799
1800 pub(crate) fn from_native_with_axis_classes(
1801 indices: Vec<DynIndex>,
1802 native: NativeTensor,
1803 axis_classes: Vec<usize>,
1804 ) -> Result<Self> {
1805 Self::from_inner_with_axis_classes(
1806 indices,
1807 EagerTensor::from_tensor_in(native, default_eager_ctx()),
1808 axis_classes,
1809 )
1810 }
1811
1812 pub(crate) fn from_inner(indices: Vec<DynIndex>, inner: EagerTensor) -> Result<Self> {
1813 let axis_classes = Self::dense_axis_classes(indices.len());
1814 Self::from_inner_with_axis_classes(indices, inner, axis_classes)
1815 }
1816
1817 pub(crate) fn from_diag_inner(
1818 indices: Vec<DynIndex>,
1819 payload_inner: EagerTensor,
1820 ) -> Result<Self> {
1821 let dims = Self::expected_dims_from_indices(&indices);
1822 Self::validate_indices(&indices)?;
1823 Self::validate_diag_dims(&dims)?;
1824 Self::validate_diag_payload_len(payload_inner.data().shape().iter().product(), &dims)?;
1825 let axis_classes = Self::diag_axis_classes(dims.len());
1826 let diag_inner = payload_inner.embed_diag(0, 1)?;
1827 Self::from_inner_with_axis_classes(indices, diag_inner, axis_classes)
1828 }
1829
1830 pub(crate) fn from_inner_with_axis_classes(
1831 indices: Vec<DynIndex>,
1832 inner: EagerTensor,
1833 axis_classes: Vec<usize>,
1834 ) -> Result<Self> {
1835 let dims = profile_pairwise_contract_section("from_inner_expected_dims", || {
1836 Self::expected_dims_from_indices(&indices)
1837 });
1838 profile_pairwise_contract_section("from_inner_validate_indices", || {
1839 Self::validate_indices(&indices)
1840 })?;
1841 if dims != inner.data().shape() {
1842 return Err(anyhow::anyhow!(
1843 "native payload dims {:?} do not match indices dims {:?}",
1844 inner.data().shape(),
1845 dims
1846 ));
1847 }
1848 if Self::is_diag_axis_classes(&axis_classes) {
1849 profile_pairwise_contract_section("from_inner_validate_diag_dims", || {
1850 Self::validate_diag_dims(&dims)
1851 })?;
1852 }
1853 let (storage, eager_cache) = if axis_classes == Self::dense_axis_classes(indices.len()) {
1854 (
1855 TensorDynLenStorage::from_eager_dense(inner, indices.len()),
1856 Self::empty_eager_cache(),
1857 )
1858 } else {
1859 let storage = profile_pairwise_contract_section("from_inner_storage_snapshot", || {
1860 Self::storage_from_native_with_axis_classes(
1861 inner.data(),
1862 &axis_classes,
1863 indices.len(),
1864 )
1865 })?;
1866 record_pairwise_contract_profile_bytes(
1867 "from_inner_storage_snapshot",
1868 native_tensor_profile_bytes(inner.data()),
1869 );
1870 (
1871 TensorDynLenStorage::from_storage(Arc::new(storage)),
1872 profile_pairwise_contract_section("from_inner_eager_cache", || {
1873 Self::eager_cache_with(inner)
1874 }),
1875 )
1876 };
1877 Ok(Self {
1878 indices,
1879 storage,
1880 structured_ad: None,
1881 eager_cache,
1882 })
1883 }
1884
1885 pub fn indices(&self) -> &[DynIndex] {
1887 &self.indices
1888 }
1889
1890 pub(crate) fn as_native(&self) -> Result<&NativeTensor> {
1892 Ok(self.try_materialized_inner()?.data())
1893 }
1894
1895 pub fn enable_grad(self) -> Result<Self> {
1897 let materialized = self.storage.materialize(self.indices.len())?;
1898 let payload = storage_payload_native(materialized.as_ref())
1899 .context("TensorDynLen::enable_grad failed")?;
1900 let payload_dims = self.storage.payload_dims().to_vec();
1901 let axis_classes = self.storage.axis_classes().to_vec();
1902 Ok(Self {
1903 indices: self.indices,
1904 storage: self.storage,
1905 structured_ad: Some(Arc::new(StructuredAdValue {
1906 payload: Arc::new(EagerTensor::requires_grad_in(payload, default_eager_ctx())),
1907 payload_dims,
1908 axis_classes,
1909 })),
1910 eager_cache: Self::empty_eager_cache(),
1911 })
1912 }
1913
1914 pub fn tracks_grad(&self) -> bool {
1916 self.structured_ad
1917 .as_ref()
1918 .is_some_and(|value| value.payload.tracks_grad())
1919 || self.storage.eager().is_some_and(EagerTensor::tracks_grad)
1920 || self
1921 .eager_cache
1922 .get()
1923 .is_some_and(|inner| inner.tracks_grad())
1924 }
1925
1926 pub fn grad(&self) -> Result<Option<Self>> {
1928 if let Some(value) = self.tracked_compact_payload_value() {
1929 return value
1930 .payload
1931 .grad()
1932 .map(|grad| {
1933 let storage = storage_from_payload_native(
1934 grad.as_ref().clone(),
1935 &value.payload_dims,
1936 value.axis_classes.clone(),
1937 )?;
1938 Self::from_storage(self.indices.clone(), Arc::new(storage))
1939 })
1940 .transpose();
1941 }
1942 self.try_materialized_inner()?
1943 .grad()
1944 .map(|grad| {
1945 Self::from_native_with_axis_classes(
1946 self.indices.clone(),
1947 grad.as_ref().clone(),
1948 self.storage.axis_classes().to_vec(),
1949 )
1950 })
1951 .transpose()
1952 }
1953
1954 pub fn clear_grad(&self) -> Result<()> {
1956 if let Some(value) = self.tracked_compact_payload_value() {
1957 value.payload.clear_grad();
1958 }
1959 if let Some(inner) = self.storage.eager() {
1960 inner.clear_grad();
1961 }
1962 if let Some(inner) = self.eager_cache.get() {
1963 inner.clear_grad();
1964 }
1965 Ok(())
1966 }
1967
1968 pub fn backward(&self) -> Result<()> {
1970 if let Some(value) = self.tracked_compact_payload_value() {
1971 return value
1972 .payload
1973 .backward()
1974 .map(|_| ())
1975 .map_err(|e| anyhow::anyhow!("TensorDynLen::backward failed: {e}"));
1976 }
1977 self.try_materialized_inner()?
1978 .backward()
1979 .map(|_| ())
1980 .map_err(|e| anyhow::anyhow!("TensorDynLen::backward failed: {e}"))
1981 }
1982
1983 pub fn detach(&self) -> Result<Self> {
1985 if self.tracked_compact_payload_value().is_some() {
1986 return Self::from_storage(
1987 self.indices.clone(),
1988 self.storage.materialize(self.indices.len())?,
1989 );
1990 }
1991 Self::from_inner_with_axis_classes(
1992 self.indices.clone(),
1993 self.try_materialized_inner()?.detach(),
1994 self.storage.axis_classes().to_vec(),
1995 )
1996 }
1997
1998 pub fn is_simple(&self) -> bool {
2000 true
2001 }
2002
2003 pub fn to_storage(&self) -> Result<Arc<Storage>> {
2005 self.storage.materialize(self.indices.len())
2006 }
2007
2008 pub fn storage(&self) -> Arc<Storage> {
2010 self.storage
2011 .materialize(self.indices.len())
2012 .expect("TensorDynLen storage materialization failed")
2013 }
2014
2015 pub fn sum(&self) -> Result<AnyScalar> {
2028 if self.indices.is_empty() {
2029 return AnyScalar::from_tensor(self.clone());
2030 }
2031 let axes: Vec<usize> = (0..self.indices.len()).collect();
2032 let reduced = self.try_materialized_inner()?.reduce_sum(&axes)?;
2033 AnyScalar::from_tensor(Self::from_inner(Vec::new(), reduced)?)
2034 }
2035
2036 pub fn only(&self) -> Result<AnyScalar> {
2057 let dims = self.dims();
2058 let total_size = checked_product(&dims)?;
2059 anyhow::ensure!(
2060 total_size == 1 || dims.is_empty(),
2061 "only() requires a scalar tensor (1 element), got {} elements with dims {:?}",
2062 if dims.is_empty() { 1 } else { total_size },
2063 dims
2064 );
2065 self.sum()
2066 }
2067
2068 pub fn permute_indices(&self, new_indices: &[DynIndex]) -> Result<Self> {
2099 let perm = compute_permutation_from_indices(&self.indices, new_indices)?;
2101 if perm.iter().copied().eq(0..perm.len()) {
2102 return Ok(Self {
2103 indices: new_indices.to_vec(),
2104 storage: self.storage.clone(),
2105 structured_ad: self.structured_ad.clone(),
2106 eager_cache: Arc::clone(&self.eager_cache),
2107 });
2108 }
2109
2110 let permuted = self.try_materialized_inner()?.transpose(&perm)?;
2111 let axis_classes = self.permute_axis_classes(&perm);
2112 Self::from_inner_with_axis_classes(new_indices.to_vec(), permuted, axis_classes)
2113 }
2114
2115 pub fn permute(&self, perm: &[usize]) -> Result<Self> {
2144 anyhow::ensure!(
2145 perm.len() == self.indices.len(),
2146 "permutation length must match tensor rank"
2147 );
2148 let mut seen = HashSet::new();
2149 for &axis in perm {
2150 anyhow::ensure!(
2151 axis < self.indices.len(),
2152 "permutation axis {axis} out of range"
2153 );
2154 anyhow::ensure!(seen.insert(axis), "duplicate axis {axis} in permutation");
2155 }
2156 if perm.iter().copied().eq(0..perm.len()) {
2157 return Ok(self.clone());
2158 }
2159
2160 let new_indices: Vec<DynIndex> = perm.iter().map(|&i| self.indices[i].clone()).collect();
2162 let permuted = self.try_materialized_inner()?.transpose(perm)?;
2163 let axis_classes = self.permute_axis_classes(perm);
2164 Self::from_inner_with_axis_classes(new_indices, permuted, axis_classes)
2165 }
2166
2167 pub(crate) fn try_contract_pairwise_default(&self, other: &Self) -> Result<Self> {
2168 self.try_contract_pairwise_default_with_options(other, PairwiseContractionOptions::new())
2169 }
2170
2171 pub(crate) fn try_contract_pairwise_default_with_options(
2172 &self,
2173 other: &Self,
2174 options: PairwiseContractionOptions,
2175 ) -> Result<Self> {
2176 let self_indices = profile_pairwise_contract_section("operand_indices", || {
2177 self.operand_indices_for_contraction(options.lhs_conj)
2178 });
2179 let other_indices = profile_pairwise_contract_section("operand_indices", || {
2180 other.operand_indices_for_contraction(options.rhs_conj)
2181 });
2182 let self_dims = profile_pairwise_contract_section("expected_dims", || {
2183 Self::expected_dims_from_indices(&self_indices)
2184 });
2185 let other_dims = profile_pairwise_contract_section("expected_dims", || {
2186 Self::expected_dims_from_indices(&other_indices)
2187 });
2188 let spec = profile_pairwise_contract_section("prepare_contraction", || {
2189 prepare_contraction(&self_indices, &self_dims, &other_indices, &other_dims)
2190 })
2191 .context("contraction preparation failed")?;
2192 let result_axis_classes = profile_pairwise_contract_section("result_axis_classes", || {
2193 Self::binary_contraction_axis_classes(
2194 self.storage.axis_classes(),
2195 &spec.axes_a,
2196 other.storage.axis_classes(),
2197 &spec.axes_b,
2198 )
2199 });
2200
2201 if profile_pairwise_contract_section("structured_check", || {
2202 self.should_use_structured_payload_contract(other)
2203 }) {
2204 if options.has_conj() {
2205 let lhs = if options.lhs_conj {
2206 self.conj()
2207 } else {
2208 self.clone()
2209 };
2210 let rhs = if options.rhs_conj {
2211 other.conj()
2212 } else {
2213 other.clone()
2214 };
2215 return profile_pairwise_contract_section("structured_conj_fallback", || {
2216 lhs.try_contract_pairwise_default(&rhs)
2217 });
2218 }
2219 return profile_pairwise_contract_section("structured_payload_contract", || {
2220 self.contract_structured_payloads(
2221 other,
2222 spec.result_indices.into_vec(),
2223 &spec.axes_a,
2224 &spec.axes_b,
2225 )
2226 });
2227 }
2228
2229 if self.indices.is_empty() && other.indices.is_empty() {
2230 if options.has_conj() {
2231 let lhs = if options.lhs_conj {
2232 self.conj()
2233 } else {
2234 self.clone()
2235 };
2236 let rhs = if options.rhs_conj {
2237 other.conj()
2238 } else {
2239 other.clone()
2240 };
2241 return lhs.try_contract_pairwise_default(&rhs);
2242 }
2243 let result = profile_pairwise_contract_section("scalar_mul", || {
2244 Ok::<_, anyhow::Error>(
2245 self.try_materialized_inner()?
2246 .mul(other.try_materialized_inner()?)?,
2247 )
2248 })?;
2249 return profile_pairwise_contract_section("from_inner", || {
2250 Self::from_inner(spec.result_indices.into_vec(), result)
2251 });
2252 }
2253
2254 let self_native = profile_pairwise_contract_section("as_native", || self.as_native())?;
2255 let other_native = profile_pairwise_contract_section("as_native", || other.as_native())?;
2256 if self_native.dtype() != other_native.dtype() {
2257 if options.has_conj() {
2258 let lhs = if options.lhs_conj {
2259 self.conj()
2260 } else {
2261 self.clone()
2262 };
2263 let rhs = if options.rhs_conj {
2264 other.conj()
2265 } else {
2266 other.clone()
2267 };
2268 return lhs.try_contract_pairwise_default(&rhs);
2269 }
2270 let result_native = profile_pairwise_contract_section("native_contract", || {
2271 contract_native_tensor(self_native, &spec.axes_a, other_native, &spec.axes_b)
2272 })?;
2273 return profile_pairwise_contract_section("from_native", || {
2274 Self::from_native_with_axis_classes(
2275 spec.result_indices.into_vec(),
2276 result_native,
2277 result_axis_classes,
2278 )
2279 });
2280 }
2281
2282 let config = profile_pairwise_contract_section("build_dot_general_config", || {
2283 Self::binary_dot_general_config(&spec.axes_a, &spec.axes_b)
2284 })?;
2285 let result = profile_pairwise_contract_section("dot_general_with_conj", || {
2286 let lhs = profile_pairwise_contract_section("lhs_try_materialized_inner", || {
2287 self.try_materialized_inner()
2288 })?;
2289 let rhs = profile_pairwise_contract_section("rhs_try_materialized_inner", || {
2290 other.try_materialized_inner()
2291 })?;
2292 profile_pairwise_contract_section("dot_general_execute", || {
2293 lhs.dot_general_with_conj(rhs, &config, options.lhs_conj, options.rhs_conj)
2294 })
2295 .map_err(anyhow::Error::from)
2296 })?;
2297 record_pairwise_contract_profile_bytes(
2298 "dot_general_output",
2299 native_tensor_profile_bytes(result.data()),
2300 );
2301 profile_pairwise_contract_section("from_inner_axis_classes", || {
2302 Self::from_inner_with_axis_classes(
2303 spec.result_indices.into_vec(),
2304 result,
2305 result_axis_classes,
2306 )
2307 })
2308 }
2309
2310 pub(crate) fn try_tensordot_pairwise_explicit(
2311 &self,
2312 other: &Self,
2313 pairs: &[(DynIndex, DynIndex)],
2314 ) -> Result<Self> {
2315 use crate::index_ops::ContractionError;
2316
2317 let self_dims = Self::expected_dims_from_indices(&self.indices);
2318 let other_dims = Self::expected_dims_from_indices(&other.indices);
2319 let spec = prepare_contraction_pairs(
2320 &self.indices,
2321 &self_dims,
2322 &other.indices,
2323 &other_dims,
2324 pairs,
2325 )
2326 .map_err(|e| match e {
2327 ContractionError::NoCommonIndices => {
2328 anyhow::anyhow!("tensordot: No pairs specified for contraction")
2329 }
2330 ContractionError::BatchContractionNotImplemented => anyhow::anyhow!(
2331 "tensordot: Common index found but not in contraction pairs. \
2332 Batch contraction is not yet implemented."
2333 ),
2334 ContractionError::IndexNotFound { tensor } => {
2335 anyhow::anyhow!("tensordot: Index not found in {} tensor", tensor)
2336 }
2337 ContractionError::DimensionMismatch {
2338 pos_a,
2339 pos_b,
2340 dim_a,
2341 dim_b,
2342 } => anyhow::anyhow!(
2343 "tensordot: Dimension mismatch: self[{}]={} != other[{}]={}",
2344 pos_a,
2345 dim_a,
2346 pos_b,
2347 dim_b
2348 ),
2349 ContractionError::DuplicateAxis { tensor, pos } => {
2350 anyhow::anyhow!("tensordot: Duplicate axis {} in {} tensor", pos, tensor)
2351 }
2352 })?;
2353 let result_axis_classes = Self::binary_contraction_axis_classes(
2354 self.storage.axis_classes(),
2355 &spec.axes_a,
2356 other.storage.axis_classes(),
2357 &spec.axes_b,
2358 );
2359
2360 if self.should_use_structured_payload_contract(other) {
2361 return self.contract_structured_payloads(
2362 other,
2363 spec.result_indices.into_vec(),
2364 &spec.axes_a,
2365 &spec.axes_b,
2366 );
2367 }
2368
2369 if self.indices.is_empty() && other.indices.is_empty() {
2370 let result = self
2371 .try_materialized_inner()?
2372 .mul(other.try_materialized_inner()?)
2373 .map_err(|e| anyhow::anyhow!("tensordot scalar multiply failed: {e}"))?;
2374 return Self::from_inner(spec.result_indices.into_vec(), result);
2375 }
2376
2377 let self_native = self.as_native()?;
2378 let other_native = other.as_native()?;
2379 if self_native.dtype() != other_native.dtype() {
2380 let result_native =
2381 contract_native_tensor(self_native, &spec.axes_a, other_native, &spec.axes_b)?;
2382 return Self::from_native_with_axis_classes(
2383 spec.result_indices.into_vec(),
2384 result_native,
2385 result_axis_classes,
2386 );
2387 }
2388
2389 let subscripts = Self::build_binary_einsum_subscripts(
2390 self.indices.len(),
2391 &spec.axes_a,
2392 other.indices.len(),
2393 &spec.axes_b,
2394 )?;
2395 let result = eager_einsum_ad(
2396 &[
2397 self.try_materialized_inner()?,
2398 other.try_materialized_inner()?,
2399 ],
2400 &subscripts,
2401 )
2402 .map_err(|e| anyhow::anyhow!("tensordot failed: {e}"))?;
2403 Self::from_inner_with_axis_classes(
2404 spec.result_indices.into_vec(),
2405 result,
2406 result_axis_classes,
2407 )
2408 }
2409
2410 pub(crate) fn try_outer_product_pairwise(&self, other: &Self) -> Result<Self> {
2411 use anyhow::Context;
2412
2413 let common_positions = common_ind_positions(&self.indices, &other.indices);
2415 if !common_positions.is_empty() {
2416 let common_ids: Vec<_> = common_positions
2417 .iter()
2418 .map(|(pos_a, _)| self.indices[*pos_a].id())
2419 .collect();
2420 return Err(anyhow::anyhow!(
2421 "outer_product: tensors have common indices {:?}. \
2422 Use tensordot to contract common indices, or use sim() to replace \
2423 indices with fresh IDs before computing outer product.",
2424 common_ids
2425 ))
2426 .context("outer_product: common indices found");
2427 }
2428
2429 let mut result_indices = self.indices.clone();
2431 result_indices.extend(other.indices.iter().cloned());
2432 let result_axis_classes = Self::binary_contraction_axis_classes(
2433 self.storage.axis_classes(),
2434 &[],
2435 other.storage.axis_classes(),
2436 &[],
2437 );
2438 if self.should_use_structured_payload_contract(other) {
2439 return self.contract_structured_payloads(other, result_indices, &[], &[]);
2440 }
2441 let self_native = self.as_native()?;
2442 let other_native = other.as_native()?;
2443 if self_native.dtype() != other_native.dtype() {
2444 let result_native = contract_native_tensor(self_native, &[], other_native, &[])?;
2445 return Self::from_native_with_axis_classes(
2446 result_indices,
2447 result_native,
2448 result_axis_classes,
2449 );
2450 }
2451
2452 let subscripts = Self::build_binary_einsum_subscripts(
2453 self.indices.len(),
2454 &[],
2455 other.indices.len(),
2456 &[],
2457 )?;
2458 let result = eager_einsum_ad(
2459 &[
2460 self.try_materialized_inner()?,
2461 other.try_materialized_inner()?,
2462 ],
2463 &subscripts,
2464 )
2465 .map_err(|e| anyhow::anyhow!("outer_product failed: {e}"))?;
2466 Self::from_inner_with_axis_classes(result_indices, result, result_axis_classes)
2467 }
2468}
2469
2470impl TensorDynLen {
2475 pub fn random<T: RandomScalar, R: Rng>(rng: &mut R, indices: Vec<DynIndex>) -> Result<Self> {
2502 let dims: Vec<usize> = indices.iter().map(|idx| idx.dim()).collect();
2503 let size = checked_product(&dims)?;
2504 let data: Vec<T> = (0..size).map(|_| T::random_value(rng)).collect();
2505 Self::from_dense(indices, data)
2506 }
2507}
2508
2509impl TensorDynLen {
2510 pub fn add(&self, other: &Self) -> Result<Self> {
2544 if self.indices.len() != other.indices.len() {
2546 return Err(anyhow::anyhow!(
2547 "Index count mismatch: self has {} indices, other has {}",
2548 self.indices.len(),
2549 other.indices.len()
2550 ));
2551 }
2552
2553 let self_set: HashSet<_> = self.indices.iter().collect();
2555 let other_set: HashSet<_> = other.indices.iter().collect();
2556
2557 if self_set != other_set {
2558 return Err(anyhow::anyhow!(
2559 "Index set mismatch: tensors must have the same indices"
2560 ));
2561 }
2562
2563 let other_aligned = other.permute_indices(&self.indices)?;
2565
2566 let self_expected_dims = Self::expected_dims_from_indices(&self.indices);
2568 let other_expected_dims = Self::expected_dims_from_indices(&other_aligned.indices);
2569 if self_expected_dims != other_expected_dims {
2570 use crate::TagSetLike;
2571 let fmt = |indices: &[DynIndex]| -> Vec<String> {
2572 indices
2573 .iter()
2574 .map(|idx| {
2575 let tags: Vec<String> = idx.tags().iter().collect();
2576 format!("{:?}(dim={},tags={:?})", idx.id(), idx.dim(), tags)
2577 })
2578 .collect()
2579 };
2580 return Err(anyhow::anyhow!(
2581 "Dimension mismatch after alignment.\n\
2582 self: dims={:?}, indices(order)={:?}\n\
2583 other_aligned: dims={:?}, indices(order)={:?}",
2584 self_expected_dims,
2585 fmt(&self.indices),
2586 other_expected_dims,
2587 fmt(&other_aligned.indices)
2588 ));
2589 }
2590
2591 self.axpby(
2592 AnyScalar::new_real(1.0),
2593 &other_aligned,
2594 AnyScalar::new_real(1.0),
2595 )
2596 }
2597
2598 pub fn axpby(&self, a: AnyScalar, other: &Self, b: AnyScalar) -> Result<Self> {
2620 if self.indices.len() != other.indices.len() {
2622 return Err(anyhow::anyhow!(
2623 "Index count mismatch: self has {} indices, other has {}",
2624 self.indices.len(),
2625 other.indices.len()
2626 ));
2627 }
2628
2629 let self_set: HashSet<_> = self.indices.iter().collect();
2631 let other_set: HashSet<_> = other.indices.iter().collect();
2632 if self_set != other_set {
2633 return Err(anyhow::anyhow!(
2634 "Index set mismatch: tensors must have the same indices"
2635 ));
2636 }
2637
2638 let other_aligned = other.permute_indices(&self.indices)?;
2640
2641 let self_expected_dims = Self::expected_dims_from_indices(&self.indices);
2643 let other_expected_dims = Self::expected_dims_from_indices(&other_aligned.indices);
2644 if self_expected_dims != other_expected_dims {
2645 return Err(anyhow::anyhow!(
2646 "Dimension mismatch after alignment: self={:?}, other_aligned={:?}",
2647 self_expected_dims,
2648 other_expected_dims
2649 ));
2650 }
2651
2652 let axis_classes = if self.storage.axis_classes() == other_aligned.storage.axis_classes() {
2653 self.storage.axis_classes().to_vec()
2654 } else {
2655 Self::dense_axis_classes(self.indices.len())
2656 };
2657
2658 let same_compact_layout = self.storage.payload_dims()
2659 == other_aligned.storage.payload_dims()
2660 && self.storage.payload_strides_vec() == other_aligned.storage.payload_strides_vec()
2661 && self.storage.axis_classes() == other_aligned.storage.axis_classes();
2662 if same_compact_layout
2663 && !self.tracks_grad()
2664 && !other_aligned.tracks_grad()
2665 && !a.tracks_grad()
2666 && !b.tracks_grad()
2667 {
2668 let lhs_storage = self.storage.materialize(self.indices.len())?;
2669 let rhs_storage = other_aligned
2670 .storage
2671 .materialize(other_aligned.indices.len())?;
2672 let combined = lhs_storage
2673 .axpby(
2674 &a.to_backend_scalar(),
2675 rhs_storage.as_ref(),
2676 &b.to_backend_scalar(),
2677 )
2678 .map_err(|e| anyhow::anyhow!("storage axpby failed: {e}"))?;
2679 return Self::from_storage(self.indices.clone(), Arc::new(combined));
2680 }
2681
2682 let self_native = self.as_native()?;
2683 let other_native = other_aligned.as_native()?;
2684 let a_native = a.as_tensor()?.as_native()?;
2685 let b_native = b.as_tensor()?.as_native()?;
2686 if self_native.dtype() != other_native.dtype()
2687 || self_native.dtype() != a_native.dtype()
2688 || other_native.dtype() != b_native.dtype()
2689 {
2690 let combined = axpby_native_tensor(
2691 self_native,
2692 &a.to_backend_scalar(),
2693 other_native,
2694 &b.to_backend_scalar(),
2695 )?;
2696 return Self::from_native_with_axis_classes(
2697 self.indices.clone(),
2698 combined,
2699 axis_classes,
2700 );
2701 }
2702
2703 let lhs = self.scale(a)?;
2704 let rhs = other_aligned.scale(b)?;
2705 let combined = lhs
2706 .try_materialized_inner()?
2707 .add(rhs.try_materialized_inner()?)
2708 .map_err(|e| anyhow::anyhow!("tensor addition failed: {e}"))?;
2709 Self::from_inner_with_axis_classes(self.indices.clone(), combined, axis_classes)
2710 }
2711
2712 pub fn scale(&self, scalar: AnyScalar) -> Result<Self> {
2727 if !self.tracks_grad() && !scalar.tracks_grad() {
2728 let scaled = self.storage.scale(&scalar.to_backend_scalar())?;
2729 return Self::from_storage(self.indices.clone(), Arc::new(scaled));
2730 }
2731
2732 let self_native = self.as_native()?;
2733 let scalar_native = scalar.as_tensor()?.as_native()?;
2734 if self_native.dtype() != scalar_native.dtype() {
2735 let scaled = scale_native_tensor(self_native, &scalar.to_backend_scalar())?;
2736 return Self::from_native_with_axis_classes(
2737 self.indices.clone(),
2738 scaled,
2739 self.storage.axis_classes().to_vec(),
2740 );
2741 }
2742
2743 let scaled = if self.indices.is_empty() {
2744 self.try_materialized_inner()?
2745 .mul(scalar.as_tensor()?.try_materialized_inner()?)
2746 .map_err(|e| anyhow::anyhow!("scalar multiplication failed: {e}"))?
2747 } else {
2748 let subscripts = Self::scale_subscripts(self.indices.len())?;
2749 eager_einsum_ad(
2750 &[
2751 self.try_materialized_inner()?,
2752 scalar.as_tensor()?.try_materialized_inner()?,
2753 ],
2754 &subscripts,
2755 )
2756 .map_err(|e| anyhow::anyhow!("tensor scaling failed: {e}"))?
2757 };
2758 Self::from_inner_with_axis_classes(
2759 self.indices.clone(),
2760 scaled,
2761 self.storage.axis_classes().to_vec(),
2762 )
2763 }
2764
2765 pub fn inner_product(&self, other: &Self) -> Result<AnyScalar> {
2783 if self.indices.len() == other.indices.len() {
2784 let self_set: HashSet<_> = self.indices.iter().collect();
2785 let other_set: HashSet<_> = other.indices.iter().collect();
2786 if self_set == other_set {
2787 let other_aligned = other.permute_indices(&self.indices)?;
2788 let result = super::contract::contract_pair_with_operand_options(
2789 self,
2790 &other_aligned,
2791 PairwiseContractionOptions::new().with_lhs_conj(true),
2792 )?;
2793 return result.sum();
2794 }
2795 }
2796
2797 let result = super::contract::contract_pair_with_operand_options(
2799 self,
2800 other,
2801 PairwiseContractionOptions::new().with_lhs_conj(true),
2802 )?;
2803 result.sum()
2805 }
2806}
2807
2808impl TensorDynLen {
2813 pub fn replaceind(&self, old_index: &DynIndex, new_index: &DynIndex) -> Result<Self> {
2847 if old_index.dim() != new_index.dim() {
2849 return Err(anyhow::anyhow!(
2850 "Index space mismatch: cannot replace index with dimension {} with index of dimension {}",
2851 old_index.dim(),
2852 new_index.dim()
2853 ));
2854 }
2855
2856 let new_indices: Vec<_> = self
2857 .indices
2858 .iter()
2859 .map(|idx| {
2860 if *idx == *old_index {
2861 new_index.clone()
2862 } else {
2863 idx.clone()
2864 }
2865 })
2866 .collect();
2867
2868 Ok(Self {
2869 indices: new_indices,
2870 storage: self.storage.clone(),
2871 structured_ad: self.structured_ad.clone(),
2872 eager_cache: Arc::clone(&self.eager_cache),
2873 })
2874 }
2875
2876 pub fn replaceinds(&self, old_indices: &[DynIndex], new_indices: &[DynIndex]) -> Result<Self> {
2914 anyhow::ensure!(
2915 old_indices.len() == new_indices.len(),
2916 "old_indices and new_indices must have the same length"
2917 );
2918
2919 for (old, new) in old_indices.iter().zip(new_indices.iter()) {
2921 if old.dim() != new.dim() {
2922 return Err(anyhow::anyhow!(
2923 "Index space mismatch: cannot replace index with dimension {} with index of dimension {}",
2924 old.dim(),
2925 new.dim()
2926 ));
2927 }
2928 }
2929
2930 let replacement_map: std::collections::HashMap<_, _> =
2932 old_indices.iter().zip(new_indices.iter()).collect();
2933
2934 let new_indices_vec: Vec<_> = self
2935 .indices
2936 .iter()
2937 .map(|idx| {
2938 if let Some(new_idx) = replacement_map.get(idx) {
2939 (*new_idx).clone()
2940 } else {
2941 idx.clone()
2942 }
2943 })
2944 .collect();
2945
2946 Ok(Self {
2947 indices: new_indices_vec,
2948 storage: self.storage.clone(),
2949 structured_ad: self.structured_ad.clone(),
2950 eager_cache: Arc::clone(&self.eager_cache),
2951 })
2952 }
2953}
2954
2955impl TensorDynLen {
2960 pub fn conj(&self) -> Self {
2983 let new_indices: Vec<DynIndex> = self.indices.iter().map(|idx| idx.conj()).collect();
2987 let structured_ad = self.tracked_compact_payload_value().and_then(|value| {
2988 value.payload.conj().ok().map(|payload| {
2989 Arc::new(StructuredAdValue {
2990 payload: Arc::new(payload),
2991 payload_dims: value.payload_dims.clone(),
2992 axis_classes: value.axis_classes.clone(),
2993 })
2994 })
2995 });
2996 let eager_cache = self
2997 .eager_cache
2998 .get()
2999 .and_then(|inner| inner.conj().ok())
3000 .map(Self::eager_cache_with)
3001 .unwrap_or_else(Self::empty_eager_cache);
3002 Self {
3003 indices: new_indices,
3004 storage: self.storage.conj().unwrap_or_else(|_| {
3005 TensorDynLenStorage::from_storage(Arc::new(self.storage().conj()))
3006 }),
3007 structured_ad,
3008 eager_cache,
3009 }
3010 }
3011}
3012
3013impl TensorDynLen {
3018 pub fn norm_squared(&self) -> f64 {
3036 self.try_norm_squared().unwrap_or(f64::NAN)
3037 }
3038
3039 pub fn try_norm_squared(&self) -> Result<f64> {
3044 if self.indices.is_empty() {
3046 let value = self.sum()?;
3048 let abs_val = value.abs();
3049 return Ok(abs_val * abs_val);
3050 }
3051
3052 let conj = self.conj();
3055 let scalar = super::contract::contract_pair(self, &conj)?;
3056 Ok(scalar.sum()?.real().max(0.0))
3059 }
3060
3061 pub fn norm(&self) -> f64 {
3075 self.norm_squared().sqrt()
3076 }
3077
3078 pub fn maxabs(&self) -> f64 {
3090 self.storage.max_abs().unwrap_or(0.0)
3091 }
3092
3093 pub fn sub(&self, other: &Self) -> Result<Self> {
3101 self.axpby(AnyScalar::new_real(1.0), other, AnyScalar::new_real(-1.0))
3102 }
3103
3104 pub fn neg(&self) -> Result<Self> {
3109 self.scale(AnyScalar::new_real(-1.0))
3110 }
3111
3112 pub fn isapprox(&self, other: &Self, atol: f64, rtol: f64) -> bool {
3117 let diff = match self.sub(other) {
3118 Ok(d) => d,
3119 Err(_) => return false,
3120 };
3121 let diff_norm = diff.norm();
3122 diff_norm <= atol.max(rtol * self.norm().max(other.norm()))
3123 }
3124
3125 pub fn diagonal(input_index: &DynIndex, output_index: &DynIndex) -> Result<Self> {
3130 <Self as TensorConstructionLike>::diagonal(input_index, output_index)
3131 }
3132
3133 pub fn delta(input_indices: &[DynIndex], output_indices: &[DynIndex]) -> Result<Self> {
3139 <Self as TensorConstructionLike>::delta(input_indices, output_indices)
3140 }
3141
3142 pub fn scalar_one() -> Result<Self> {
3147 <Self as TensorConstructionLike>::scalar_one()
3148 }
3149
3150 pub fn ones(indices: &[DynIndex]) -> Result<Self> {
3155 <Self as TensorConstructionLike>::ones(indices)
3156 }
3157
3158 pub fn onehot(index_vals: &[(DynIndex, usize)]) -> Result<Self> {
3163 <Self as TensorConstructionLike>::onehot(index_vals)
3164 }
3165
3166 pub fn distance(&self, other: &Self) -> Result<f64> {
3197 let norm_self = self.norm();
3198
3199 let neg_other = other.scale(AnyScalar::new_real(-1.0))?;
3201 let diff = self.add(&neg_other)?;
3202 let norm_diff = diff.norm();
3203
3204 if norm_self > 0.0 {
3205 Ok(norm_diff / norm_self)
3206 } else {
3207 Ok(norm_diff)
3208 }
3209 }
3210}
3211
3212impl std::fmt::Debug for TensorDynLen {
3213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3214 f.debug_struct("TensorDynLen")
3215 .field("indices", &self.indices)
3216 .field("dims", &self.dims())
3217 .field("is_diag", &self.is_diag())
3218 .finish()
3219 }
3220}
3221
3222pub fn diag_tensor_dyn_len(indices: Vec<DynIndex>, diag_data: Vec<f64>) -> Result<TensorDynLen> {
3247 TensorDynLen::from_diag(indices, diag_data)
3248}
3249
3250#[allow(clippy::type_complexity)]
3251pub(crate) type UnfoldSplitInnerResult = (
3252 EagerTensor,
3253 usize,
3254 usize,
3255 usize,
3256 Vec<DynIndex>,
3257 Vec<DynIndex>,
3258);
3259
3260#[allow(clippy::type_complexity)]
3307pub fn unfold_split(
3308 t: &TensorDynLen,
3309 left_inds: &[DynIndex],
3310) -> Result<(
3311 NativeTensor,
3312 usize,
3313 usize,
3314 usize,
3315 Vec<DynIndex>,
3316 Vec<DynIndex>,
3317)> {
3318 let (matrix_inner, left_len, m, n, left_indices, right_indices) =
3319 unfold_split_inner(t, left_inds)?;
3320
3321 Ok((
3322 matrix_inner.data().clone(),
3323 left_len,
3324 m,
3325 n,
3326 left_indices,
3327 right_indices,
3328 ))
3329}
3330
3331pub(crate) fn unfold_split_inner(
3332 t: &TensorDynLen,
3333 left_inds: &[DynIndex],
3334) -> Result<UnfoldSplitInnerResult> {
3335 let rank = t.indices.len();
3336
3337 anyhow::ensure!(rank >= 2, "Tensor must have rank >= 2, got rank {}", rank);
3339
3340 let left_len = left_inds.len();
3341
3342 anyhow::ensure!(
3344 left_len > 0 && left_len < rank,
3345 "Left indices must be a non-empty proper subset of tensor indices (0 < left_len < rank), got left_len={}, rank={}",
3346 left_len,
3347 rank
3348 );
3349
3350 let tensor_set: HashSet<_> = t.indices.iter().collect();
3352 let mut left_set = HashSet::new();
3353
3354 for left_idx in left_inds {
3355 anyhow::ensure!(
3356 tensor_set.contains(left_idx),
3357 "Index in left_inds not found in tensor"
3358 );
3359 anyhow::ensure!(left_set.insert(left_idx), "Duplicate index in left_inds");
3360 }
3361
3362 let mut right_inds = Vec::new();
3364 for idx in &t.indices {
3365 if !left_set.contains(idx) {
3366 right_inds.push(idx.clone());
3367 }
3368 }
3369
3370 let mut new_indices = Vec::with_capacity(rank);
3372 new_indices.extend_from_slice(left_inds);
3373 new_indices.extend_from_slice(&right_inds);
3374
3375 let unfolded = t.permute_indices(&new_indices)?;
3377
3378 let unfolded_dims = unfolded.dims();
3380 let m: usize = unfolded_dims[..left_len].iter().product();
3381 let n: usize = unfolded_dims[left_len..].iter().product();
3382
3383 let matrix_tensor = unfolded.try_materialized_inner()?.reshape(&[m, n])?;
3384
3385 Ok((
3386 matrix_tensor,
3387 left_len,
3388 m,
3389 n,
3390 left_inds.to_vec(),
3391 right_inds,
3392 ))
3393}
3394
3395use crate::tensor_index::TensorIndex;
3400
3401impl TensorIndex for TensorDynLen {
3402 type Index = DynIndex;
3403
3404 fn external_indices(&self) -> Vec<DynIndex> {
3405 self.indices.clone()
3407 }
3408
3409 fn num_external_indices(&self) -> usize {
3410 self.indices.len()
3411 }
3412
3413 fn replaceind(&self, old_index: &DynIndex, new_index: &DynIndex) -> Result<Self> {
3414 TensorDynLen::replaceind(self, old_index, new_index)
3416 }
3417
3418 fn replaceinds(&self, old_indices: &[DynIndex], new_indices: &[DynIndex]) -> Result<Self> {
3419 TensorDynLen::replaceinds(self, old_indices, new_indices)
3421 }
3422}
3423
3424use crate::tensor_like::{
3429 FactorizeError, FactorizeOptions, FactorizeResult, TensorConstructionLike,
3430 TensorContractionLike, TensorFactorizationLike, TensorVectorSpace,
3431};
3432
3433impl TensorVectorSpace for TensorDynLen {
3434 fn norm_squared(&self) -> f64 {
3435 TensorDynLen::norm_squared(self)
3436 }
3437
3438 fn maxabs(&self) -> f64 {
3439 TensorDynLen::maxabs(self)
3440 }
3441
3442 fn axpby(&self, a: crate::AnyScalar, other: &Self, b: crate::AnyScalar) -> Result<Self> {
3443 TensorDynLen::axpby(self, a, other, b)
3444 }
3445
3446 fn scale(&self, scalar: crate::AnyScalar) -> Result<Self> {
3447 TensorDynLen::scale(self, scalar)
3448 }
3449
3450 fn inner_product(&self, other: &Self) -> Result<crate::AnyScalar> {
3451 TensorDynLen::inner_product(self, other)
3452 }
3453}
3454
3455impl TensorFactorizationLike for TensorDynLen {
3456 fn factorize(
3457 &self,
3458 left_inds: &[DynIndex],
3459 options: &FactorizeOptions,
3460 ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
3461 crate::factorize::factorize(self, left_inds, options)
3462 }
3463
3464 fn factorize_full_rank(
3465 &self,
3466 left_inds: &[DynIndex],
3467 alg: crate::FactorizeAlg,
3468 canonical: crate::Canonical,
3469 ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
3470 crate::factorize::factorize_full_rank(self, left_inds, alg, canonical)
3471 }
3472}
3473
3474impl TensorContractionLike for TensorDynLen {
3475 fn conj(&self) -> Self {
3476 TensorDynLen::conj(self)
3478 }
3479
3480 fn direct_sum(
3481 &self,
3482 other: &Self,
3483 pairs: &[(DynIndex, DynIndex)],
3484 ) -> Result<crate::tensor_like::DirectSumResult<Self>> {
3485 let (tensor, new_indices) = crate::direct_sum::direct_sum(self, other, pairs)?;
3486 Ok(crate::tensor_like::DirectSumResult {
3487 tensor,
3488 new_indices,
3489 })
3490 }
3491
3492 fn outer_product(&self, other: &Self) -> Result<Self> {
3493 super::contract::outer_product(self, other)
3494 }
3495
3496 fn permuteinds(&self, new_order: &[DynIndex]) -> Result<Self> {
3497 TensorDynLen::permute_indices(self, new_order)
3499 }
3500
3501 fn fuse_indices(
3502 &self,
3503 old_indices: &[DynIndex],
3504 new_index: DynIndex,
3505 order: LinearizationOrder,
3506 ) -> Result<Self> {
3507 TensorDynLen::fuse_indices(self, old_indices, new_index, order)
3508 }
3509
3510 fn contract(tensors: &[&Self]) -> Result<Self> {
3511 super::contract::contract(tensors)
3512 }
3513
3514 fn contract_pair(&self, other: &Self) -> Result<Self> {
3515 super::contract::contract_pair(self, other)
3516 }
3517}
3518
3519impl TensorConstructionLike for TensorDynLen {
3520 fn select_indices(&self, selected_indices: &[DynIndex], positions: &[usize]) -> Result<Self> {
3521 TensorDynLen::select_indices(self, selected_indices, positions)
3522 }
3523
3524 fn diagonal(input_index: &DynIndex, output_index: &DynIndex) -> Result<Self> {
3525 let dim = input_index.dim();
3526 if dim != output_index.dim() {
3527 return Err(anyhow::anyhow!(
3528 "Dimension mismatch: input index has dim {}, output has dim {}",
3529 dim,
3530 output_index.dim(),
3531 ));
3532 }
3533
3534 TensorDynLen::from_diag(
3535 vec![input_index.clone(), output_index.clone()],
3536 vec![1.0_f64; dim],
3537 )
3538 }
3539
3540 fn scalar_one() -> Result<Self> {
3541 TensorDynLen::from_dense(vec![], vec![1.0_f64])
3542 }
3543
3544 fn ones(indices: &[DynIndex]) -> Result<Self> {
3545 if indices.is_empty() {
3546 return Self::scalar_one();
3547 }
3548 let dims: Vec<usize> = indices.iter().map(|idx| idx.size()).collect();
3549 let total_size = checked_total_size(&dims)?;
3550 TensorDynLen::from_dense(indices.to_vec(), vec![1.0_f64; total_size])
3551 }
3552
3553 fn onehot(index_vals: &[(DynIndex, usize)]) -> Result<Self> {
3554 if index_vals.is_empty() {
3555 return Self::scalar_one();
3556 }
3557 let indices: Vec<DynIndex> = index_vals.iter().map(|(idx, _)| idx.clone()).collect();
3558 let vals: Vec<usize> = index_vals.iter().map(|(_, v)| *v).collect();
3559 let dims: Vec<usize> = indices.iter().map(|idx| idx.size()).collect();
3560
3561 for (k, (&v, &d)) in vals.iter().zip(dims.iter()).enumerate() {
3562 if v >= d {
3563 return Err(anyhow::anyhow!(
3564 "onehot: value {} at position {} is >= dimension {}",
3565 v,
3566 k,
3567 d
3568 ));
3569 }
3570 }
3571
3572 let total_size = checked_total_size(&dims)?;
3573 let mut data = vec![0.0_f64; total_size];
3574
3575 let offset = column_major_offset(&dims, &vals)?;
3576 data[offset] = 1.0;
3577
3578 Self::from_dense(indices, data)
3579 }
3580
3581 }
3583
3584fn checked_total_size(dims: &[usize]) -> Result<usize> {
3585 dims.iter().try_fold(1_usize, |acc, &d| {
3586 if d == 0 {
3587 return Err(anyhow::anyhow!("invalid dimension 0"));
3588 }
3589 acc.checked_mul(d)
3590 .ok_or_else(|| anyhow::anyhow!("tensor size overflow"))
3591 })
3592}
3593
3594fn column_major_offset(dims: &[usize], vals: &[usize]) -> Result<usize> {
3595 if dims.len() != vals.len() {
3596 return Err(anyhow::anyhow!(
3597 "column_major_offset: dims.len() != vals.len()"
3598 ));
3599 }
3600 checked_total_size(dims)?;
3601
3602 let mut offset = 0usize;
3603 let mut stride = 1usize;
3604 for (k, (&v, &d)) in vals.iter().zip(dims.iter()).enumerate() {
3605 if d == 0 {
3606 return Err(anyhow::anyhow!("invalid dimension 0 at position {}", k));
3607 }
3608 if v >= d {
3609 return Err(anyhow::anyhow!(
3610 "column_major_offset: value {} at position {} is >= dimension {}",
3611 v,
3612 k,
3613 d
3614 ));
3615 }
3616 let term = v
3617 .checked_mul(stride)
3618 .ok_or_else(|| anyhow::anyhow!("column_major_offset: overflow"))?;
3619 offset = offset
3620 .checked_add(term)
3621 .ok_or_else(|| anyhow::anyhow!("column_major_offset: overflow"))?;
3622 stride = stride
3623 .checked_mul(d)
3624 .ok_or_else(|| anyhow::anyhow!("column_major_offset: overflow"))?;
3625 }
3626 Ok(offset)
3627}
3628
3629impl TensorDynLen {
3634 fn any_scalar_payload_to_complex(data: Vec<AnyScalar>) -> Vec<Complex64> {
3635 data.into_iter()
3636 .map(|value| {
3637 value
3638 .as_c64()
3639 .unwrap_or_else(|| Complex64::new(value.real(), 0.0))
3640 })
3641 .collect()
3642 }
3643
3644 fn any_scalar_payload_to_real(data: Vec<AnyScalar>) -> Vec<f64> {
3645 data.into_iter().map(|value| value.real()).collect()
3646 }
3647
3648 fn validate_dense_payload_len(data_len: usize, dims: &[usize]) -> Result<()> {
3649 let expected_len = checked_total_size(dims)?;
3650 anyhow::ensure!(
3651 data_len == expected_len,
3652 "dense payload length {} does not match dims {:?} (expected {})",
3653 data_len,
3654 dims,
3655 expected_len
3656 );
3657 Ok(())
3658 }
3659
3660 fn validate_diag_payload_len(data_len: usize, dims: &[usize]) -> Result<()> {
3661 anyhow::ensure!(
3662 !dims.is_empty(),
3663 "diagonal tensor construction requires at least one index"
3664 );
3665 Self::validate_diag_dims(dims)?;
3666 anyhow::ensure!(
3667 data_len == dims[0],
3668 "diagonal payload length {} does not match diagonal dimension {}",
3669 data_len,
3670 dims[0]
3671 );
3672 Ok(())
3673 }
3674
3675 pub fn from_dense<T: TensorElement>(indices: Vec<DynIndex>, data: Vec<T>) -> Result<Self> {
3702 let dims = Self::expected_dims_from_indices(&indices);
3703 Self::validate_indices(&indices)?;
3704 Self::validate_dense_payload_len(data.len(), &dims)?;
3705 let native = dense_native_tensor_from_col_major(&data, &dims)?;
3706 Self::from_native(indices, native)
3707 }
3708
3709 pub fn from_dense_any(indices: Vec<DynIndex>, data: Vec<AnyScalar>) -> Result<Self> {
3735 if data.iter().any(AnyScalar::is_complex) {
3736 Self::from_dense(indices, Self::any_scalar_payload_to_complex(data))
3737 } else {
3738 Self::from_dense(indices, Self::any_scalar_payload_to_real(data))
3739 }
3740 }
3741
3742 pub fn from_diag<T: TensorElement>(indices: Vec<DynIndex>, data: Vec<T>) -> Result<Self> {
3770 let dims = Self::expected_dims_from_indices(&indices);
3771 Self::validate_indices(&indices)?;
3772 Self::validate_diag_payload_len(data.len(), &dims)?;
3773 let native = diag_native_tensor_from_col_major(&data, dims.len())?;
3774 Self::from_native_with_axis_classes(indices, native, Self::diag_axis_classes(dims.len()))
3775 }
3776
3777 pub fn from_diag_any(indices: Vec<DynIndex>, data: Vec<AnyScalar>) -> Result<Self> {
3799 if data.iter().any(AnyScalar::is_complex) {
3800 Self::from_diag(indices, Self::any_scalar_payload_to_complex(data))
3801 } else {
3802 Self::from_diag(indices, Self::any_scalar_payload_to_real(data))
3803 }
3804 }
3805
3806 pub fn copy_tensor(indices: Vec<DynIndex>, value: AnyScalar) -> Result<Self> {
3827 if indices.is_empty() {
3828 return Self::from_dense_any(vec![], vec![value]);
3829 }
3830 let dim = indices[0].dim();
3831 let data = vec![value; dim];
3832 Self::from_diag_any(indices, data)
3833 }
3834
3835 pub fn fuse_indices(
3889 &self,
3890 old_indices: &[DynIndex],
3891 new_index: DynIndex,
3892 order: LinearizationOrder,
3893 ) -> Result<Self> {
3894 anyhow::ensure!(
3895 !old_indices.is_empty(),
3896 "fuse_indices requires at least one index to fuse"
3897 );
3898
3899 let old_dims = self.dims();
3900 let mut seen_indices = HashSet::new();
3901 let mut old_axes = Vec::with_capacity(old_indices.len());
3902 for old_index in old_indices {
3903 anyhow::ensure!(
3904 seen_indices.insert(old_index),
3905 "duplicate index in old_indices"
3906 );
3907 let axis = self
3908 .indices
3909 .iter()
3910 .position(|idx| idx == old_index)
3911 .ok_or_else(|| anyhow::anyhow!("index {:?} not found in tensor", old_index))?;
3912 anyhow::ensure!(
3913 old_index.dim() == old_dims[axis],
3914 "old index dimension does not match tensor axis dimension"
3915 );
3916 old_axes.push(axis);
3917 }
3918
3919 let fused_dims: Vec<usize> = old_axes.iter().map(|&axis| old_dims[axis]).collect();
3920 let fused_product = checked_product(&fused_dims)?;
3921 anyhow::ensure!(
3922 fused_product == new_index.dim(),
3923 "product of old index dimensions must match the replacement index dimension"
3924 );
3925
3926 let insertion_axis =
3927 old_axes.iter().copied().min().ok_or_else(|| {
3928 anyhow::anyhow!("fuse_indices requires at least one index to fuse")
3929 })?;
3930 let old_axis_set: HashSet<usize> = old_axes.iter().copied().collect();
3931
3932 let mut result_indices =
3933 Vec::with_capacity(self.indices.len() - old_indices.len() + 1usize);
3934 for (axis, index) in self.indices.iter().enumerate() {
3935 if axis == insertion_axis {
3936 result_indices.push(new_index.clone());
3937 }
3938 if !old_axis_set.contains(&axis) {
3939 result_indices.push(index.clone());
3940 }
3941 }
3942 let mut result_seen = HashSet::new();
3943 for index in &result_indices {
3944 anyhow::ensure!(
3945 result_seen.insert(index),
3946 "fuse_indices result would contain duplicate index"
3947 );
3948 }
3949 Self::validate_indices(&result_indices)?;
3950
3951 let mut new_dims = Vec::with_capacity(old_dims.len() - old_indices.len() + 1usize);
3952 for (axis, dim) in old_dims.iter().copied().enumerate() {
3953 if axis == insertion_axis {
3954 new_dims.push(new_index.dim());
3955 }
3956 if !old_axis_set.contains(&axis) {
3957 new_dims.push(dim);
3958 }
3959 }
3960
3961 let old_data = self.to_vec_any()?;
3962 let mut new_data = vec![AnyScalar::new_real(0.0); old_data.len()];
3963 for (old_linear, value) in old_data.into_iter().enumerate() {
3964 let old_multi = decode_col_major_linear(old_linear, &old_dims)?;
3965 let fused_multi: Vec<usize> = old_axes.iter().map(|&axis| old_multi[axis]).collect();
3966 let fused_linear = encode_linear_with_order(&fused_multi, &fused_dims, order)?;
3967
3968 let mut new_multi = Vec::with_capacity(new_dims.len());
3969 for (axis, old_coord) in old_multi.iter().copied().enumerate() {
3970 if axis == insertion_axis {
3971 new_multi.push(fused_linear);
3972 }
3973 if !old_axis_set.contains(&axis) {
3974 new_multi.push(old_coord);
3975 }
3976 }
3977 let new_linear = encode_col_major_linear(&new_multi, &new_dims)?;
3978 new_data[new_linear] = value;
3979 }
3980
3981 Self::from_dense_any(result_indices, new_data)
3982 }
3983
3984 pub fn unfuse_index(
4006 &self,
4007 old_index: &DynIndex,
4008 new_indices: &[DynIndex],
4009 order: LinearizationOrder,
4010 ) -> Result<Self> {
4011 anyhow::ensure!(
4012 !new_indices.is_empty(),
4013 "unfuse_index requires at least one replacement index"
4014 );
4015
4016 let axis = self
4017 .indices
4018 .iter()
4019 .position(|idx| idx == old_index)
4020 .ok_or_else(|| anyhow::anyhow!("index {:?} not found in tensor", old_index))?;
4021
4022 let replacement_dims: Vec<usize> = new_indices.iter().map(DynIndex::dim).collect();
4023 let replacement_product = checked_product(&replacement_dims)?;
4024 anyhow::ensure!(
4025 replacement_product == old_index.dim(),
4026 "product of new index dimensions must match the replaced index dimension"
4027 );
4028
4029 let mut result_indices =
4030 Vec::with_capacity(self.indices.len() - 1usize + new_indices.len());
4031 result_indices.extend_from_slice(&self.indices[..axis]);
4032 result_indices.extend(new_indices.iter().cloned());
4033 result_indices.extend_from_slice(&self.indices[axis + 1..]);
4034 Self::validate_indices(&result_indices)?;
4035
4036 let old_dims = self.dims();
4037 let mut new_dims = Vec::with_capacity(old_dims.len() - 1usize + replacement_dims.len());
4038 new_dims.extend_from_slice(&old_dims[..axis]);
4039 new_dims.extend_from_slice(&replacement_dims);
4040 new_dims.extend_from_slice(&old_dims[axis + 1..]);
4041
4042 let old_data = self.to_vec_any()?;
4043 let mut new_data = vec![AnyScalar::new_real(0.0); old_data.len()];
4044 for (old_linear, value) in old_data.into_iter().enumerate() {
4045 let old_multi = decode_col_major_linear(old_linear, &old_dims)?;
4046 let split_multi = decode_linear_with_order(old_multi[axis], &replacement_dims, order)?;
4047 let mut new_multi = Vec::with_capacity(new_dims.len());
4048 new_multi.extend_from_slice(&old_multi[..axis]);
4049 new_multi.extend_from_slice(&split_multi);
4050 new_multi.extend_from_slice(&old_multi[axis + 1..]);
4051 let new_linear = encode_col_major_linear(&new_multi, &new_dims)?;
4052 new_data[new_linear] = value;
4053 }
4054
4055 Self::from_dense_any(result_indices, new_data)
4056 }
4057
4058 pub fn scalar<T: TensorElement>(value: T) -> Result<Self> {
4069 Self::from_dense(vec![], vec![value])
4070 }
4071
4072 pub fn zeros<T: TensorElement + Zero + Clone>(indices: Vec<DynIndex>) -> Result<Self> {
4085 let dims: Vec<usize> = indices.iter().map(|idx| idx.dim()).collect();
4086 let size: usize = dims.iter().product();
4087 Self::from_dense(indices, vec![T::zero(); size])
4088 }
4089}
4090
4091impl TensorDynLen {
4096 pub fn to_vec<T: TensorElement>(&self) -> Result<Vec<T>> {
4118 native_tensor_primal_to_dense_col_major(self.as_native()?)
4119 }
4120
4121 pub fn into_dense_col_major_parts<T: TensorElement>(self) -> Result<(Vec<DynIndex>, Vec<T>)> {
4157 anyhow::ensure!(
4158 self.structured_ad.is_none() && !self.tracks_grad(),
4159 "TensorDynLen::into_dense_col_major_parts cannot consume tensors with tracked autodiff state"
4160 );
4161 let data = self.to_vec::<T>()?;
4162 Ok((self.indices, data))
4163 }
4164
4165 fn to_vec_any(&self) -> Result<Vec<AnyScalar>> {
4166 if self.is_complex() {
4167 self.to_vec::<Complex64>().map(|data| {
4168 data.into_iter()
4169 .map(|value| AnyScalar::new_complex(value.re, value.im))
4170 .collect()
4171 })
4172 } else {
4173 self.to_vec::<f64>()
4174 .map(|data| data.into_iter().map(AnyScalar::new_real).collect())
4175 }
4176 }
4177
4178 pub fn as_slice_f64(&self) -> Result<Vec<f64>> {
4183 self.to_vec::<f64>()
4184 }
4185
4186 pub fn as_slice_c64(&self) -> Result<Vec<Complex64>> {
4191 self.to_vec::<Complex64>()
4192 }
4193
4194 pub fn is_f64(&self) -> bool {
4207 self.storage.is_f64()
4208 }
4209
4210 pub fn is_diag(&self) -> bool {
4237 self.storage.is_diag()
4238 }
4239
4240 pub fn is_complex(&self) -> bool {
4259 self.storage.is_complex()
4260 }
4261}
4262
4263fn checked_product(dims: &[usize]) -> Result<usize> {
4264 dims.iter().try_fold(1usize, |acc, &dim| {
4265 acc.checked_mul(dim)
4266 .ok_or_else(|| anyhow::anyhow!("dimension product overflow"))
4267 })
4268}
4269
4270fn decode_col_major_linear(linear: usize, dims: &[usize]) -> Result<Vec<usize>> {
4271 let total = checked_product(dims)?;
4272 anyhow::ensure!(
4273 linear < total,
4274 "linear offset {} out of bounds for dims {:?}",
4275 linear,
4276 dims
4277 );
4278 let mut remaining = linear;
4279 let mut out = Vec::with_capacity(dims.len());
4280 for &dim in dims {
4281 out.push(remaining % dim);
4282 remaining /= dim;
4283 }
4284 Ok(out)
4285}
4286
4287fn encode_col_major_linear(indices: &[usize], dims: &[usize]) -> Result<usize> {
4288 anyhow::ensure!(
4289 indices.len() == dims.len(),
4290 "index rank {} does not match dims {:?}",
4291 indices.len(),
4292 dims
4293 );
4294 let mut linear = 0usize;
4295 let mut stride = 1usize;
4296 for (&index, &dim) in indices.iter().zip(dims.iter()) {
4297 anyhow::ensure!(
4298 index < dim,
4299 "index {} out of bounds for dimension {}",
4300 index,
4301 dim
4302 );
4303 linear += index * stride;
4304 stride = stride
4305 .checked_mul(dim)
4306 .ok_or_else(|| anyhow::anyhow!("stride overflow"))?;
4307 }
4308 Ok(linear)
4309}
4310
4311fn decode_linear_with_order(
4312 linear: usize,
4313 dims: &[usize],
4314 order: LinearizationOrder,
4315) -> Result<Vec<usize>> {
4316 let total = checked_product(dims)?;
4317 anyhow::ensure!(
4318 linear < total,
4319 "linear offset {} out of bounds for dims {:?}",
4320 linear,
4321 dims
4322 );
4323
4324 let mut remaining = linear;
4325 let mut out = vec![0usize; dims.len()];
4326 match order {
4327 LinearizationOrder::ColumnMajor => {
4328 for (slot, &dim) in out.iter_mut().zip(dims.iter()) {
4329 *slot = remaining % dim;
4330 remaining /= dim;
4331 }
4332 }
4333 LinearizationOrder::RowMajor => {
4334 for (slot, &dim) in out.iter_mut().rev().zip(dims.iter().rev()) {
4335 *slot = remaining % dim;
4336 remaining /= dim;
4337 }
4338 }
4339 }
4340 Ok(out)
4341}
4342
4343fn encode_linear_with_order(
4344 indices: &[usize],
4345 dims: &[usize],
4346 order: LinearizationOrder,
4347) -> Result<usize> {
4348 match order {
4349 LinearizationOrder::ColumnMajor => encode_col_major_linear(indices, dims),
4350 LinearizationOrder::RowMajor => {
4351 anyhow::ensure!(
4352 indices.len() == dims.len(),
4353 "index rank {} does not match dims {:?}",
4354 indices.len(),
4355 dims
4356 );
4357 let mut linear = 0usize;
4358 let mut stride = 1usize;
4359 for (&index, &dim) in indices.iter().rev().zip(dims.iter().rev()) {
4360 anyhow::ensure!(
4361 index < dim,
4362 "index {} out of bounds for dimension {}",
4363 index,
4364 dim
4365 );
4366 linear += index * stride;
4367 stride = stride
4368 .checked_mul(dim)
4369 .ok_or_else(|| anyhow::anyhow!("stride overflow"))?;
4370 }
4371 Ok(linear)
4372 }
4373 }
4374}