1use crate::defaults::DynIndex;
2use crate::index_like::IndexLike;
3use crate::index_ops::{common_ind_positions, prepare_contraction, prepare_contraction_pairs};
4use crate::tensor_like::LinearizationOrder;
5use crate::{storage::Storage, storage::StorageKind, AnyScalar};
6use anyhow::Result;
7use num_complex::Complex64;
8use num_traits::Zero;
9use rand::Rng;
10use rand_distr::{Distribution, StandardNormal};
11use std::collections::HashSet;
12use std::ops::{Mul, Neg, Sub};
13use std::sync::{Arc, OnceLock};
14use tenferro::eager_einsum::eager_einsum_ad;
15use tenferro::{CpuBackend, DType, EagerTensor, Tensor as NativeTensor};
16use tensor4all_tensorbackend::{
17 axpby_native_tensor, contract_native_tensor, default_eager_ctx,
18 dense_native_tensor_from_col_major, diag_native_tensor_from_col_major,
19 native_tensor_primal_to_dense_col_major, native_tensor_primal_to_diag_c64,
20 native_tensor_primal_to_diag_f64, native_tensor_primal_to_storage,
21 reshape_col_major_native_tensor, scale_native_tensor, storage_to_native_tensor, StorageScalar,
22 TensorElement,
23};
24
25use super::structured_contraction::{
26 normalize_payload_for_roots, storage_from_payload_native, storage_payload_native,
27 OperandLayout, StructuredContractionPlan, StructuredContractionSpec,
28};
29
30pub trait RandomScalar: TensorElement {
35 fn random_value<R: Rng>(rng: &mut R) -> Self;
37}
38
39impl RandomScalar for f64 {
40 fn random_value<R: Rng>(rng: &mut R) -> Self {
41 StandardNormal.sample(rng)
42 }
43}
44
45impl RandomScalar for Complex64 {
46 fn random_value<R: Rng>(rng: &mut R) -> Self {
47 Complex64::new(StandardNormal.sample(rng), StandardNormal.sample(rng))
48 }
49}
50
51pub fn compute_permutation_from_indices(
83 original_indices: &[DynIndex],
84 new_indices: &[DynIndex],
85) -> Vec<usize> {
86 assert_eq!(
87 new_indices.len(),
88 original_indices.len(),
89 "new_indices length must match original_indices length"
90 );
91
92 let mut perm = Vec::with_capacity(new_indices.len());
93 let mut used = std::collections::HashSet::new();
94
95 for new_idx in new_indices {
96 let pos = original_indices
99 .iter()
100 .position(|old_idx| old_idx == new_idx)
101 .expect("new_indices must be a permutation of original_indices");
102
103 if used.contains(&pos) {
104 panic!("duplicate index in new_indices");
105 }
106 used.insert(pos);
107 perm.push(pos);
108 }
109
110 perm
111}
112
113pub trait TensorAccess {
115 fn indices(&self) -> &[DynIndex];
117}
118
119#[derive(Clone)]
120pub(crate) struct StructuredAdValue {
121 payload: Arc<EagerTensor<CpuBackend>>,
122 payload_dims: Vec<usize>,
123 axis_classes: Vec<usize>,
124}
125
126#[derive(Clone)]
178pub struct TensorDynLen {
179 pub indices: Vec<DynIndex>,
181 pub(crate) storage: Arc<Storage>,
183 pub(crate) structured_ad: Option<Arc<StructuredAdValue>>,
185 pub(crate) eager_cache: Arc<OnceLock<Arc<EagerTensor<CpuBackend>>>>,
187}
188
189impl TensorAccess for TensorDynLen {
190 fn indices(&self) -> &[DynIndex] {
191 &self.indices
192 }
193}
194
195impl TensorDynLen {
196 const EINSUM_LABELS: &'static [u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
197
198 fn dense_axis_classes(rank: usize) -> Vec<usize> {
199 (0..rank).collect()
200 }
201
202 fn diag_axis_classes(rank: usize) -> Vec<usize> {
203 if rank == 0 {
204 vec![]
205 } else {
206 vec![0; rank]
207 }
208 }
209
210 fn canonicalize_axis_classes(axis_classes: &[usize]) -> Vec<usize> {
211 let mut map = std::collections::HashMap::new();
212 let mut next = 0usize;
213 axis_classes
214 .iter()
215 .map(|&class_id| {
216 *map.entry(class_id).or_insert_with(|| {
217 let canonical = next;
218 next += 1;
219 canonical
220 })
221 })
222 .collect()
223 }
224
225 fn permute_axis_classes(&self, perm: &[usize]) -> Vec<usize> {
226 let axis_classes = self.storage.axis_classes();
227 let permuted: Vec<usize> = perm.iter().map(|&index| axis_classes[index]).collect();
228 Self::canonicalize_axis_classes(&permuted)
229 }
230
231 fn is_diag_axis_classes(axis_classes: &[usize]) -> bool {
232 axis_classes.len() >= 2 && axis_classes.iter().all(|&class_id| class_id == 0)
233 }
234
235 fn einsum_labels(ids: &[usize]) -> Result<String> {
236 let mut out = String::with_capacity(ids.len());
237 for &id in ids {
238 let label = Self::EINSUM_LABELS.get(id).ok_or_else(|| {
239 anyhow::anyhow!("einsum label {id} exceeds supported label range")
240 })?;
241 out.push(char::from(*label));
242 }
243 Ok(out)
244 }
245
246 fn build_binary_einsum_subscripts(
247 lhs_rank: usize,
248 axes_a: &[usize],
249 rhs_rank: usize,
250 axes_b: &[usize],
251 ) -> Result<String> {
252 anyhow::ensure!(
253 axes_a.len() == axes_b.len(),
254 "contract axis length mismatch: lhs {:?}, rhs {:?}",
255 axes_a,
256 axes_b
257 );
258
259 let mut lhs_ids = vec![usize::MAX; lhs_rank];
260 let mut rhs_ids = vec![usize::MAX; rhs_rank];
261 let mut next_id = 0usize;
262
263 let mut seen_lhs = vec![false; lhs_rank];
264 let mut seen_rhs = vec![false; rhs_rank];
265
266 for (&lhs_axis, &rhs_axis) in axes_a.iter().zip(axes_b.iter()) {
267 anyhow::ensure!(
268 lhs_axis < lhs_rank,
269 "lhs contract axis {lhs_axis} out of range"
270 );
271 anyhow::ensure!(
272 rhs_axis < rhs_rank,
273 "rhs contract axis {rhs_axis} out of range"
274 );
275 anyhow::ensure!(
276 !seen_lhs[lhs_axis],
277 "duplicate lhs contract axis {lhs_axis}"
278 );
279 anyhow::ensure!(
280 !seen_rhs[rhs_axis],
281 "duplicate rhs contract axis {rhs_axis}"
282 );
283 seen_lhs[lhs_axis] = true;
284 seen_rhs[rhs_axis] = true;
285 lhs_ids[lhs_axis] = next_id;
286 rhs_ids[rhs_axis] = next_id;
287 next_id += 1;
288 }
289
290 let mut output_ids = Vec::with_capacity(lhs_rank + rhs_rank - 2 * axes_a.len());
291 for id in &mut lhs_ids {
292 if *id == usize::MAX {
293 *id = next_id;
294 output_ids.push(next_id);
295 next_id += 1;
296 }
297 }
298 for id in &mut rhs_ids {
299 if *id == usize::MAX {
300 *id = next_id;
301 output_ids.push(next_id);
302 next_id += 1;
303 }
304 }
305
306 Ok(format!(
307 "{},{}->{}",
308 Self::einsum_labels(&lhs_ids)?,
309 Self::einsum_labels(&rhs_ids)?,
310 Self::einsum_labels(&output_ids)?,
311 ))
312 }
313
314 fn binary_contraction_axis_classes(
315 lhs_axis_classes: &[usize],
316 axes_a: &[usize],
317 rhs_axis_classes: &[usize],
318 axes_b: &[usize],
319 ) -> Vec<usize> {
320 debug_assert_eq!(axes_a.len(), axes_b.len());
321
322 fn find(parent: &mut [usize], value: usize) -> usize {
323 if parent[value] != value {
324 parent[value] = find(parent, parent[value]);
325 }
326 parent[value]
327 }
328
329 fn union(parent: &mut [usize], lhs: usize, rhs: usize) {
330 let lhs_root = find(parent, lhs);
331 let rhs_root = find(parent, rhs);
332 if lhs_root != rhs_root {
333 parent[rhs_root] = lhs_root;
334 }
335 }
336
337 let lhs_payload_rank = lhs_axis_classes
338 .iter()
339 .copied()
340 .max()
341 .map(|value| value + 1)
342 .unwrap_or(0);
343 let rhs_payload_rank = rhs_axis_classes
344 .iter()
345 .copied()
346 .max()
347 .map(|value| value + 1)
348 .unwrap_or(0);
349 let rhs_offset = lhs_payload_rank;
350 let mut parent: Vec<usize> = (0..lhs_payload_rank + rhs_payload_rank).collect();
351
352 for (&lhs_axis, &rhs_axis) in axes_a.iter().zip(axes_b.iter()) {
353 union(
354 &mut parent,
355 lhs_axis_classes[lhs_axis],
356 rhs_offset + rhs_axis_classes[rhs_axis],
357 );
358 }
359
360 let mut lhs_contracted = vec![false; lhs_axis_classes.len()];
361 for &axis in axes_a {
362 lhs_contracted[axis] = true;
363 }
364 let mut rhs_contracted = vec![false; rhs_axis_classes.len()];
365 for &axis in axes_b {
366 rhs_contracted[axis] = true;
367 }
368
369 let mut root_to_class = std::collections::HashMap::new();
370 let mut next_class = 0usize;
371 let mut axis_classes = Vec::new();
372
373 for (axis, &class_id) in lhs_axis_classes.iter().enumerate() {
374 if !lhs_contracted[axis] {
375 let root = find(&mut parent, class_id);
376 let class = *root_to_class.entry(root).or_insert_with(|| {
377 let value = next_class;
378 next_class += 1;
379 value
380 });
381 axis_classes.push(class);
382 }
383 }
384 for (axis, &class_id) in rhs_axis_classes.iter().enumerate() {
385 if !rhs_contracted[axis] {
386 let root = find(&mut parent, rhs_offset + class_id);
387 let class = *root_to_class.entry(root).or_insert_with(|| {
388 let value = next_class;
389 next_class += 1;
390 value
391 });
392 axis_classes.push(class);
393 }
394 }
395
396 axis_classes
397 }
398
399 fn scale_subscripts(rank: usize) -> Result<String> {
400 if rank == 0 {
401 Ok("->".to_string())
402 } else {
403 let ids: Vec<usize> = (0..rank).collect();
404 let labels = Self::einsum_labels(&ids)?;
405 Ok(format!("{labels},->{labels}"))
406 }
407 }
408
409 fn validate_indices(indices: &[DynIndex]) {
410 let mut seen = HashSet::new();
411 for idx in indices {
412 assert!(
413 seen.insert(idx.clone()),
414 "Tensor indices must all be unique"
415 );
416 }
417 }
418
419 fn validate_diag_dims(dims: &[usize]) -> Result<()> {
420 if !dims.is_empty() {
421 let first_dim = dims[0];
422 for (i, &dim) in dims.iter().enumerate() {
423 anyhow::ensure!(
424 dim == first_dim,
425 "DiagTensor requires all indices to have the same dimension, but dims[{i}] = {dim} != dims[0] = {first_dim}"
426 );
427 }
428 }
429 Ok(())
430 }
431
432 fn seed_native_payload(storage: &Storage, dims: &[usize]) -> Result<NativeTensor> {
433 storage_to_native_tensor(storage, dims)
434 }
435
436 fn empty_eager_cache() -> Arc<OnceLock<Arc<EagerTensor<CpuBackend>>>> {
437 Arc::new(OnceLock::new())
438 }
439
440 fn eager_cache_with(
441 inner: EagerTensor<CpuBackend>,
442 ) -> Arc<OnceLock<Arc<EagerTensor<CpuBackend>>>> {
443 let cache = Arc::new(OnceLock::new());
444 let _ = cache.set(Arc::new(inner));
445 cache
446 }
447
448 fn compact_payload_inner(&self) -> Result<EagerTensor<CpuBackend>> {
449 Ok(EagerTensor::from_tensor_in(
450 storage_payload_native(self.storage.as_ref())?,
451 default_eager_ctx(),
452 ))
453 }
454
455 fn tracked_compact_payload_value(&self) -> Option<&StructuredAdValue> {
456 self.structured_ad.as_deref()
457 }
458
459 fn compact_payload_is_logical_dense(&self, payload_dims: &[usize]) -> bool {
460 self.storage.axis_classes() == Self::dense_axis_classes(self.indices.len())
461 && payload_dims == self.dims()
462 }
463
464 fn build_binary_contraction_labels(
465 lhs_rank: usize,
466 axes_a: &[usize],
467 rhs_rank: usize,
468 axes_b: &[usize],
469 ) -> Result<(Vec<usize>, Vec<usize>, Vec<usize>)> {
470 anyhow::ensure!(
471 axes_a.len() == axes_b.len(),
472 "contract axis length mismatch: lhs {:?}, rhs {:?}",
473 axes_a,
474 axes_b
475 );
476
477 let mut lhs_ids = vec![usize::MAX; lhs_rank];
478 let mut rhs_ids = vec![usize::MAX; rhs_rank];
479 let mut next_id = 0usize;
480
481 let mut seen_lhs = vec![false; lhs_rank];
482 let mut seen_rhs = vec![false; rhs_rank];
483
484 for (&lhs_axis, &rhs_axis) in axes_a.iter().zip(axes_b.iter()) {
485 anyhow::ensure!(
486 lhs_axis < lhs_rank,
487 "lhs contract axis {lhs_axis} out of range"
488 );
489 anyhow::ensure!(
490 rhs_axis < rhs_rank,
491 "rhs contract axis {rhs_axis} out of range"
492 );
493 anyhow::ensure!(
494 !seen_lhs[lhs_axis],
495 "duplicate lhs contract axis {lhs_axis}"
496 );
497 anyhow::ensure!(
498 !seen_rhs[rhs_axis],
499 "duplicate rhs contract axis {rhs_axis}"
500 );
501 seen_lhs[lhs_axis] = true;
502 seen_rhs[rhs_axis] = true;
503 lhs_ids[lhs_axis] = next_id;
504 rhs_ids[rhs_axis] = next_id;
505 next_id += 1;
506 }
507
508 let mut output_ids = Vec::with_capacity(lhs_rank + rhs_rank - 2 * axes_a.len());
509 for id in &mut lhs_ids {
510 if *id == usize::MAX {
511 *id = next_id;
512 output_ids.push(next_id);
513 next_id += 1;
514 }
515 }
516 for id in &mut rhs_ids {
517 if *id == usize::MAX {
518 *id = next_id;
519 output_ids.push(next_id);
520 next_id += 1;
521 }
522 }
523
524 Ok((lhs_ids, rhs_ids, output_ids))
525 }
526
527 fn build_payload_einsum_subscripts(
528 input_roots: &[Vec<usize>],
529 output_roots: &[usize],
530 ) -> Result<String> {
531 let input_labels = input_roots
532 .iter()
533 .map(|roots| Self::einsum_labels(roots))
534 .collect::<Result<Vec<_>>>()?;
535 let output = Self::einsum_labels(output_roots)?;
536 Ok(format!("{}->{}", input_labels.join(","), output))
537 }
538
539 fn normalize_eager_payload_for_roots(
540 payload: &EagerTensor<CpuBackend>,
541 roots: &[usize],
542 ) -> Result<(Option<EagerTensor<CpuBackend>>, Vec<usize>)> {
543 anyhow::ensure!(
544 payload.data().shape().len() == roots.len(),
545 "payload rank {} does not match root label count {}",
546 payload.data().shape().len(),
547 roots.len()
548 );
549
550 let mut current_payload = None;
551 let mut current_roots = roots.to_vec();
552 while let Some((axis_a, axis_b)) = Self::first_duplicate_pair(¤t_roots) {
553 let source = current_payload.as_ref().unwrap_or(payload);
554 current_payload = Some(source.extract_diag(axis_a, axis_b)?);
555 current_roots.remove(axis_b);
556 }
557
558 Ok((current_payload, current_roots))
559 }
560
561 fn first_duplicate_pair(values: &[usize]) -> Option<(usize, usize)> {
562 let mut first_axis_by_value = std::collections::HashMap::new();
563 for (axis, &value) in values.iter().enumerate() {
564 if let Some(&first_axis) = first_axis_by_value.get(&value) {
565 return Some((first_axis, axis));
566 }
567 first_axis_by_value.insert(value, axis);
568 }
569 None
570 }
571
572 fn binary_structured_contraction_plan(
573 &self,
574 other: &Self,
575 axes_a: &[usize],
576 axes_b: &[usize],
577 ) -> Result<(StructuredContractionPlan, Vec<Vec<usize>>, Vec<usize>)> {
578 let (lhs_labels, rhs_labels, output_labels) = Self::build_binary_contraction_labels(
579 self.indices.len(),
580 axes_a,
581 other.indices.len(),
582 axes_b,
583 )?;
584 let operands = vec![
585 OperandLayout::new(self.dims(), self.storage.axis_classes().to_vec())?,
586 OperandLayout::new(other.dims(), other.storage.axis_classes().to_vec())?,
587 ];
588 let spec = StructuredContractionSpec {
589 input_labels: vec![lhs_labels, rhs_labels],
590 output_labels,
591 retained_labels: Default::default(),
592 };
593 let plan = StructuredContractionPlan::new(&operands, &spec)?;
594 Ok((plan, spec.input_labels, spec.output_labels))
595 }
596
597 fn from_structured_payload_inner(
598 indices: Vec<DynIndex>,
599 payload_inner: EagerTensor<CpuBackend>,
600 payload_dims: Vec<usize>,
601 axis_classes: Vec<usize>,
602 ) -> Result<Self> {
603 Self::validate_indices(&indices);
604 if payload_inner.data().shape() != payload_dims {
605 return Err(anyhow::anyhow!(
606 "structured payload dims {:?} do not match planned payload dims {:?}",
607 payload_inner.data().shape(),
608 payload_dims
609 ));
610 }
611 let storage = storage_from_payload_native(
612 payload_inner.data().clone(),
613 &payload_dims,
614 axis_classes.clone(),
615 )?;
616 Self::validate_storage_matches_indices(&indices, &storage)?;
617 Ok(Self {
618 indices,
619 storage: Arc::new(storage),
620 structured_ad: Some(Arc::new(StructuredAdValue {
621 payload: Arc::new(payload_inner),
622 payload_dims,
623 axis_classes,
624 })),
625 eager_cache: Self::empty_eager_cache(),
626 })
627 }
628
629 fn contract_structured_payloads(
630 &self,
631 other: &Self,
632 result_indices: Vec<DynIndex>,
633 axes_a: &[usize],
634 axes_b: &[usize],
635 ) -> Result<Self> {
636 let (plan, _, _) = self.binary_structured_contraction_plan(other, axes_a, axes_b)?;
637 let lhs_roots = plan.operand_plans[0].class_roots.clone();
638 let rhs_roots = plan.operand_plans[1].class_roots.clone();
639 let scalar_multiply =
640 lhs_roots.is_empty() && rhs_roots.is_empty() && plan.output_payload_roots.is_empty();
641
642 if let (Some(lhs_ad), Some(rhs_ad)) = (
643 self.tracked_compact_payload_value(),
644 other.tracked_compact_payload_value(),
645 ) {
646 if lhs_ad.payload.data().dtype() != rhs_ad.payload.data().dtype() {
647 return Err(anyhow::anyhow!(
648 "structured AD contraction requires matching payload dtypes"
649 ));
650 }
651 let (lhs_normalized, lhs_labels) =
652 Self::normalize_eager_payload_for_roots(lhs_ad.payload.as_ref(), &lhs_roots)?;
653 let (rhs_normalized, rhs_labels) =
654 Self::normalize_eager_payload_for_roots(rhs_ad.payload.as_ref(), &rhs_roots)?;
655 let lhs_payload = lhs_normalized
656 .as_ref()
657 .unwrap_or_else(|| lhs_ad.payload.as_ref());
658 let rhs_payload = rhs_normalized
659 .as_ref()
660 .unwrap_or_else(|| rhs_ad.payload.as_ref());
661 let payload = if scalar_multiply {
662 lhs_payload.mul(rhs_payload)?
663 } else {
664 let subscripts = Self::build_payload_einsum_subscripts(
665 &[lhs_labels, rhs_labels],
666 &plan.output_payload_roots,
667 )?;
668 eager_einsum_ad(&[lhs_payload, rhs_payload], &subscripts)?
669 };
670 return Self::from_structured_payload_inner(
671 result_indices,
672 payload,
673 plan.output_payload_dims,
674 plan.output_axis_classes,
675 );
676 }
677
678 if self.tracked_compact_payload_value().is_some()
679 || other.tracked_compact_payload_value().is_some()
680 {
681 let lhs_owned = if self.tracked_compact_payload_value().is_some() {
682 None
683 } else {
684 Some(self.compact_payload_inner()?)
685 };
686 let rhs_owned = if other.tracked_compact_payload_value().is_some() {
687 None
688 } else {
689 Some(other.compact_payload_inner()?)
690 };
691 let lhs = if let Some(value) = self.tracked_compact_payload_value() {
692 value.payload.as_ref()
693 } else {
694 lhs_owned.as_ref().unwrap()
695 };
696 let rhs = if let Some(value) = other.tracked_compact_payload_value() {
697 value.payload.as_ref()
698 } else {
699 rhs_owned.as_ref().unwrap()
700 };
701 if lhs.data().dtype() != rhs.data().dtype() {
702 return Err(anyhow::anyhow!(
703 "structured AD contraction requires matching payload dtypes"
704 ));
705 }
706 let (lhs_normalized, lhs_labels) =
707 Self::normalize_eager_payload_for_roots(lhs, &lhs_roots)?;
708 let (rhs_normalized, rhs_labels) =
709 Self::normalize_eager_payload_for_roots(rhs, &rhs_roots)?;
710 let lhs_payload = lhs_normalized.as_ref().unwrap_or(lhs);
711 let rhs_payload = rhs_normalized.as_ref().unwrap_or(rhs);
712 let payload = if scalar_multiply {
713 lhs_payload.mul(rhs_payload)?
714 } else {
715 let subscripts = Self::build_payload_einsum_subscripts(
716 &[lhs_labels, rhs_labels],
717 &plan.output_payload_roots,
718 )?;
719 eager_einsum_ad(&[lhs_payload, rhs_payload], &subscripts)?
720 };
721 return Self::from_structured_payload_inner(
722 result_indices,
723 payload,
724 plan.output_payload_dims,
725 plan.output_axis_classes,
726 );
727 }
728
729 let lhs = storage_payload_native(self.storage.as_ref())?;
730 let rhs = storage_payload_native(other.storage.as_ref())?;
731 if lhs.dtype() != rhs.dtype() {
732 return Err(anyhow::anyhow!(
733 "structured payload contraction requires matching payload dtypes"
734 ));
735 }
736 let (lhs, lhs_labels) = normalize_payload_for_roots(&lhs, &lhs_roots)?;
737 let (rhs, rhs_labels) = normalize_payload_for_roots(&rhs, &rhs_roots)?;
738 let payload = tensor4all_tensorbackend::einsum_native_tensors(
739 &[(&lhs, lhs_labels.as_slice()), (&rhs, rhs_labels.as_slice())],
740 &plan.output_payload_roots,
741 )?;
742 let storage = storage_from_payload_native(
743 payload,
744 &plan.output_payload_dims,
745 plan.output_axis_classes,
746 )?;
747 Self::from_storage(result_indices, Arc::new(storage))
748 }
749
750 fn should_use_structured_payload_contract(&self, other: &Self) -> bool {
751 let same_payload_dtype = self.storage.is_f64() == other.storage.is_f64()
752 && self.storage.is_complex() == other.storage.is_complex();
753 same_payload_dtype
754 && (self.tracked_compact_payload_value().is_some()
755 || other.tracked_compact_payload_value().is_some()
756 || self.storage.axis_classes() != Self::dense_axis_classes(self.indices.len())
757 || other.storage.axis_classes() != Self::dense_axis_classes(other.indices.len()))
758 }
759
760 fn storage_from_native_with_axis_classes(
761 native: &NativeTensor,
762 axis_classes: &[usize],
763 logical_rank: usize,
764 ) -> Result<Storage> {
765 if Self::is_diag_axis_classes(axis_classes) {
766 match native.dtype() {
767 DType::F32 | DType::F64 | DType::I64 => Storage::from_diag_col_major(
768 native_tensor_primal_to_diag_f64(native)?,
769 logical_rank,
770 ),
771 DType::C32 | DType::C64 => Storage::from_diag_col_major(
772 native_tensor_primal_to_diag_c64(native)?,
773 logical_rank,
774 ),
775 }
776 } else {
777 native_tensor_primal_to_storage(native)
778 }
779 }
780
781 fn dense_selected_diag_payload<T: TensorElement + Copy + Zero>(
782 payload: Vec<T>,
783 kept_dims: &[usize],
784 selected_positions: &[usize],
785 ) -> Vec<T> {
786 let output_len = kept_dims.iter().product::<usize>();
787 let mut data = vec![T::zero(); output_len];
788 if output_len == 0 {
789 return data;
790 }
791
792 let Some((&first_position, rest)) = selected_positions.split_first() else {
793 return data;
794 };
795 if rest.iter().any(|&position| position != first_position) {
796 return data;
797 }
798
799 let value = payload[first_position];
800 if kept_dims.is_empty() {
801 data[0] = value;
802 return data;
803 }
804
805 let mut offset = 0usize;
806 let mut stride = 1usize;
807 for &dim in kept_dims {
808 offset += first_position * stride;
809 stride *= dim;
810 }
811 data[offset] = value;
812 data
813 }
814
815 fn select_diag_indices(
816 &self,
817 kept_indices: Vec<DynIndex>,
818 kept_dims: Vec<usize>,
819 positions: &[usize],
820 ) -> Result<Self> {
821 if self.storage.is_f64() {
822 let payload = self
823 .storage
824 .payload_f64_col_major_vec()
825 .map_err(anyhow::Error::msg)?;
826 let data = Self::dense_selected_diag_payload(payload, &kept_dims, positions);
827 Self::from_dense(kept_indices, data)
828 } else if self.storage.is_c64() {
829 let payload = self
830 .storage
831 .payload_c64_col_major_vec()
832 .map_err(anyhow::Error::msg)?;
833 let data = Self::dense_selected_diag_payload(payload, &kept_dims, positions);
834 Self::from_dense(kept_indices, data)
835 } else {
836 Err(anyhow::anyhow!("unsupported diagonal storage scalar type"))
837 }
838 }
839
840 fn col_major_strides(dims: &[usize]) -> Result<Vec<isize>> {
841 let mut strides = Vec::with_capacity(dims.len());
842 let mut stride = 1isize;
843 for &dim in dims {
844 strides.push(stride);
845 let dim = isize::try_from(dim)
846 .map_err(|_| anyhow::anyhow!("dimension does not fit in isize"))?;
847 stride = stride
848 .checked_mul(dim)
849 .ok_or_else(|| anyhow::anyhow!("column-major stride overflow"))?;
850 }
851 Ok(strides)
852 }
853
854 fn zero_structured_selection<T>(
855 kept_indices: Vec<DynIndex>,
856 kept_dims: &[usize],
857 ) -> Result<Self>
858 where
859 T: TensorElement + Zero,
860 {
861 let output_len = checked_product(kept_dims)?;
862 Self::from_dense(kept_indices, vec![T::zero(); output_len])
863 }
864
865 fn select_structured_indices_typed<T>(
866 &self,
867 payload: Vec<T>,
868 kept_axes: &[usize],
869 kept_indices: Vec<DynIndex>,
870 kept_dims: Vec<usize>,
871 selected_axes: &[usize],
872 positions: &[usize],
873 ) -> Result<Self>
874 where
875 T: TensorElement + StorageScalar + Zero,
876 {
877 let payload_dims = self.storage.payload_dims();
878 let axis_classes = self.storage.axis_classes();
879 let payload_rank = payload_dims.len();
880 let mut selected_class_positions = vec![None; payload_rank];
881
882 for (&axis, &position) in selected_axes.iter().zip(positions.iter()) {
883 let class_id = axis_classes[axis];
884 if let Some(existing) = selected_class_positions[class_id] {
885 if existing != position {
886 return Self::zero_structured_selection::<T>(kept_indices, &kept_dims);
887 }
888 } else {
889 selected_class_positions[class_id] = Some(position);
890 }
891 }
892
893 let selected_class_kept = kept_axes
894 .iter()
895 .any(|&axis| selected_class_positions[axis_classes[axis]].is_some());
896 if selected_class_kept {
897 return self.select_structured_indices_dense(
898 payload,
899 kept_axes,
900 kept_indices,
901 kept_dims,
902 &selected_class_positions,
903 );
904 }
905
906 let mut old_to_new_class = vec![None; payload_rank];
907 let mut output_payload_dims = Vec::new();
908 let mut output_axis_classes = Vec::with_capacity(kept_axes.len());
909 for &axis in kept_axes {
910 let class_id = axis_classes[axis];
911 let new_class = match old_to_new_class[class_id] {
912 Some(new_class) => new_class,
913 None => {
914 let new_class = output_payload_dims.len();
915 old_to_new_class[class_id] = Some(new_class);
916 output_payload_dims.push(payload_dims[class_id]);
917 new_class
918 }
919 };
920 output_axis_classes.push(new_class);
921 }
922
923 let output_len = checked_product(&output_payload_dims)?;
924 let mut output_payload = Vec::with_capacity(output_len);
925 for linear in 0..output_len {
926 let output_payload_index = decode_col_major_linear(linear, &output_payload_dims)?;
927 let mut input_payload_index = vec![0usize; payload_rank];
928 for class_id in 0..payload_rank {
929 input_payload_index[class_id] =
930 if let Some(position) = selected_class_positions[class_id] {
931 position
932 } else if let Some(new_class) = old_to_new_class[class_id] {
933 output_payload_index[new_class]
934 } else {
935 return Err(anyhow::anyhow!(
936 "structured payload class {class_id} is neither selected nor kept"
937 ));
938 };
939 }
940 let input_linear = encode_col_major_linear(&input_payload_index, payload_dims)?;
941 output_payload.push(payload[input_linear]);
942 }
943
944 let output_strides = Self::col_major_strides(&output_payload_dims)?;
945 let storage = Storage::new_structured(
946 output_payload,
947 output_payload_dims,
948 output_strides,
949 output_axis_classes,
950 )?;
951 Self::from_storage(kept_indices, Arc::new(storage))
952 }
953
954 fn select_structured_indices_dense<T>(
955 &self,
956 payload: Vec<T>,
957 kept_axes: &[usize],
958 kept_indices: Vec<DynIndex>,
959 kept_dims: Vec<usize>,
960 selected_class_positions: &[Option<usize>],
961 ) -> Result<Self>
962 where
963 T: TensorElement + Zero,
964 {
965 let payload_dims = self.storage.payload_dims();
966 let axis_classes = self.storage.axis_classes();
967 let output_len = checked_product(&kept_dims)?;
968 let mut output = Vec::with_capacity(output_len);
969
970 for linear in 0..output_len {
971 let kept_position = decode_col_major_linear(linear, &kept_dims)?;
972 let mut input_payload_index = selected_class_positions.to_vec();
973 let mut is_structural_zero = false;
974
975 for (&axis, &position) in kept_axes.iter().zip(kept_position.iter()) {
976 let class_id = axis_classes[axis];
977 match input_payload_index[class_id] {
978 Some(existing) if existing != position => {
979 is_structural_zero = true;
980 break;
981 }
982 Some(_) => {}
983 None => input_payload_index[class_id] = Some(position),
984 }
985 }
986
987 if is_structural_zero {
988 output.push(T::zero());
989 continue;
990 }
991
992 let input_payload_index = input_payload_index
993 .into_iter()
994 .enumerate()
995 .map(|(class_id, position)| {
996 position.ok_or_else(|| {
997 anyhow::anyhow!(
998 "structured payload class {class_id} is neither selected nor kept"
999 )
1000 })
1001 })
1002 .collect::<Result<Vec<_>>>()?;
1003 let input_linear = encode_col_major_linear(&input_payload_index, payload_dims)?;
1004 output.push(payload[input_linear]);
1005 }
1006
1007 Self::from_dense(kept_indices, output)
1008 }
1009
1010 fn select_structured_indices(
1011 &self,
1012 kept_axes: &[usize],
1013 kept_indices: Vec<DynIndex>,
1014 kept_dims: Vec<usize>,
1015 selected_axes: &[usize],
1016 positions: &[usize],
1017 ) -> Result<Self> {
1018 if self.storage.is_f64() {
1019 let payload = self
1020 .storage
1021 .payload_f64_col_major_vec()
1022 .map_err(anyhow::Error::msg)?;
1023 self.select_structured_indices_typed(
1024 payload,
1025 kept_axes,
1026 kept_indices,
1027 kept_dims,
1028 selected_axes,
1029 positions,
1030 )
1031 } else if self.storage.is_c64() {
1032 let payload = self
1033 .storage
1034 .payload_c64_col_major_vec()
1035 .map_err(anyhow::Error::msg)?;
1036 self.select_structured_indices_typed(
1037 payload,
1038 kept_axes,
1039 kept_indices,
1040 kept_dims,
1041 selected_axes,
1042 positions,
1043 )
1044 } else {
1045 Err(anyhow::anyhow!(
1046 "unsupported structured storage scalar type"
1047 ))
1048 }
1049 }
1050
1051 fn validate_storage_matches_indices(indices: &[DynIndex], storage: &Storage) -> Result<()> {
1052 let dims = Self::expected_dims_from_indices(indices);
1053 let storage_dims = storage.logical_dims();
1054 if storage_dims != dims {
1055 return Err(anyhow::anyhow!(
1056 "storage logical dims {:?} do not match indices dims {:?}",
1057 storage_dims,
1058 dims
1059 ));
1060 }
1061 if storage.is_diag() {
1062 Self::validate_diag_dims(&dims)?;
1063 }
1064 Ok(())
1065 }
1066
1067 fn materialized_inner(&self) -> &EagerTensor<CpuBackend> {
1068 if let Some(value) = self.tracked_compact_payload_value() {
1069 if self.compact_payload_is_logical_dense(&value.payload_dims) {
1070 return value.payload.as_ref();
1071 }
1072 }
1073 self.eager_cache
1074 .get_or_init(|| {
1075 let native = Self::seed_native_payload(self.storage.as_ref(), &self.dims())
1076 .unwrap_or_else(|err| panic!("TensorDynLen materialization failed: {err}"));
1077 Arc::new(EagerTensor::from_tensor_in(native, default_eager_ctx()))
1078 })
1079 .as_ref()
1080 }
1081
1082 pub(crate) fn as_inner(&self) -> &EagerTensor<CpuBackend> {
1083 self.materialized_inner()
1084 }
1085
1086 #[inline]
1088 fn expected_dims_from_indices(indices: &[DynIndex]) -> Vec<usize> {
1089 indices.iter().map(|idx| idx.dim()).collect()
1090 }
1091
1092 pub fn dims(&self) -> Vec<usize> {
1111 Self::expected_dims_from_indices(&self.indices)
1112 }
1113
1114 pub fn select_indices(
1156 &self,
1157 selected_indices: &[DynIndex],
1158 positions: &[usize],
1159 ) -> Result<Self> {
1160 if selected_indices.len() != positions.len() {
1161 return Err(anyhow::anyhow!(
1162 "selected_indices length {} does not match positions length {}",
1163 selected_indices.len(),
1164 positions.len()
1165 ));
1166 }
1167 if selected_indices.is_empty() {
1168 return Ok(self.clone());
1169 }
1170
1171 let mut selected_axes = Vec::with_capacity(selected_indices.len());
1172 let mut seen_axes = HashSet::with_capacity(selected_indices.len());
1173 for (selected, &position) in selected_indices.iter().zip(positions.iter()) {
1174 let axis = self
1175 .indices
1176 .iter()
1177 .position(|index| index == selected)
1178 .ok_or_else(|| anyhow::anyhow!("selected index is not present in tensor"))?;
1179 if !seen_axes.insert(axis) {
1180 return Err(anyhow::anyhow!("selected index appears more than once"));
1181 }
1182 let dim = self.indices[axis].dim();
1183 if position >= dim {
1184 return Err(anyhow::anyhow!(
1185 "selected coordinate {position} is out of range for axis {axis} with dim {dim}"
1186 ));
1187 }
1188 selected_axes.push(axis);
1189 }
1190
1191 let kept_axes = self
1192 .indices
1193 .iter()
1194 .enumerate()
1195 .filter(|(axis, _)| !seen_axes.contains(axis))
1196 .map(|(axis, _)| axis)
1197 .collect::<Vec<_>>();
1198 let kept_indices = kept_axes
1199 .iter()
1200 .map(|&axis| self.indices[axis].clone())
1201 .collect::<Vec<_>>();
1202 let kept_dims = kept_axes
1203 .iter()
1204 .map(|&axis| self.indices[axis].dim())
1205 .collect::<Vec<_>>();
1206
1207 if self.storage.storage_kind() == StorageKind::Diagonal {
1208 return self.select_diag_indices(kept_indices, kept_dims, positions);
1209 }
1210 if self.storage.storage_kind() == StorageKind::Structured {
1211 return self.select_structured_indices(
1212 &kept_axes,
1213 kept_indices,
1214 kept_dims,
1215 &selected_axes,
1216 positions,
1217 );
1218 }
1219 if self.storage.storage_kind() != StorageKind::Dense {
1220 return Err(anyhow::anyhow!(
1221 "select_indices got unsupported storage kind {:?}",
1222 self.storage.storage_kind()
1223 ));
1224 }
1225
1226 let rank = self.indices.len();
1227 let mut starts = vec![0_i64; rank];
1228 let mut slice_sizes = self.dims();
1229 for (&axis, &position) in selected_axes.iter().zip(positions.iter()) {
1230 starts[axis] = i64::try_from(position)
1231 .map_err(|_| anyhow::anyhow!("selected coordinate does not fit in i64"))?;
1232 slice_sizes[axis] = 1;
1233 }
1234
1235 let starts_tensor = EagerTensor::from_tensor_in(
1236 NativeTensor::from_vec(vec![rank], starts),
1237 default_eager_ctx(),
1238 );
1239 let sliced = self
1240 .materialized_inner()
1241 .dynamic_slice(&starts_tensor, &slice_sizes)?;
1242 Self::from_inner(kept_indices, sliced.reshape(&kept_dims)?)
1243 }
1244
1245 pub fn new(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Self {
1263 match Self::from_storage(indices, storage) {
1264 Ok(tensor) => tensor,
1265 Err(err) => panic!("TensorDynLen::new failed: {err}"),
1266 }
1267 }
1268
1269 pub fn from_indices(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Self {
1289 Self::new(indices, storage)
1290 }
1291
1292 pub fn from_storage(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Result<Self> {
1307 Self::validate_indices(&indices);
1308 Self::validate_storage_matches_indices(&indices, storage.as_ref())?;
1309 Ok(Self {
1310 indices,
1311 storage,
1312 structured_ad: None,
1313 eager_cache: Self::empty_eager_cache(),
1314 })
1315 }
1316
1317 pub fn from_structured_storage(indices: Vec<DynIndex>, storage: Arc<Storage>) -> Result<Self> {
1340 Self::from_storage(indices, storage)
1341 }
1342
1343 pub(crate) fn from_native(indices: Vec<DynIndex>, native: NativeTensor) -> Result<Self> {
1345 let axis_classes = Self::dense_axis_classes(indices.len());
1346 Self::from_native_with_axis_classes(indices, native, axis_classes)
1347 }
1348
1349 pub(crate) fn from_native_with_axis_classes(
1350 indices: Vec<DynIndex>,
1351 native: NativeTensor,
1352 axis_classes: Vec<usize>,
1353 ) -> Result<Self> {
1354 Self::from_inner_with_axis_classes(
1355 indices,
1356 EagerTensor::from_tensor_in(native, default_eager_ctx()),
1357 axis_classes,
1358 )
1359 }
1360
1361 pub(crate) fn from_inner(
1362 indices: Vec<DynIndex>,
1363 inner: EagerTensor<CpuBackend>,
1364 ) -> Result<Self> {
1365 let axis_classes = Self::dense_axis_classes(indices.len());
1366 Self::from_inner_with_axis_classes(indices, inner, axis_classes)
1367 }
1368
1369 pub(crate) fn from_inner_with_axis_classes(
1370 indices: Vec<DynIndex>,
1371 inner: EagerTensor<CpuBackend>,
1372 axis_classes: Vec<usize>,
1373 ) -> Result<Self> {
1374 let dims = Self::expected_dims_from_indices(&indices);
1375 Self::validate_indices(&indices);
1376 if dims != inner.data().shape() {
1377 return Err(anyhow::anyhow!(
1378 "native payload dims {:?} do not match indices dims {:?}",
1379 inner.data().shape(),
1380 dims
1381 ));
1382 }
1383 if Self::is_diag_axis_classes(&axis_classes) {
1384 Self::validate_diag_dims(&dims)?;
1385 }
1386 let storage = Self::storage_from_native_with_axis_classes(
1387 inner.data(),
1388 &axis_classes,
1389 indices.len(),
1390 )?;
1391 Ok(Self {
1392 indices,
1393 storage: Arc::new(storage),
1394 structured_ad: None,
1395 eager_cache: Self::eager_cache_with(inner),
1396 })
1397 }
1398
1399 pub fn indices(&self) -> &[DynIndex] {
1401 &self.indices
1402 }
1403
1404 pub(crate) fn as_native(&self) -> &NativeTensor {
1406 self.materialized_inner().data()
1407 }
1408
1409 pub fn enable_grad(self) -> Self {
1411 let payload = storage_payload_native(self.storage.as_ref())
1412 .unwrap_or_else(|err| panic!("TensorDynLen::enable_grad failed: {err}"));
1413 let payload_dims = self.storage.payload_dims().to_vec();
1414 let axis_classes = self.storage.axis_classes().to_vec();
1415 Self {
1416 indices: self.indices,
1417 storage: self.storage,
1418 structured_ad: Some(Arc::new(StructuredAdValue {
1419 payload: Arc::new(EagerTensor::requires_grad_in(payload, default_eager_ctx())),
1420 payload_dims,
1421 axis_classes,
1422 })),
1423 eager_cache: Self::empty_eager_cache(),
1424 }
1425 }
1426
1427 pub fn tracks_grad(&self) -> bool {
1429 self.structured_ad
1430 .as_ref()
1431 .is_some_and(|value| value.payload.tracks_grad())
1432 || self
1433 .eager_cache
1434 .get()
1435 .is_some_and(|inner| inner.tracks_grad())
1436 }
1437
1438 pub fn grad(&self) -> Result<Option<Self>> {
1440 if let Some(value) = self.tracked_compact_payload_value() {
1441 return value
1442 .payload
1443 .grad()
1444 .map(|grad| {
1445 let storage = storage_from_payload_native(
1446 grad.as_ref().clone(),
1447 &value.payload_dims,
1448 value.axis_classes.clone(),
1449 )?;
1450 Self::from_storage(self.indices.clone(), Arc::new(storage))
1451 })
1452 .transpose();
1453 }
1454 self.materialized_inner()
1455 .grad()
1456 .map(|grad| {
1457 Self::from_native_with_axis_classes(
1458 self.indices.clone(),
1459 grad.as_ref().clone(),
1460 self.storage.axis_classes().to_vec(),
1461 )
1462 })
1463 .transpose()
1464 }
1465
1466 pub fn clear_grad(&self) -> Result<()> {
1468 if let Some(value) = self.tracked_compact_payload_value() {
1469 value.payload.clear_grad();
1470 }
1471 if let Some(inner) = self.eager_cache.get() {
1472 inner.clear_grad();
1473 }
1474 Ok(())
1475 }
1476
1477 pub fn backward(&self) -> Result<()> {
1479 if let Some(value) = self.tracked_compact_payload_value() {
1480 return value
1481 .payload
1482 .backward()
1483 .map(|_| ())
1484 .map_err(|e| anyhow::anyhow!("TensorDynLen::backward failed: {e}"));
1485 }
1486 self.materialized_inner()
1487 .backward()
1488 .map(|_| ())
1489 .map_err(|e| anyhow::anyhow!("TensorDynLen::backward failed: {e}"))
1490 }
1491
1492 pub fn detach(&self) -> Self {
1494 if self.tracked_compact_payload_value().is_some() {
1495 return Self::from_storage(self.indices.clone(), Arc::clone(&self.storage))
1496 .expect("TensorDynLen::detach returned invalid tensor");
1497 }
1498 Self::from_inner_with_axis_classes(
1499 self.indices.clone(),
1500 self.materialized_inner().detach(),
1501 self.storage.axis_classes().to_vec(),
1502 )
1503 .expect("TensorDynLen::detach returned invalid tensor")
1504 }
1505
1506 pub fn is_simple(&self) -> bool {
1508 true
1509 }
1510
1511 pub fn to_storage(&self) -> Result<Arc<Storage>> {
1513 Ok(Arc::clone(&self.storage))
1514 }
1515
1516 pub fn storage(&self) -> Arc<Storage> {
1518 Arc::clone(&self.storage)
1519 }
1520
1521 pub fn sum(&self) -> AnyScalar {
1534 if self.indices.is_empty() {
1535 return AnyScalar::from_tensor_unchecked(self.clone());
1536 }
1537 let axes: Vec<usize> = (0..self.indices.len()).collect();
1538 let reduced = self
1539 .materialized_inner()
1540 .reduce_sum(&axes)
1541 .unwrap_or_else(|e| panic!("TensorDynLen::sum failed: {e}"));
1542 AnyScalar::from_tensor_unchecked(
1543 Self::from_inner(Vec::new(), reduced)
1544 .unwrap_or_else(|e| panic!("TensorDynLen::sum returned invalid scalar: {e}")),
1545 )
1546 }
1547
1548 pub fn only(&self) -> AnyScalar {
1569 let dims = self.dims();
1570 let total_size: usize = dims.iter().product();
1571 assert!(
1572 total_size == 1 || dims.is_empty(),
1573 "only() requires a scalar tensor (1 element), got {} elements with dims {:?}",
1574 if dims.is_empty() { 1 } else { total_size },
1575 dims
1576 );
1577 self.sum()
1578 }
1579
1580 pub fn permute_indices(&self, new_indices: &[DynIndex]) -> Self {
1611 let perm = compute_permutation_from_indices(&self.indices, new_indices);
1613 if perm.iter().copied().eq(0..perm.len()) {
1614 return Self {
1615 indices: new_indices.to_vec(),
1616 storage: Arc::clone(&self.storage),
1617 structured_ad: self.structured_ad.clone(),
1618 eager_cache: Arc::clone(&self.eager_cache),
1619 };
1620 }
1621
1622 let permuted = self
1623 .materialized_inner()
1624 .transpose(&perm)
1625 .unwrap_or_else(|e| panic!("TensorDynLen::permute_indices failed: {e}"));
1626 let axis_classes = self.permute_axis_classes(&perm);
1627 Self::from_inner_with_axis_classes(new_indices.to_vec(), permuted, axis_classes)
1628 .expect("TensorDynLen::permute_indices returned invalid tensor")
1629 }
1630
1631 pub fn permute(&self, perm: &[usize]) -> Self {
1660 assert_eq!(
1661 perm.len(),
1662 self.indices.len(),
1663 "permutation length must match tensor rank"
1664 );
1665 if perm.iter().copied().eq(0..perm.len()) {
1666 return self.clone();
1667 }
1668
1669 let new_indices: Vec<DynIndex> = perm.iter().map(|&i| self.indices[i].clone()).collect();
1671 let permuted = self
1672 .materialized_inner()
1673 .transpose(perm)
1674 .unwrap_or_else(|e| panic!("TensorDynLen::permute failed: {e}"));
1675 let axis_classes = self.permute_axis_classes(perm);
1676 Self::from_inner_with_axis_classes(new_indices, permuted, axis_classes)
1677 .expect("TensorDynLen::permute returned invalid tensor")
1678 }
1679
1680 pub fn contract(&self, other: &Self) -> Self {
1718 self.contract_pairwise_default(other)
1719 }
1720
1721 pub fn contract_with_options(
1760 &self,
1761 other: &Self,
1762 options: crate::defaults::contract::ContractionOptions<'_>,
1763 ) -> Result<Self> {
1764 crate::defaults::contract::contract_multi_with_options(&[self, other], options)
1765 }
1766
1767 pub(crate) fn contract_pairwise_default(&self, other: &Self) -> Self {
1768 let self_dims = Self::expected_dims_from_indices(&self.indices);
1769 let other_dims = Self::expected_dims_from_indices(&other.indices);
1770 let spec = prepare_contraction(&self.indices, &self_dims, &other.indices, &other_dims)
1771 .expect("contraction preparation failed");
1772 let result_axis_classes = Self::binary_contraction_axis_classes(
1773 self.storage.axis_classes(),
1774 &spec.axes_a,
1775 other.storage.axis_classes(),
1776 &spec.axes_b,
1777 );
1778
1779 if self.should_use_structured_payload_contract(other) {
1780 return self
1781 .contract_structured_payloads(
1782 other,
1783 spec.result_indices,
1784 &spec.axes_a,
1785 &spec.axes_b,
1786 )
1787 .expect("TensorDynLen::contract structured payload path failed");
1788 }
1789
1790 if self.indices.is_empty() && other.indices.is_empty() {
1791 let result = self
1792 .materialized_inner()
1793 .mul(other.materialized_inner())
1794 .unwrap_or_else(|e| panic!("TensorDynLen::contract scalar multiply failed: {e}"));
1795 return Self::from_inner(spec.result_indices, result)
1796 .expect("TensorDynLen::contract returned invalid scalar");
1797 }
1798
1799 if self.as_native().dtype() != other.as_native().dtype() {
1800 let result_native = contract_native_tensor(
1801 self.as_native(),
1802 &spec.axes_a,
1803 other.as_native(),
1804 &spec.axes_b,
1805 )
1806 .unwrap_or_else(|e| panic!("TensorDynLen::contract native fallback failed: {e}"));
1807 return Self::from_native_with_axis_classes(
1808 spec.result_indices,
1809 result_native,
1810 result_axis_classes,
1811 )
1812 .expect("TensorDynLen::contract native fallback returned invalid tensor");
1813 }
1814
1815 let subscripts = Self::build_binary_einsum_subscripts(
1816 self.indices.len(),
1817 &spec.axes_a,
1818 other.indices.len(),
1819 &spec.axes_b,
1820 )
1821 .expect("TensorDynLen::contract failed to build einsum subscripts");
1822 let result = eager_einsum_ad(
1823 &[self.materialized_inner(), other.materialized_inner()],
1824 &subscripts,
1825 )
1826 .unwrap_or_else(|e| panic!("TensorDynLen::contract failed: {e}"));
1827 Self::from_inner_with_axis_classes(spec.result_indices, result, result_axis_classes)
1828 .expect("TensorDynLen::contract returned invalid tensor")
1829 }
1830
1831 pub fn tensordot(&self, other: &Self, pairs: &[(DynIndex, DynIndex)]) -> Result<Self> {
1876 use crate::index_ops::ContractionError;
1877
1878 let self_dims = Self::expected_dims_from_indices(&self.indices);
1879 let other_dims = Self::expected_dims_from_indices(&other.indices);
1880 let spec = prepare_contraction_pairs(
1881 &self.indices,
1882 &self_dims,
1883 &other.indices,
1884 &other_dims,
1885 pairs,
1886 )
1887 .map_err(|e| match e {
1888 ContractionError::NoCommonIndices => {
1889 anyhow::anyhow!("tensordot: No pairs specified for contraction")
1890 }
1891 ContractionError::BatchContractionNotImplemented => anyhow::anyhow!(
1892 "tensordot: Common index found but not in contraction pairs. \
1893 Batch contraction is not yet implemented."
1894 ),
1895 ContractionError::IndexNotFound { tensor } => {
1896 anyhow::anyhow!("tensordot: Index not found in {} tensor", tensor)
1897 }
1898 ContractionError::DimensionMismatch {
1899 pos_a,
1900 pos_b,
1901 dim_a,
1902 dim_b,
1903 } => anyhow::anyhow!(
1904 "tensordot: Dimension mismatch: self[{}]={} != other[{}]={}",
1905 pos_a,
1906 dim_a,
1907 pos_b,
1908 dim_b
1909 ),
1910 ContractionError::DuplicateAxis { tensor, pos } => {
1911 anyhow::anyhow!("tensordot: Duplicate axis {} in {} tensor", pos, tensor)
1912 }
1913 })?;
1914 let result_axis_classes = Self::binary_contraction_axis_classes(
1915 self.storage.axis_classes(),
1916 &spec.axes_a,
1917 other.storage.axis_classes(),
1918 &spec.axes_b,
1919 );
1920
1921 if self.should_use_structured_payload_contract(other) {
1922 return self.contract_structured_payloads(
1923 other,
1924 spec.result_indices,
1925 &spec.axes_a,
1926 &spec.axes_b,
1927 );
1928 }
1929
1930 if self.indices.is_empty() && other.indices.is_empty() {
1931 let result = self
1932 .materialized_inner()
1933 .mul(other.materialized_inner())
1934 .map_err(|e| anyhow::anyhow!("tensordot scalar multiply failed: {e}"))?;
1935 return Self::from_inner(spec.result_indices, result);
1936 }
1937
1938 if self.as_native().dtype() != other.as_native().dtype() {
1939 let result_native = contract_native_tensor(
1940 self.as_native(),
1941 &spec.axes_a,
1942 other.as_native(),
1943 &spec.axes_b,
1944 )?;
1945 return Self::from_native_with_axis_classes(
1946 spec.result_indices,
1947 result_native,
1948 result_axis_classes,
1949 );
1950 }
1951
1952 let subscripts = Self::build_binary_einsum_subscripts(
1953 self.indices.len(),
1954 &spec.axes_a,
1955 other.indices.len(),
1956 &spec.axes_b,
1957 )?;
1958 let result = eager_einsum_ad(
1959 &[self.materialized_inner(), other.materialized_inner()],
1960 &subscripts,
1961 )
1962 .map_err(|e| anyhow::anyhow!("tensordot failed: {e}"))?;
1963 Self::from_inner_with_axis_classes(spec.result_indices, result, result_axis_classes)
1964 }
1965
1966 pub fn outer_product(&self, other: &Self) -> Result<Self> {
1996 use anyhow::Context;
1997
1998 let common_positions = common_ind_positions(&self.indices, &other.indices);
2000 if !common_positions.is_empty() {
2001 let common_ids: Vec<_> = common_positions
2002 .iter()
2003 .map(|(pos_a, _)| self.indices[*pos_a].id())
2004 .collect();
2005 return Err(anyhow::anyhow!(
2006 "outer_product: tensors have common indices {:?}. \
2007 Use tensordot to contract common indices, or use sim() to replace \
2008 indices with fresh IDs before computing outer product.",
2009 common_ids
2010 ))
2011 .context("outer_product: common indices found");
2012 }
2013
2014 let mut result_indices = self.indices.clone();
2016 result_indices.extend(other.indices.iter().cloned());
2017 let result_axis_classes = Self::binary_contraction_axis_classes(
2018 self.storage.axis_classes(),
2019 &[],
2020 other.storage.axis_classes(),
2021 &[],
2022 );
2023 if self.should_use_structured_payload_contract(other) {
2024 return self.contract_structured_payloads(other, result_indices, &[], &[]);
2025 }
2026 if self.as_native().dtype() != other.as_native().dtype() {
2027 let result_native =
2028 contract_native_tensor(self.as_native(), &[], other.as_native(), &[])?;
2029 return Self::from_native_with_axis_classes(
2030 result_indices,
2031 result_native,
2032 result_axis_classes,
2033 );
2034 }
2035
2036 let subscripts = Self::build_binary_einsum_subscripts(
2037 self.indices.len(),
2038 &[],
2039 other.indices.len(),
2040 &[],
2041 )?;
2042 let result = eager_einsum_ad(
2043 &[self.materialized_inner(), other.materialized_inner()],
2044 &subscripts,
2045 )
2046 .map_err(|e| anyhow::anyhow!("outer_product failed: {e}"))?;
2047 Self::from_inner_with_axis_classes(result_indices, result, result_axis_classes)
2048 }
2049}
2050
2051impl TensorDynLen {
2056 pub fn random<T: RandomScalar, R: Rng>(rng: &mut R, indices: Vec<DynIndex>) -> Self {
2083 let dims: Vec<usize> = indices.iter().map(|idx| idx.dim()).collect();
2084 let size: usize = dims.iter().product();
2085 let data: Vec<T> = (0..size).map(|_| T::random_value(rng)).collect();
2086 Self::from_dense(indices, data).expect("TensorDynLen::random failed")
2087 }
2088}
2089
2090impl Mul<&TensorDynLen> for &TensorDynLen {
2116 type Output = TensorDynLen;
2117
2118 fn mul(self, other: &TensorDynLen) -> Self::Output {
2119 self.contract(other)
2120 }
2121}
2122
2123impl Mul<TensorDynLen> for TensorDynLen {
2127 type Output = TensorDynLen;
2128
2129 fn mul(self, other: TensorDynLen) -> Self::Output {
2130 self.contract(&other)
2131 }
2132}
2133
2134impl Mul<TensorDynLen> for &TensorDynLen {
2136 type Output = TensorDynLen;
2137
2138 fn mul(self, other: TensorDynLen) -> Self::Output {
2139 self.contract(&other)
2140 }
2141}
2142
2143impl Mul<&TensorDynLen> for TensorDynLen {
2145 type Output = TensorDynLen;
2146
2147 fn mul(self, other: &TensorDynLen) -> Self::Output {
2148 self.contract(other)
2149 }
2150}
2151
2152impl Sub<&TensorDynLen> for &TensorDynLen {
2153 type Output = TensorDynLen;
2154
2155 fn sub(self, other: &TensorDynLen) -> Self::Output {
2156 TensorDynLen::axpby(
2157 self,
2158 AnyScalar::new_real(1.0),
2159 other,
2160 AnyScalar::new_real(-1.0),
2161 )
2162 .expect("tensor subtraction failed")
2163 }
2164}
2165
2166impl Sub<TensorDynLen> for TensorDynLen {
2167 type Output = TensorDynLen;
2168
2169 fn sub(self, other: TensorDynLen) -> Self::Output {
2170 Sub::sub(&self, &other)
2171 }
2172}
2173
2174impl Sub<TensorDynLen> for &TensorDynLen {
2175 type Output = TensorDynLen;
2176
2177 fn sub(self, other: TensorDynLen) -> Self::Output {
2178 Sub::sub(self, &other)
2179 }
2180}
2181
2182impl Sub<&TensorDynLen> for TensorDynLen {
2183 type Output = TensorDynLen;
2184
2185 fn sub(self, other: &TensorDynLen) -> Self::Output {
2186 Sub::sub(&self, other)
2187 }
2188}
2189
2190impl Neg for &TensorDynLen {
2191 type Output = TensorDynLen;
2192
2193 fn neg(self) -> Self::Output {
2194 TensorDynLen::scale(self, AnyScalar::new_real(-1.0)).expect("tensor negation failed")
2195 }
2196}
2197
2198impl Neg for TensorDynLen {
2199 type Output = TensorDynLen;
2200
2201 fn neg(self) -> Self::Output {
2202 Neg::neg(&self)
2203 }
2204}
2205
2206impl TensorDynLen {
2207 pub fn add(&self, other: &Self) -> Result<Self> {
2241 if self.indices.len() != other.indices.len() {
2243 return Err(anyhow::anyhow!(
2244 "Index count mismatch: self has {} indices, other has {}",
2245 self.indices.len(),
2246 other.indices.len()
2247 ));
2248 }
2249
2250 let self_set: HashSet<_> = self.indices.iter().collect();
2252 let other_set: HashSet<_> = other.indices.iter().collect();
2253
2254 if self_set != other_set {
2255 return Err(anyhow::anyhow!(
2256 "Index set mismatch: tensors must have the same indices"
2257 ));
2258 }
2259
2260 let other_aligned = other.permute_indices(&self.indices);
2262
2263 let self_expected_dims = Self::expected_dims_from_indices(&self.indices);
2265 let other_expected_dims = Self::expected_dims_from_indices(&other_aligned.indices);
2266 if self_expected_dims != other_expected_dims {
2267 use crate::TagSetLike;
2268 let fmt = |indices: &[DynIndex]| -> Vec<String> {
2269 indices
2270 .iter()
2271 .map(|idx| {
2272 let tags: Vec<String> = idx.tags().iter().collect();
2273 format!("{:?}(dim={},tags={:?})", idx.id(), idx.dim(), tags)
2274 })
2275 .collect()
2276 };
2277 return Err(anyhow::anyhow!(
2278 "Dimension mismatch after alignment.\n\
2279 self: dims={:?}, indices(order)={:?}\n\
2280 other_aligned: dims={:?}, indices(order)={:?}",
2281 self_expected_dims,
2282 fmt(&self.indices),
2283 other_expected_dims,
2284 fmt(&other_aligned.indices)
2285 ));
2286 }
2287
2288 self.axpby(
2289 AnyScalar::new_real(1.0),
2290 &other_aligned,
2291 AnyScalar::new_real(1.0),
2292 )
2293 }
2294
2295 pub fn axpby(&self, a: AnyScalar, other: &Self, b: AnyScalar) -> Result<Self> {
2317 if self.indices.len() != other.indices.len() {
2319 return Err(anyhow::anyhow!(
2320 "Index count mismatch: self has {} indices, other has {}",
2321 self.indices.len(),
2322 other.indices.len()
2323 ));
2324 }
2325
2326 let self_set: HashSet<_> = self.indices.iter().collect();
2328 let other_set: HashSet<_> = other.indices.iter().collect();
2329 if self_set != other_set {
2330 return Err(anyhow::anyhow!(
2331 "Index set mismatch: tensors must have the same indices"
2332 ));
2333 }
2334
2335 let other_aligned = other.permute_indices(&self.indices);
2337
2338 let self_expected_dims = Self::expected_dims_from_indices(&self.indices);
2340 let other_expected_dims = Self::expected_dims_from_indices(&other_aligned.indices);
2341 if self_expected_dims != other_expected_dims {
2342 return Err(anyhow::anyhow!(
2343 "Dimension mismatch after alignment: self={:?}, other_aligned={:?}",
2344 self_expected_dims,
2345 other_expected_dims
2346 ));
2347 }
2348
2349 let axis_classes = if self.storage.axis_classes() == other_aligned.storage.axis_classes() {
2350 self.storage.axis_classes().to_vec()
2351 } else {
2352 Self::dense_axis_classes(self.indices.len())
2353 };
2354
2355 let same_compact_layout = self.storage.payload_dims()
2356 == other_aligned.storage.payload_dims()
2357 && self.storage.payload_strides() == other_aligned.storage.payload_strides()
2358 && self.storage.axis_classes() == other_aligned.storage.axis_classes();
2359 if same_compact_layout
2360 && !self.tracks_grad()
2361 && !other_aligned.tracks_grad()
2362 && !a.tracks_grad()
2363 && !b.tracks_grad()
2364 {
2365 let combined = self
2366 .storage
2367 .axpby(
2368 &a.to_backend_scalar(),
2369 other_aligned.storage.as_ref(),
2370 &b.to_backend_scalar(),
2371 )
2372 .map_err(|e| anyhow::anyhow!("storage axpby failed: {e}"))?;
2373 return Self::from_storage(self.indices.clone(), Arc::new(combined));
2374 }
2375
2376 if self.as_native().dtype() != other_aligned.as_native().dtype()
2377 || self.as_native().dtype() != a.as_tensor().as_native().dtype()
2378 || other_aligned.as_native().dtype() != b.as_tensor().as_native().dtype()
2379 {
2380 let combined = axpby_native_tensor(
2381 self.as_native(),
2382 &a.to_backend_scalar(),
2383 other_aligned.as_native(),
2384 &b.to_backend_scalar(),
2385 )?;
2386 return Self::from_native_with_axis_classes(
2387 self.indices.clone(),
2388 combined,
2389 axis_classes,
2390 );
2391 }
2392
2393 let lhs = self.scale(a)?;
2394 let rhs = other_aligned.scale(b)?;
2395 let combined = lhs
2396 .materialized_inner()
2397 .add(rhs.materialized_inner())
2398 .map_err(|e| anyhow::anyhow!("tensor addition failed: {e}"))?;
2399 Self::from_inner_with_axis_classes(self.indices.clone(), combined, axis_classes)
2400 }
2401
2402 pub fn scale(&self, scalar: AnyScalar) -> Result<Self> {
2417 if self.as_native().dtype() != scalar.as_tensor().as_native().dtype() {
2418 let scaled = scale_native_tensor(self.as_native(), &scalar.to_backend_scalar())?;
2419 return Self::from_native_with_axis_classes(
2420 self.indices.clone(),
2421 scaled,
2422 self.storage.axis_classes().to_vec(),
2423 );
2424 }
2425
2426 let scaled = if self.indices.is_empty() {
2427 self.materialized_inner()
2428 .mul(scalar.as_tensor().materialized_inner())
2429 .map_err(|e| anyhow::anyhow!("scalar multiplication failed: {e}"))?
2430 } else {
2431 let subscripts = Self::scale_subscripts(self.indices.len())?;
2432 eager_einsum_ad(
2433 &[
2434 self.materialized_inner(),
2435 scalar.as_tensor().materialized_inner(),
2436 ],
2437 &subscripts,
2438 )
2439 .map_err(|e| anyhow::anyhow!("tensor scaling failed: {e}"))?
2440 };
2441 Self::from_inner_with_axis_classes(
2442 self.indices.clone(),
2443 scaled,
2444 self.storage.axis_classes().to_vec(),
2445 )
2446 }
2447
2448 pub fn inner_product(&self, other: &Self) -> Result<AnyScalar> {
2466 if self.indices.len() == other.indices.len() {
2467 let self_set: HashSet<_> = self.indices.iter().collect();
2468 let other_set: HashSet<_> = other.indices.iter().collect();
2469 if self_set == other_set {
2470 let other_aligned = other.permute_indices(&self.indices);
2471 let result = self.conj().contract(&other_aligned);
2472 return Ok(result.sum());
2473 }
2474 }
2475
2476 let conj_self = self.conj();
2478 let result =
2479 super::contract::contract_multi(&[&conj_self, other], crate::AllowedPairs::All)?;
2480 Ok(result.sum())
2482 }
2483}
2484
2485impl TensorDynLen {
2490 pub fn replaceind(&self, old_index: &DynIndex, new_index: &DynIndex) -> Self {
2521 if old_index.dim() != new_index.dim() {
2523 panic!(
2524 "Index space mismatch: cannot replace index with dimension {} with index of dimension {}",
2525 old_index.dim(),
2526 new_index.dim()
2527 );
2528 }
2529
2530 let new_indices: Vec<_> = self
2531 .indices
2532 .iter()
2533 .map(|idx| {
2534 if *idx == *old_index {
2535 new_index.clone()
2536 } else {
2537 idx.clone()
2538 }
2539 })
2540 .collect();
2541
2542 Self {
2543 indices: new_indices,
2544 storage: Arc::clone(&self.storage),
2545 structured_ad: self.structured_ad.clone(),
2546 eager_cache: Arc::clone(&self.eager_cache),
2547 }
2548 }
2549
2550 pub fn replaceinds(&self, old_indices: &[DynIndex], new_indices: &[DynIndex]) -> Self {
2585 assert_eq!(
2586 old_indices.len(),
2587 new_indices.len(),
2588 "old_indices and new_indices must have the same length"
2589 );
2590
2591 for (old, new) in old_indices.iter().zip(new_indices.iter()) {
2593 if old.dim() != new.dim() {
2594 panic!(
2595 "Index space mismatch: cannot replace index with dimension {} with index of dimension {}",
2596 old.dim(),
2597 new.dim()
2598 );
2599 }
2600 }
2601
2602 let replacement_map: std::collections::HashMap<_, _> =
2604 old_indices.iter().zip(new_indices.iter()).collect();
2605
2606 let new_indices_vec: Vec<_> = self
2607 .indices
2608 .iter()
2609 .map(|idx| {
2610 if let Some(new_idx) = replacement_map.get(idx) {
2611 (*new_idx).clone()
2612 } else {
2613 idx.clone()
2614 }
2615 })
2616 .collect();
2617
2618 Self {
2619 indices: new_indices_vec,
2620 storage: Arc::clone(&self.storage),
2621 structured_ad: self.structured_ad.clone(),
2622 eager_cache: Arc::clone(&self.eager_cache),
2623 }
2624 }
2625}
2626
2627impl TensorDynLen {
2632 pub fn conj(&self) -> Self {
2655 let new_indices: Vec<DynIndex> = self.indices.iter().map(|idx| idx.conj()).collect();
2659 let conjugated = self
2660 .materialized_inner()
2661 .conj()
2662 .unwrap_or_else(|e| panic!("TensorDynLen::conj failed: {e}"));
2663 Self::from_inner_with_axis_classes(
2664 new_indices,
2665 conjugated,
2666 self.storage.axis_classes().to_vec(),
2667 )
2668 .expect("TensorDynLen::conj returned invalid tensor")
2669 }
2670}
2671
2672impl TensorDynLen {
2677 pub fn norm_squared(&self) -> f64 {
2695 if self.indices.is_empty() {
2697 let value = self.sum();
2699 let abs_val = value.abs();
2700 return abs_val * abs_val;
2701 }
2702
2703 let conj = self.conj();
2706 let scalar = self.contract(&conj);
2707 scalar.sum().real().max(0.0)
2710 }
2711
2712 pub fn norm(&self) -> f64 {
2726 self.norm_squared().sqrt()
2727 }
2728
2729 pub fn maxabs(&self) -> f64 {
2741 self.to_storage()
2742 .map(|storage| storage.max_abs())
2743 .unwrap_or(0.0)
2744 }
2745
2746 pub fn distance(&self, other: &Self) -> f64 {
2777 let norm_self = self.norm();
2778
2779 let neg_other = other
2781 .scale(AnyScalar::new_real(-1.0))
2782 .expect("distance: tensor scaling failed");
2783 let diff = self
2784 .add(&neg_other)
2785 .expect("distance: tensors must have same indices");
2786 let norm_diff = diff.norm();
2787
2788 if norm_self > 0.0 {
2789 norm_diff / norm_self
2790 } else {
2791 norm_diff
2792 }
2793 }
2794}
2795
2796impl std::fmt::Debug for TensorDynLen {
2797 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2798 f.debug_struct("TensorDynLen")
2799 .field("indices", &self.indices)
2800 .field("dims", &self.dims())
2801 .field("is_diag", &self.is_diag())
2802 .finish()
2803 }
2804}
2805
2806pub fn diag_tensor_dyn_len(indices: Vec<DynIndex>, diag_data: Vec<f64>) -> TensorDynLen {
2831 TensorDynLen::from_diag(indices, diag_data)
2832 .unwrap_or_else(|err| panic!("diag_tensor_dyn_len failed: {err}"))
2833}
2834
2835#[allow(clippy::type_complexity)]
2882pub fn unfold_split(
2883 t: &TensorDynLen,
2884 left_inds: &[DynIndex],
2885) -> Result<(
2886 NativeTensor,
2887 usize,
2888 usize,
2889 usize,
2890 Vec<DynIndex>,
2891 Vec<DynIndex>,
2892)> {
2893 let rank = t.indices.len();
2894
2895 anyhow::ensure!(rank >= 2, "Tensor must have rank >= 2, got rank {}", rank);
2897
2898 let left_len = left_inds.len();
2899
2900 anyhow::ensure!(
2902 left_len > 0 && left_len < rank,
2903 "Left indices must be a non-empty proper subset of tensor indices (0 < left_len < rank), got left_len={}, rank={}",
2904 left_len,
2905 rank
2906 );
2907
2908 let tensor_set: HashSet<_> = t.indices.iter().collect();
2910 let mut left_set = HashSet::new();
2911
2912 for left_idx in left_inds {
2913 anyhow::ensure!(
2914 tensor_set.contains(left_idx),
2915 "Index in left_inds not found in tensor"
2916 );
2917 anyhow::ensure!(left_set.insert(left_idx), "Duplicate index in left_inds");
2918 }
2919
2920 let mut right_inds = Vec::new();
2922 for idx in &t.indices {
2923 if !left_set.contains(idx) {
2924 right_inds.push(idx.clone());
2925 }
2926 }
2927
2928 let mut new_indices = Vec::with_capacity(rank);
2930 new_indices.extend_from_slice(left_inds);
2931 new_indices.extend_from_slice(&right_inds);
2932
2933 let unfolded = t.permute_indices(&new_indices);
2935
2936 let unfolded_dims = unfolded.dims();
2938 let m: usize = unfolded_dims[..left_len].iter().product();
2939 let n: usize = unfolded_dims[left_len..].iter().product();
2940
2941 let matrix_tensor = reshape_col_major_native_tensor(unfolded.as_native(), &[m, n])?;
2942
2943 Ok((
2944 matrix_tensor,
2945 left_len,
2946 m,
2947 n,
2948 left_inds.to_vec(),
2949 right_inds,
2950 ))
2951}
2952
2953use crate::tensor_index::TensorIndex;
2958
2959impl TensorIndex for TensorDynLen {
2960 type Index = DynIndex;
2961
2962 fn external_indices(&self) -> Vec<DynIndex> {
2963 self.indices.clone()
2965 }
2966
2967 fn num_external_indices(&self) -> usize {
2968 self.indices.len()
2969 }
2970
2971 fn replaceind(&self, old_index: &DynIndex, new_index: &DynIndex) -> Result<Self> {
2972 Ok(TensorDynLen::replaceind(self, old_index, new_index))
2974 }
2975
2976 fn replaceinds(&self, old_indices: &[DynIndex], new_indices: &[DynIndex]) -> Result<Self> {
2977 Ok(TensorDynLen::replaceinds(self, old_indices, new_indices))
2979 }
2980}
2981
2982use crate::tensor_like::{FactorizeError, FactorizeOptions, FactorizeResult, TensorLike};
2987
2988impl TensorLike for TensorDynLen {
2989 fn factorize(
2990 &self,
2991 left_inds: &[DynIndex],
2992 options: &FactorizeOptions,
2993 ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
2994 crate::factorize::factorize(self, left_inds, options)
2995 }
2996
2997 fn factorize_full_rank(
2998 &self,
2999 left_inds: &[DynIndex],
3000 alg: crate::FactorizeAlg,
3001 canonical: crate::Canonical,
3002 ) -> std::result::Result<FactorizeResult<Self>, FactorizeError> {
3003 crate::factorize::factorize_full_rank(self, left_inds, alg, canonical)
3004 }
3005
3006 fn conj(&self) -> Self {
3007 TensorDynLen::conj(self)
3009 }
3010
3011 fn direct_sum(
3012 &self,
3013 other: &Self,
3014 pairs: &[(DynIndex, DynIndex)],
3015 ) -> Result<crate::tensor_like::DirectSumResult<Self>> {
3016 let (tensor, new_indices) = crate::direct_sum::direct_sum(self, other, pairs)?;
3017 Ok(crate::tensor_like::DirectSumResult {
3018 tensor,
3019 new_indices,
3020 })
3021 }
3022
3023 fn outer_product(&self, other: &Self) -> Result<Self> {
3024 TensorDynLen::outer_product(self, other)
3026 }
3027
3028 fn norm_squared(&self) -> f64 {
3029 TensorDynLen::norm_squared(self)
3031 }
3032
3033 fn maxabs(&self) -> f64 {
3034 TensorDynLen::maxabs(self)
3035 }
3036
3037 fn permuteinds(&self, new_order: &[DynIndex]) -> Result<Self> {
3038 Ok(TensorDynLen::permute_indices(self, new_order))
3040 }
3041
3042 fn fuse_indices(
3043 &self,
3044 old_indices: &[DynIndex],
3045 new_index: DynIndex,
3046 order: LinearizationOrder,
3047 ) -> Result<Self> {
3048 TensorDynLen::fuse_indices(self, old_indices, new_index, order)
3049 }
3050
3051 fn contract(tensors: &[&Self], allowed: crate::AllowedPairs<'_>) -> Result<Self> {
3052 super::contract::contract_multi(tensors, allowed)
3054 }
3055
3056 fn contract_connected(tensors: &[&Self], allowed: crate::AllowedPairs<'_>) -> Result<Self> {
3057 super::contract::contract_connected(tensors, allowed)
3059 }
3060
3061 fn axpby(&self, a: crate::AnyScalar, other: &Self, b: crate::AnyScalar) -> Result<Self> {
3062 TensorDynLen::axpby(self, a, other, b)
3064 }
3065
3066 fn scale(&self, scalar: crate::AnyScalar) -> Result<Self> {
3067 TensorDynLen::scale(self, scalar)
3069 }
3070
3071 fn inner_product(&self, other: &Self) -> Result<crate::AnyScalar> {
3072 TensorDynLen::inner_product(self, other)
3074 }
3075
3076 fn diagonal(input_index: &DynIndex, output_index: &DynIndex) -> Result<Self> {
3077 let dim = input_index.dim();
3078 if dim != output_index.dim() {
3079 return Err(anyhow::anyhow!(
3080 "Dimension mismatch: input index has dim {}, output has dim {}",
3081 dim,
3082 output_index.dim(),
3083 ));
3084 }
3085
3086 TensorDynLen::from_diag(
3087 vec![input_index.clone(), output_index.clone()],
3088 vec![1.0_f64; dim],
3089 )
3090 }
3091
3092 fn scalar_one() -> Result<Self> {
3093 TensorDynLen::from_dense(vec![], vec![1.0_f64])
3094 }
3095
3096 fn ones(indices: &[DynIndex]) -> Result<Self> {
3097 if indices.is_empty() {
3098 return Self::scalar_one();
3099 }
3100 let dims: Vec<usize> = indices.iter().map(|idx| idx.size()).collect();
3101 let total_size = checked_total_size(&dims)?;
3102 TensorDynLen::from_dense(indices.to_vec(), vec![1.0_f64; total_size])
3103 }
3104
3105 fn onehot(index_vals: &[(DynIndex, usize)]) -> Result<Self> {
3106 if index_vals.is_empty() {
3107 return Self::scalar_one();
3108 }
3109 let indices: Vec<DynIndex> = index_vals.iter().map(|(idx, _)| idx.clone()).collect();
3110 let vals: Vec<usize> = index_vals.iter().map(|(_, v)| *v).collect();
3111 let dims: Vec<usize> = indices.iter().map(|idx| idx.size()).collect();
3112
3113 for (k, (&v, &d)) in vals.iter().zip(dims.iter()).enumerate() {
3114 if v >= d {
3115 return Err(anyhow::anyhow!(
3116 "onehot: value {} at position {} is >= dimension {}",
3117 v,
3118 k,
3119 d
3120 ));
3121 }
3122 }
3123
3124 let total_size = checked_total_size(&dims)?;
3125 let mut data = vec![0.0_f64; total_size];
3126
3127 let offset = column_major_offset(&dims, &vals)?;
3128 data[offset] = 1.0;
3129
3130 Self::from_dense(indices, data)
3131 }
3132
3133 }
3135
3136fn checked_total_size(dims: &[usize]) -> Result<usize> {
3137 dims.iter().try_fold(1_usize, |acc, &d| {
3138 if d == 0 {
3139 return Err(anyhow::anyhow!("invalid dimension 0"));
3140 }
3141 acc.checked_mul(d)
3142 .ok_or_else(|| anyhow::anyhow!("tensor size overflow"))
3143 })
3144}
3145
3146fn column_major_offset(dims: &[usize], vals: &[usize]) -> Result<usize> {
3147 if dims.len() != vals.len() {
3148 return Err(anyhow::anyhow!(
3149 "column_major_offset: dims.len() != vals.len()"
3150 ));
3151 }
3152 checked_total_size(dims)?;
3153
3154 let mut offset = 0usize;
3155 let mut stride = 1usize;
3156 for (k, (&v, &d)) in vals.iter().zip(dims.iter()).enumerate() {
3157 if d == 0 {
3158 return Err(anyhow::anyhow!("invalid dimension 0 at position {}", k));
3159 }
3160 if v >= d {
3161 return Err(anyhow::anyhow!(
3162 "column_major_offset: value {} at position {} is >= dimension {}",
3163 v,
3164 k,
3165 d
3166 ));
3167 }
3168 let term = v
3169 .checked_mul(stride)
3170 .ok_or_else(|| anyhow::anyhow!("column_major_offset: overflow"))?;
3171 offset = offset
3172 .checked_add(term)
3173 .ok_or_else(|| anyhow::anyhow!("column_major_offset: overflow"))?;
3174 stride = stride
3175 .checked_mul(d)
3176 .ok_or_else(|| anyhow::anyhow!("column_major_offset: overflow"))?;
3177 }
3178 Ok(offset)
3179}
3180
3181impl TensorDynLen {
3186 fn any_scalar_payload_to_complex(data: Vec<AnyScalar>) -> Vec<Complex64> {
3187 data.into_iter()
3188 .map(|value| {
3189 value
3190 .as_c64()
3191 .unwrap_or_else(|| Complex64::new(value.real(), 0.0))
3192 })
3193 .collect()
3194 }
3195
3196 fn any_scalar_payload_to_real(data: Vec<AnyScalar>) -> Vec<f64> {
3197 data.into_iter().map(|value| value.real()).collect()
3198 }
3199
3200 fn validate_dense_payload_len(data_len: usize, dims: &[usize]) -> Result<()> {
3201 let expected_len = checked_total_size(dims)?;
3202 anyhow::ensure!(
3203 data_len == expected_len,
3204 "dense payload length {} does not match dims {:?} (expected {})",
3205 data_len,
3206 dims,
3207 expected_len
3208 );
3209 Ok(())
3210 }
3211
3212 fn validate_diag_payload_len(data_len: usize, dims: &[usize]) -> Result<()> {
3213 anyhow::ensure!(
3214 !dims.is_empty(),
3215 "diagonal tensor construction requires at least one index"
3216 );
3217 Self::validate_diag_dims(dims)?;
3218 anyhow::ensure!(
3219 data_len == dims[0],
3220 "diagonal payload length {} does not match diagonal dimension {}",
3221 data_len,
3222 dims[0]
3223 );
3224 Ok(())
3225 }
3226
3227 pub fn from_dense<T: TensorElement>(indices: Vec<DynIndex>, data: Vec<T>) -> Result<Self> {
3254 let dims = Self::expected_dims_from_indices(&indices);
3255 Self::validate_indices(&indices);
3256 Self::validate_dense_payload_len(data.len(), &dims)?;
3257 let native = dense_native_tensor_from_col_major(&data, &dims)?;
3258 Self::from_native(indices, native)
3259 }
3260
3261 pub fn from_dense_any(indices: Vec<DynIndex>, data: Vec<AnyScalar>) -> Result<Self> {
3287 if data.iter().any(AnyScalar::is_complex) {
3288 Self::from_dense(indices, Self::any_scalar_payload_to_complex(data))
3289 } else {
3290 Self::from_dense(indices, Self::any_scalar_payload_to_real(data))
3291 }
3292 }
3293
3294 pub fn from_diag<T: TensorElement>(indices: Vec<DynIndex>, data: Vec<T>) -> Result<Self> {
3322 let dims = Self::expected_dims_from_indices(&indices);
3323 Self::validate_indices(&indices);
3324 Self::validate_diag_payload_len(data.len(), &dims)?;
3325 let native = diag_native_tensor_from_col_major(&data, dims.len())?;
3326 Self::from_native_with_axis_classes(indices, native, Self::diag_axis_classes(dims.len()))
3327 }
3328
3329 pub fn from_diag_any(indices: Vec<DynIndex>, data: Vec<AnyScalar>) -> Result<Self> {
3351 if data.iter().any(AnyScalar::is_complex) {
3352 Self::from_diag(indices, Self::any_scalar_payload_to_complex(data))
3353 } else {
3354 Self::from_diag(indices, Self::any_scalar_payload_to_real(data))
3355 }
3356 }
3357
3358 pub fn copy_tensor(indices: Vec<DynIndex>, value: AnyScalar) -> Result<Self> {
3379 if indices.is_empty() {
3380 return Self::from_dense_any(vec![], vec![value]);
3381 }
3382 let dim = indices[0].dim();
3383 let data = vec![value; dim];
3384 Self::from_diag_any(indices, data)
3385 }
3386
3387 pub fn fuse_indices(
3441 &self,
3442 old_indices: &[DynIndex],
3443 new_index: DynIndex,
3444 order: LinearizationOrder,
3445 ) -> Result<Self> {
3446 anyhow::ensure!(
3447 !old_indices.is_empty(),
3448 "fuse_indices requires at least one index to fuse"
3449 );
3450
3451 let old_dims = self.dims();
3452 let mut seen_indices = HashSet::new();
3453 let mut old_axes = Vec::with_capacity(old_indices.len());
3454 for old_index in old_indices {
3455 anyhow::ensure!(
3456 seen_indices.insert(old_index),
3457 "duplicate index in old_indices"
3458 );
3459 let axis = self
3460 .indices
3461 .iter()
3462 .position(|idx| idx == old_index)
3463 .ok_or_else(|| anyhow::anyhow!("index {:?} not found in tensor", old_index))?;
3464 anyhow::ensure!(
3465 old_index.dim() == old_dims[axis],
3466 "old index dimension does not match tensor axis dimension"
3467 );
3468 old_axes.push(axis);
3469 }
3470
3471 let fused_dims: Vec<usize> = old_axes.iter().map(|&axis| old_dims[axis]).collect();
3472 let fused_product = checked_product(&fused_dims)?;
3473 anyhow::ensure!(
3474 fused_product == new_index.dim(),
3475 "product of old index dimensions must match the replacement index dimension"
3476 );
3477
3478 let insertion_axis =
3479 old_axes.iter().copied().min().ok_or_else(|| {
3480 anyhow::anyhow!("fuse_indices requires at least one index to fuse")
3481 })?;
3482 let old_axis_set: HashSet<usize> = old_axes.iter().copied().collect();
3483
3484 let mut result_indices =
3485 Vec::with_capacity(self.indices.len() - old_indices.len() + 1usize);
3486 for (axis, index) in self.indices.iter().enumerate() {
3487 if axis == insertion_axis {
3488 result_indices.push(new_index.clone());
3489 }
3490 if !old_axis_set.contains(&axis) {
3491 result_indices.push(index.clone());
3492 }
3493 }
3494 let mut result_seen = HashSet::new();
3495 for index in &result_indices {
3496 anyhow::ensure!(
3497 result_seen.insert(index),
3498 "fuse_indices result would contain duplicate index"
3499 );
3500 }
3501 Self::validate_indices(&result_indices);
3502
3503 let mut new_dims = Vec::with_capacity(old_dims.len() - old_indices.len() + 1usize);
3504 for (axis, dim) in old_dims.iter().copied().enumerate() {
3505 if axis == insertion_axis {
3506 new_dims.push(new_index.dim());
3507 }
3508 if !old_axis_set.contains(&axis) {
3509 new_dims.push(dim);
3510 }
3511 }
3512
3513 let old_data = self.to_vec_any()?;
3514 let mut new_data = vec![AnyScalar::new_real(0.0); old_data.len()];
3515 for (old_linear, value) in old_data.into_iter().enumerate() {
3516 let old_multi = decode_col_major_linear(old_linear, &old_dims)?;
3517 let fused_multi: Vec<usize> = old_axes.iter().map(|&axis| old_multi[axis]).collect();
3518 let fused_linear = encode_linear_with_order(&fused_multi, &fused_dims, order)?;
3519
3520 let mut new_multi = Vec::with_capacity(new_dims.len());
3521 for (axis, old_coord) in old_multi.iter().copied().enumerate() {
3522 if axis == insertion_axis {
3523 new_multi.push(fused_linear);
3524 }
3525 if !old_axis_set.contains(&axis) {
3526 new_multi.push(old_coord);
3527 }
3528 }
3529 let new_linear = encode_col_major_linear(&new_multi, &new_dims)?;
3530 new_data[new_linear] = value;
3531 }
3532
3533 Self::from_dense_any(result_indices, new_data)
3534 }
3535
3536 pub fn unfuse_index(
3558 &self,
3559 old_index: &DynIndex,
3560 new_indices: &[DynIndex],
3561 order: LinearizationOrder,
3562 ) -> Result<Self> {
3563 anyhow::ensure!(
3564 !new_indices.is_empty(),
3565 "unfuse_index requires at least one replacement index"
3566 );
3567
3568 let axis = self
3569 .indices
3570 .iter()
3571 .position(|idx| idx == old_index)
3572 .ok_or_else(|| anyhow::anyhow!("index {:?} not found in tensor", old_index))?;
3573
3574 let replacement_dims: Vec<usize> = new_indices.iter().map(DynIndex::dim).collect();
3575 let replacement_product = checked_product(&replacement_dims)?;
3576 anyhow::ensure!(
3577 replacement_product == old_index.dim(),
3578 "product of new index dimensions must match the replaced index dimension"
3579 );
3580
3581 let mut result_indices =
3582 Vec::with_capacity(self.indices.len() - 1usize + new_indices.len());
3583 result_indices.extend_from_slice(&self.indices[..axis]);
3584 result_indices.extend(new_indices.iter().cloned());
3585 result_indices.extend_from_slice(&self.indices[axis + 1..]);
3586 Self::validate_indices(&result_indices);
3587
3588 let old_dims = self.dims();
3589 let mut new_dims = Vec::with_capacity(old_dims.len() - 1usize + replacement_dims.len());
3590 new_dims.extend_from_slice(&old_dims[..axis]);
3591 new_dims.extend_from_slice(&replacement_dims);
3592 new_dims.extend_from_slice(&old_dims[axis + 1..]);
3593
3594 let old_data = self.to_vec_any()?;
3595 let mut new_data = vec![AnyScalar::new_real(0.0); old_data.len()];
3596 for (old_linear, value) in old_data.into_iter().enumerate() {
3597 let old_multi = decode_col_major_linear(old_linear, &old_dims)?;
3598 let split_multi = decode_linear_with_order(old_multi[axis], &replacement_dims, order)?;
3599 let mut new_multi = Vec::with_capacity(new_dims.len());
3600 new_multi.extend_from_slice(&old_multi[..axis]);
3601 new_multi.extend_from_slice(&split_multi);
3602 new_multi.extend_from_slice(&old_multi[axis + 1..]);
3603 let new_linear = encode_col_major_linear(&new_multi, &new_dims)?;
3604 new_data[new_linear] = value;
3605 }
3606
3607 Self::from_dense_any(result_indices, new_data)
3608 }
3609
3610 pub fn scalar<T: TensorElement>(value: T) -> Result<Self> {
3621 Self::from_dense(vec![], vec![value])
3622 }
3623
3624 pub fn zeros<T: TensorElement + Zero + Clone>(indices: Vec<DynIndex>) -> Result<Self> {
3637 let dims: Vec<usize> = indices.iter().map(|idx| idx.dim()).collect();
3638 let size: usize = dims.iter().product();
3639 Self::from_dense(indices, vec![T::zero(); size])
3640 }
3641}
3642
3643impl TensorDynLen {
3648 pub fn to_vec<T: TensorElement>(&self) -> Result<Vec<T>> {
3670 native_tensor_primal_to_dense_col_major(self.as_native())
3671 }
3672
3673 fn to_vec_any(&self) -> Result<Vec<AnyScalar>> {
3674 if self.is_complex() {
3675 self.to_vec::<Complex64>().map(|data| {
3676 data.into_iter()
3677 .map(|value| AnyScalar::new_complex(value.re, value.im))
3678 .collect()
3679 })
3680 } else {
3681 self.to_vec::<f64>()
3682 .map(|data| data.into_iter().map(AnyScalar::new_real).collect())
3683 }
3684 }
3685
3686 pub fn as_slice_f64(&self) -> Result<Vec<f64>> {
3691 self.to_vec::<f64>()
3692 }
3693
3694 pub fn as_slice_c64(&self) -> Result<Vec<Complex64>> {
3699 self.to_vec::<Complex64>()
3700 }
3701
3702 pub fn is_f64(&self) -> bool {
3715 self.storage.is_f64()
3716 }
3717
3718 pub fn is_diag(&self) -> bool {
3744 self.storage.is_diag()
3745 }
3746
3747 pub fn is_complex(&self) -> bool {
3766 self.storage.is_complex()
3767 }
3768}
3769
3770fn checked_product(dims: &[usize]) -> Result<usize> {
3771 dims.iter().try_fold(1usize, |acc, &dim| {
3772 acc.checked_mul(dim)
3773 .ok_or_else(|| anyhow::anyhow!("dimension product overflow"))
3774 })
3775}
3776
3777fn decode_col_major_linear(linear: usize, dims: &[usize]) -> Result<Vec<usize>> {
3778 let total = checked_product(dims)?;
3779 anyhow::ensure!(
3780 linear < total,
3781 "linear offset {} out of bounds for dims {:?}",
3782 linear,
3783 dims
3784 );
3785 let mut remaining = linear;
3786 let mut out = Vec::with_capacity(dims.len());
3787 for &dim in dims {
3788 out.push(remaining % dim);
3789 remaining /= dim;
3790 }
3791 Ok(out)
3792}
3793
3794fn encode_col_major_linear(indices: &[usize], dims: &[usize]) -> Result<usize> {
3795 anyhow::ensure!(
3796 indices.len() == dims.len(),
3797 "index rank {} does not match dims {:?}",
3798 indices.len(),
3799 dims
3800 );
3801 let mut linear = 0usize;
3802 let mut stride = 1usize;
3803 for (&index, &dim) in indices.iter().zip(dims.iter()) {
3804 anyhow::ensure!(
3805 index < dim,
3806 "index {} out of bounds for dimension {}",
3807 index,
3808 dim
3809 );
3810 linear += index * stride;
3811 stride = stride
3812 .checked_mul(dim)
3813 .ok_or_else(|| anyhow::anyhow!("stride overflow"))?;
3814 }
3815 Ok(linear)
3816}
3817
3818fn decode_linear_with_order(
3819 linear: usize,
3820 dims: &[usize],
3821 order: LinearizationOrder,
3822) -> Result<Vec<usize>> {
3823 let total = checked_product(dims)?;
3824 anyhow::ensure!(
3825 linear < total,
3826 "linear offset {} out of bounds for dims {:?}",
3827 linear,
3828 dims
3829 );
3830
3831 let mut remaining = linear;
3832 let mut out = vec![0usize; dims.len()];
3833 match order {
3834 LinearizationOrder::ColumnMajor => {
3835 for (slot, &dim) in out.iter_mut().zip(dims.iter()) {
3836 *slot = remaining % dim;
3837 remaining /= dim;
3838 }
3839 }
3840 LinearizationOrder::RowMajor => {
3841 for (slot, &dim) in out.iter_mut().rev().zip(dims.iter().rev()) {
3842 *slot = remaining % dim;
3843 remaining /= dim;
3844 }
3845 }
3846 }
3847 Ok(out)
3848}
3849
3850fn encode_linear_with_order(
3851 indices: &[usize],
3852 dims: &[usize],
3853 order: LinearizationOrder,
3854) -> Result<usize> {
3855 match order {
3856 LinearizationOrder::ColumnMajor => encode_col_major_linear(indices, dims),
3857 LinearizationOrder::RowMajor => {
3858 anyhow::ensure!(
3859 indices.len() == dims.len(),
3860 "index rank {} does not match dims {:?}",
3861 indices.len(),
3862 dims
3863 );
3864 let mut linear = 0usize;
3865 let mut stride = 1usize;
3866 for (&index, &dim) in indices.iter().rev().zip(dims.iter().rev()) {
3867 anyhow::ensure!(
3868 index < dim,
3869 "index {} out of bounds for dimension {}",
3870 index,
3871 dim
3872 );
3873 linear += index * stride;
3874 stride = stride
3875 .checked_mul(dim)
3876 .ok_or_else(|| anyhow::anyhow!("stride overflow"))?;
3877 }
3878 Ok(linear)
3879 }
3880 }
3881}