Skip to main content

tensor4all_tensorbackend/
tenferro_bridge.rs

1//! Bridge helpers between tensor4all storage snapshots and tenferro tensors.
2
3use std::cell::RefCell;
4use std::cmp::Reverse;
5use std::collections::{HashMap, HashSet};
6use std::env;
7use std::time::{Duration, Instant};
8
9use anyhow::{anyhow, ensure, Result};
10use num_complex::{Complex32, Complex64};
11use omeco::ScoreFunction;
12use tenferro::traced_tensor::{einsum_subscripts_with, EinsumOptimize};
13use tenferro::{
14    DType, EinsumSubscripts, Tensor as NativeTensor, TensorBackend, TensorRead, TensorView,
15    TracedTensor,
16};
17use tenferro_einsum::{ContractionOptimizerOptions, ContractionTree, Subscripts};
18
19use crate::any_scalar::promote_scalar_native;
20use crate::context::{
21    default_engine_buffer_pool_stats, reset_default_engine, reset_default_engine_buffer_pool,
22    with_default_backend, with_default_engine,
23};
24use crate::memory::release_process_allocator_cached_memory;
25use crate::storage::Storage;
26#[cfg(test)]
27use crate::storage::StorageRepr;
28use crate::tensor_element::TensorElement;
29use crate::AnyScalar;
30
31/// Read-only native tensor input that can either borrow external payload data
32/// or own a temporary materialized tensor.
33pub enum NativeTensorReadInput<'a> {
34    /// Borrowed read-only tensor input.
35    Borrowed(TensorRead<'a>),
36    /// Owned temporary tensor input.
37    Owned(NativeTensor),
38}
39
40impl<'a> NativeTensorReadInput<'a> {
41    /// Return this input as a read-only tenferro tensor input.
42    pub fn as_read(&'a self) -> TensorRead<'a> {
43        match self {
44            Self::Borrowed(read) => *read,
45            Self::Owned(tensor) => TensorRead::from_tensor(tensor),
46        }
47    }
48
49    /// Return the scalar dtype of this input.
50    pub fn dtype(&self) -> DType {
51        match self {
52            Self::Borrowed(read) => read.dtype(),
53            Self::Owned(tensor) => tensor.dtype(),
54        }
55    }
56
57    /// Return the tensor shape of this input.
58    pub fn shape(&self) -> &[usize] {
59        match self {
60            Self::Borrowed(read) => read.shape(),
61            Self::Owned(tensor) => tensor.shape(),
62        }
63    }
64}
65
66#[cfg(test)]
67use std::cell::Cell;
68
69#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
70enum NativeEinsumPath {
71    Owned,
72    Borrowed,
73    BorrowedWithConversions,
74}
75
76#[derive(Debug, Clone, Hash, PartialEq, Eq)]
77struct NativeOperandSignature {
78    shape: Vec<usize>,
79    ids: Vec<u32>,
80    dtype: DType,
81}
82
83#[derive(Debug, Clone, Hash, PartialEq, Eq)]
84struct NativeEinsumSignature {
85    path: NativeEinsumPath,
86    operands: Vec<NativeOperandSignature>,
87    output_ids: Vec<u32>,
88}
89
90#[derive(Debug, Default, Clone)]
91struct NativeEinsumProfileEntry {
92    calls: usize,
93    total_time: Duration,
94}
95
96thread_local! {
97    static NATIVE_EINSUM_PROFILE_STATE: RefCell<HashMap<NativeEinsumSignature, NativeEinsumProfileEntry>> =
98        RefCell::new(HashMap::new());
99    static NATIVE_EINSUM_TRACE_STATE: RefCell<HashSet<NativeEinsumSignature>> =
100        RefCell::new(HashSet::new());
101}
102
103#[cfg(test)]
104thread_local! {
105    static FORCE_NATIVE_EINSUM_PROFILE: Cell<bool> = const { Cell::new(false) };
106}
107
108fn native_einsum_profile_enabled() -> bool {
109    #[cfg(test)]
110    if FORCE_NATIVE_EINSUM_PROFILE.with(Cell::get) {
111        return true;
112    }
113    env::var("T4A_PROFILE_NATIVE_EINSUM").is_ok()
114}
115
116fn native_einsum_path_trace_enabled() -> bool {
117    env::var("T4A_TRACE_NATIVE_EINSUM_PATHS").is_ok()
118}
119
120fn native_einsum_path_trace_min_bytes() -> usize {
121    env::var("T4A_TRACE_NATIVE_EINSUM_MIN_BYTES")
122        .ok()
123        .and_then(|value| value.parse().ok())
124        .unwrap_or(0)
125}
126
127fn native_einsum_path_trace_max_signatures() -> usize {
128    env::var("T4A_TRACE_NATIVE_EINSUM_MAX_SIGNATURES")
129        .ok()
130        .and_then(|value| value.parse().ok())
131        .unwrap_or(64)
132}
133
134fn native_einsum_pool_trace_enabled() -> bool {
135    env::var("T4A_TRACE_NATIVE_EINSUM_POOL").is_ok()
136}
137
138fn native_einsum_pool_trace_min_output_bytes() -> usize {
139    env::var("T4A_TRACE_NATIVE_EINSUM_POOL_MIN_OUTPUT_BYTES")
140        .ok()
141        .and_then(|value| value.parse().ok())
142        .unwrap_or(0)
143}
144
145fn native_einsum_pool_trace_min_retained_bytes() -> usize {
146    env::var("T4A_TRACE_NATIVE_EINSUM_POOL_MIN_RETAINED_BYTES")
147        .ok()
148        .and_then(|value| value.parse().ok())
149        .unwrap_or(0)
150}
151
152fn reset_native_einsum_engine_after_call() -> bool {
153    env::var("T4A_RESET_NATIVE_EINSUM_ENGINE_AFTER_CALL").is_ok()
154}
155
156fn reset_native_einsum_buffer_pool_after_call() -> bool {
157    env::var("T4A_RESET_NATIVE_EINSUM_BUFFER_POOL_AFTER_CALL").is_ok()
158}
159
160fn release_allocator_after_native_einsum_call() -> bool {
161    env::var("T4A_RELEASE_ALLOCATOR_AFTER_NATIVE_EINSUM_CALL").is_ok()
162}
163
164#[cfg(test)]
165pub(crate) fn set_native_einsum_profile_enabled_for_tests(enabled: bool) {
166    FORCE_NATIVE_EINSUM_PROFILE.with(|slot| slot.set(enabled));
167}
168
169fn native_einsum_signature(
170    path: NativeEinsumPath,
171    operands: &[(&NativeTensor, &[usize])],
172    output_ids: &[u32],
173) -> NativeEinsumSignature {
174    NativeEinsumSignature {
175        path,
176        operands: operands
177            .iter()
178            .map(|(tensor, ids)| NativeOperandSignature {
179                shape: tensor.shape().to_vec(),
180                ids: ids.iter().map(|&id| id as u32).collect(),
181                dtype: tensor.dtype(),
182            })
183            .collect(),
184        output_ids: output_ids.to_vec(),
185    }
186}
187
188fn record_native_einsum_profile(
189    path: NativeEinsumPath,
190    operands: &[(&NativeTensor, &[usize])],
191    output_ids: &[u32],
192    elapsed: Duration,
193) {
194    if !native_einsum_profile_enabled() {
195        return;
196    }
197    let signature = native_einsum_signature(path, operands, output_ids);
198    NATIVE_EINSUM_PROFILE_STATE.with(|state| {
199        let mut state = state.borrow_mut();
200        let entry = state.entry(signature).or_default();
201        entry.calls += 1;
202        entry.total_time += elapsed;
203    });
204}
205
206fn dtype_size_bytes(dtype: DType) -> usize {
207    match dtype {
208        DType::F32 => 4,
209        DType::F64 => 8,
210        DType::C32 => 8,
211        DType::C64 => 16,
212        DType::I64 => 8,
213    }
214}
215
216fn native_tensor_bytes(tensor: &NativeTensor) -> usize {
217    tensor
218        .shape()
219        .iter()
220        .copied()
221        .fold(1usize, usize::saturating_mul)
222        .saturating_mul(dtype_size_bytes(tensor.dtype()))
223}
224
225fn format_label(label: u32) -> String {
226    char::from_u32(label).map_or_else(|| label.to_string(), |label| label.to_string())
227}
228
229fn format_labels(labels: &[u32]) -> String {
230    if labels.is_empty() {
231        "scalar".to_string()
232    } else {
233        labels
234            .iter()
235            .map(|&label| format_label(label))
236            .collect::<Vec<_>>()
237            .join("")
238    }
239}
240
241fn label_dims(subscripts: &Subscripts, shapes: &[Vec<usize>]) -> Result<HashMap<u32, usize>> {
242    let mut dims = HashMap::new();
243    for (labels, shape) in subscripts.inputs.iter().zip(shapes.iter()) {
244        ensure!(
245            labels.len() == shape.len(),
246            "einsum labels {:?} do not match shape {:?}",
247            labels,
248            shape
249        );
250        for (&label, &dim) in labels.iter().zip(shape.iter()) {
251            if let Some(previous) = dims.insert(label, dim) {
252                ensure!(
253                    previous == dim,
254                    "inconsistent dimension for einsum label {}: {} vs {}",
255                    format_label(label),
256                    previous,
257                    dim
258                );
259            }
260        }
261    }
262    Ok(dims)
263}
264
265fn labels_size(labels: &[u32], dims: &HashMap<u32, usize>) -> usize {
266    labels.iter().fold(1usize, |size, label| {
267        size.saturating_mul(dims.get(label).copied().unwrap_or(1))
268    })
269}
270
271fn union_labels(lhs: &[u32], rhs: &[u32]) -> Vec<u32> {
272    let mut seen = HashSet::new();
273    let mut labels = Vec::new();
274    for &label in lhs.iter().chain(rhs.iter()) {
275        if seen.insert(label) {
276            labels.push(label);
277        }
278    }
279    labels
280}
281
282#[derive(Debug)]
283struct NativeEinsumPlanReport {
284    lines: Vec<String>,
285    peak_intermediate_bytes: usize,
286}
287
288fn time_optimized_contraction_options() -> ContractionOptimizerOptions {
289    ContractionOptimizerOptions {
290        score: ScoreFunction::time_optimized(),
291        ..ContractionOptimizerOptions::default()
292    }
293}
294
295fn native_einsum_plan_report_with_options(
296    signature: &NativeEinsumSignature,
297    optimizer_name: &'static str,
298    options: &ContractionOptimizerOptions,
299) -> Result<NativeEinsumPlanReport> {
300    let input_ids = signature
301        .operands
302        .iter()
303        .map(|operand| operand.ids.as_slice())
304        .collect::<Vec<_>>();
305    let subscripts_string = build_einsum_subscripts(&input_ids, &signature.output_ids)?;
306    let subscripts = Subscripts {
307        inputs: input_ids.iter().map(|ids| ids.to_vec()).collect(),
308        output: signature.output_ids.clone(),
309    };
310    let shapes = signature
311        .operands
312        .iter()
313        .map(|operand| operand.shape.clone())
314        .collect::<Vec<_>>();
315    let shape_refs = shapes.iter().map(Vec::as_slice).collect::<Vec<_>>();
316    let tree = ContractionTree::optimize_with_options(&subscripts, &shape_refs, options)
317        .map_err(|e| anyhow!("failed to optimize native einsum path: {e}"))?;
318    let dims = label_dims(&subscripts, &shapes)?;
319    let dtype = signature
320        .operands
321        .first()
322        .map(|operand| operand.dtype)
323        .unwrap_or(DType::F64);
324    let dtype_size = dtype_size_bytes(dtype);
325
326    let mut lines = Vec::new();
327    lines.push(format!(
328        "optimizer={optimizer_name} subscripts={subscripts_string} dtype={dtype:?} steps={}",
329        tree.step_count()
330    ));
331    let mut peak_intermediate_elems = 1usize;
332    for step in 0..tree.step_count() {
333        let Some((left, right)) = tree.step_pair(step) else {
334            continue;
335        };
336        let Some((lhs, rhs, out)) = tree.step_subscripts(step) else {
337            continue;
338        };
339        let lhs_elems = labels_size(lhs, &dims);
340        let rhs_elems = labels_size(rhs, &dims);
341        let out_elems = labels_size(out, &dims);
342        let flop_index_elems = labels_size(&union_labels(lhs, rhs), &dims);
343        peak_intermediate_elems = peak_intermediate_elems.max(out_elems);
344        lines.push(format!(
345            "  step {step:02}: pair=({left},{right}) {}[{}] x {}[{}] -> {}[{}]  flop_index={}  intermediate={} elems ({:.3} MiB)",
346            format_labels(lhs),
347            lhs_elems,
348            format_labels(rhs),
349            rhs_elems,
350            format_labels(out),
351            out_elems,
352            flop_index_elems,
353            out_elems,
354            out_elems as f64 * dtype_size as f64 / (1024.0 * 1024.0),
355        ));
356    }
357    let peak_intermediate_bytes = peak_intermediate_elems.saturating_mul(dtype_size);
358    lines.push(format!(
359        "  peak_intermediate={} elems ({:.3} MiB)",
360        peak_intermediate_elems,
361        peak_intermediate_bytes as f64 / (1024.0 * 1024.0)
362    ));
363
364    Ok(NativeEinsumPlanReport {
365        lines,
366        peak_intermediate_bytes,
367    })
368}
369
370fn native_einsum_time_optimized_plan_report(
371    signature: &NativeEinsumSignature,
372) -> Result<NativeEinsumPlanReport> {
373    native_einsum_plan_report_with_options(
374        signature,
375        "time_optimized",
376        &time_optimized_contraction_options(),
377    )
378}
379
380fn native_einsum_balanced_plan_report(
381    signature: &NativeEinsumSignature,
382) -> Result<NativeEinsumPlanReport> {
383    native_einsum_plan_report_with_options(
384        signature,
385        "balanced_default",
386        &ContractionOptimizerOptions::default(),
387    )
388}
389
390fn maybe_trace_native_einsum_path(
391    path: NativeEinsumPath,
392    operands: &[(&NativeTensor, &[usize])],
393    output_ids: &[u32],
394) {
395    if !native_einsum_path_trace_enabled() {
396        return;
397    }
398    let signature = native_einsum_signature(path, operands, output_ids);
399    let report = match native_einsum_time_optimized_plan_report(&signature) {
400        Ok(report) if report.peak_intermediate_bytes >= native_einsum_path_trace_min_bytes() => {
401            report
402        }
403        Ok(_) => return,
404        Err(err) => {
405            eprintln!("native_einsum path trace failed: {err:#}");
406            return;
407        }
408    };
409
410    let max_signatures = native_einsum_path_trace_max_signatures();
411    let should_trace = NATIVE_EINSUM_TRACE_STATE.with(|state| {
412        let mut state = state.borrow_mut();
413        if state.len() >= max_signatures || state.contains(&signature) {
414            false
415        } else {
416            state.insert(signature.clone());
417            true
418        }
419    });
420    if !should_trace {
421        return;
422    }
423
424    eprintln!("=== native_einsum Path Trace ===");
425    eprintln!(
426        "path={:?} output_ids={:?}",
427        signature.path, signature.output_ids
428    );
429    for operand in &signature.operands {
430        eprintln!(
431            "  operand shape={:?} ids={:?} dtype={:?}",
432            operand.shape, operand.ids, operand.dtype
433        );
434    }
435    for line in report.lines {
436        eprintln!("{line}");
437    }
438    if env::var("T4A_TRACE_NATIVE_EINSUM_COMPARE_BALANCED").is_ok() {
439        match native_einsum_balanced_plan_report(&signature) {
440            Ok(balanced) => {
441                for line in balanced.lines {
442                    eprintln!("{line}");
443                }
444            }
445            Err(err) => eprintln!("balanced native_einsum path trace failed: {err:#}"),
446        }
447    }
448}
449
450/// Reset the aggregated native einsum profile.
451pub fn reset_native_einsum_profile() {
452    NATIVE_EINSUM_PROFILE_STATE.with(|state| state.borrow_mut().clear());
453    NATIVE_EINSUM_TRACE_STATE.with(|state| state.borrow_mut().clear());
454}
455
456/// Print and clear the aggregated native einsum profile.
457pub fn print_and_reset_native_einsum_profile() {
458    if !native_einsum_profile_enabled() {
459        return;
460    }
461    NATIVE_EINSUM_PROFILE_STATE.with(|state| {
462        let mut entries: Vec<_> = state
463            .borrow()
464            .iter()
465            .map(|(k, v)| (k.clone(), v.clone()))
466            .collect();
467        state.borrow_mut().clear();
468        entries.sort_by_key(|(_, entry)| Reverse(entry.total_time));
469
470        eprintln!("=== native_einsum Profile ===");
471        for (idx, (signature, entry)) in entries.into_iter().take(20).enumerate() {
472            eprintln!(
473                "#{idx:02} path={:?} calls={} total={:.3}s per_call={:.3}us output_ids={:?}",
474                signature.path,
475                entry.calls,
476                entry.total_time.as_secs_f64(),
477                entry.total_time.as_secs_f64() * 1e6 / entry.calls as f64,
478                signature.output_ids,
479            );
480            for operand in &signature.operands {
481                eprintln!(
482                    "     shape={:?} ids={:?} dtype={:?}",
483                    operand.shape, operand.ids, operand.dtype
484                );
485            }
486            match native_einsum_time_optimized_plan_report(&signature) {
487                Ok(report) => {
488                    for line in report.lines {
489                        eprintln!("     {line}");
490                    }
491                }
492                Err(err) => eprintln!("     path report failed: {err:#}"),
493            }
494        }
495    });
496}
497
498fn common_dtype(dtypes: &[DType]) -> DType {
499    let has_f64 = dtypes.contains(&DType::F64);
500    let has_c64 = dtypes.contains(&DType::C64);
501    let has_c32 = dtypes.contains(&DType::C32);
502    let has_i64 = dtypes.contains(&DType::I64);
503    let has_complex = has_c64 || has_c32;
504    if has_c64 || (has_f64 && has_complex) {
505        DType::C64
506    } else if has_c32 {
507        DType::C32
508    } else if has_f64 || has_i64 {
509        DType::F64
510    } else {
511        DType::F32
512    }
513}
514
515fn convert_tensor(tensor: &NativeTensor, to: DType) -> Result<NativeTensor> {
516    if tensor.dtype() == to {
517        return Ok(tensor.clone());
518    }
519    with_default_backend(|backend| backend.with_exec_session(|exec| exec.convert(tensor, to)))
520        .map_err(|e| anyhow!("tensor conversion to {to:?} failed: {e}"))
521}
522
523fn ids_to_subscript(ids: &[u32]) -> Result<String> {
524    const LETTERS: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
525    let mut out = String::with_capacity(ids.len());
526    for &id in ids {
527        let idx = usize::try_from(id).unwrap_or(usize::MAX);
528        let letter = LETTERS
529            .get(idx)
530            .ok_or_else(|| anyhow!("einsum label {id} exceeds supported label range"))?;
531        out.push(char::from(*letter));
532    }
533    Ok(out)
534}
535
536fn build_einsum_subscripts(operands: &[&[u32]], output_ids: &[u32]) -> Result<String> {
537    let inputs = operands
538        .iter()
539        .map(|ids| ids_to_subscript(ids))
540        .collect::<Result<Vec<_>>>()?;
541    Ok(format!(
542        "{}->{}",
543        inputs.join(","),
544        ids_to_subscript(output_ids)?
545    ))
546}
547
548fn cached_einsum_native_tensors(
549    inputs: &[&NativeTensor],
550    subscripts: &EinsumSubscripts,
551) -> Result<NativeTensor> {
552    let placeholders = inputs
553        .iter()
554        .map(|tensor| TracedTensor::input_concrete_shape(tensor.dtype(), tensor.shape()))
555        .collect::<Vec<_>>();
556    let placeholder_refs = placeholders.iter().collect::<Vec<_>>();
557    let bindings = placeholders
558        .iter()
559        .zip(inputs.iter())
560        .map(|(placeholder, tensor)| (placeholder, *tensor))
561        .collect::<Vec<_>>();
562
563    let trace_pool = native_einsum_pool_trace_enabled();
564    let pool_before = trace_pool.then(default_engine_buffer_pool_stats);
565    let result = with_default_engine(|engine| {
566        let mut result = einsum_subscripts_with(
567            engine,
568            &placeholder_refs,
569            subscripts,
570            EinsumOptimize::default(),
571        )
572        .map_err(|e| anyhow!("native einsum failed: {e}"))?;
573        result
574            .eval_with_inputs(engine, &bindings)
575            .cloned()
576            .map_err(|e| anyhow!("native einsum failed: {e}"))
577    })?;
578    if trace_pool {
579        let pool_after = default_engine_buffer_pool_stats();
580        let output_bytes = native_tensor_bytes(&result);
581        let retained_threshold = native_einsum_pool_trace_min_retained_bytes();
582        if pool_after != pool_before.unwrap_or_default()
583            && pool_after.capacity_bytes >= retained_threshold
584            || output_bytes >= native_einsum_pool_trace_min_output_bytes()
585        {
586            let before = pool_before.unwrap_or_default();
587            eprintln!(
588                "native_einsum pool subscripts={subscripts:?} before_buffers={} before_capacity={:.3} MiB after_buffers={} after_capacity={:.3} MiB output_shape={:?} output_bytes={:.3} MiB",
589                before.buffers,
590                before.capacity_bytes as f64 / (1024.0 * 1024.0),
591                pool_after.buffers,
592                pool_after.capacity_bytes as f64 / (1024.0 * 1024.0),
593                result.shape(),
594                output_bytes as f64 / (1024.0 * 1024.0),
595            );
596        }
597    }
598    if reset_native_einsum_engine_after_call() {
599        let before_reset = trace_pool.then(default_engine_buffer_pool_stats);
600        reset_default_engine();
601        if trace_pool
602            && before_reset.unwrap_or_default().capacity_bytes
603                >= native_einsum_pool_trace_min_retained_bytes()
604        {
605            let before = before_reset.unwrap_or_default();
606            let after = default_engine_buffer_pool_stats();
607            eprintln!(
608                "native_einsum engine_reset before_buffers={} before_capacity={:.3} MiB after_buffers={} after_capacity={:.3} MiB",
609                before.buffers,
610                before.capacity_bytes as f64 / (1024.0 * 1024.0),
611                after.buffers,
612                after.capacity_bytes as f64 / (1024.0 * 1024.0),
613            );
614        }
615    } else if reset_native_einsum_buffer_pool_after_call() {
616        let before_clear = trace_pool.then(default_engine_buffer_pool_stats);
617        reset_default_engine_buffer_pool();
618        if trace_pool
619            && before_clear.unwrap_or_default().capacity_bytes
620                >= native_einsum_pool_trace_min_retained_bytes()
621        {
622            let before = before_clear.unwrap_or_default();
623            let after = default_engine_buffer_pool_stats();
624            eprintln!(
625                "native_einsum pool_reset before_buffers={} before_capacity={:.3} MiB after_buffers={} after_capacity={:.3} MiB",
626                before.buffers,
627                before.capacity_bytes as f64 / (1024.0 * 1024.0),
628                after.buffers,
629                after.capacity_bytes as f64 / (1024.0 * 1024.0),
630            );
631        }
632    }
633    if release_allocator_after_native_einsum_call() {
634        let report = release_process_allocator_cached_memory();
635        if trace_pool && (report.released_bytes.unwrap_or(0) > 0 || report.success == Some(true)) {
636            eprintln!(
637                "native_einsum allocator_pressure_relief supported={} released_bytes={:?} success={:?}",
638                report.supported,
639                report.released_bytes,
640                report.success,
641            );
642        }
643    }
644    Ok(result)
645}
646
647fn cached_einsum_native_reads(
648    inputs: &[TensorRead<'_>],
649    subscripts: &Subscripts,
650) -> Result<NativeTensor> {
651    with_default_backend(|backend| {
652        tenferro_einsum::eager_einsum_read_subscripts(backend, inputs, subscripts)
653            .map_err(|e| anyhow!("native read einsum failed: {e}"))
654    })
655}
656
657/// Build native einsum ids for a binary contraction.
658pub(crate) fn build_binary_einsum_ids(
659    lhs_rank: usize,
660    axes_a: &[usize],
661    rhs_rank: usize,
662    axes_b: &[usize],
663) -> Result<(Vec<u32>, Vec<u32>, Vec<u32>)> {
664    ensure!(
665        axes_a.len() == axes_b.len(),
666        "contract axis length mismatch: lhs {:?}, rhs {:?}",
667        axes_a,
668        axes_b
669    );
670
671    let mut lhs_ids = vec![u32::MAX; lhs_rank];
672    let mut rhs_ids = vec![u32::MAX; rhs_rank];
673    let mut next_id = 0u32;
674
675    let mut seen_lhs = vec![false; lhs_rank];
676    let mut seen_rhs = vec![false; rhs_rank];
677
678    for (&lhs_axis, &rhs_axis) in axes_a.iter().zip(axes_b.iter()) {
679        ensure!(
680            lhs_axis < lhs_rank,
681            "lhs contract axis {lhs_axis} out of range"
682        );
683        ensure!(
684            rhs_axis < rhs_rank,
685            "rhs contract axis {rhs_axis} out of range"
686        );
687        ensure!(
688            !seen_lhs[lhs_axis],
689            "duplicate lhs contract axis {lhs_axis}"
690        );
691        ensure!(
692            !seen_rhs[rhs_axis],
693            "duplicate rhs contract axis {rhs_axis}"
694        );
695        seen_lhs[lhs_axis] = true;
696        seen_rhs[rhs_axis] = true;
697        lhs_ids[lhs_axis] = next_id;
698        rhs_ids[rhs_axis] = next_id;
699        next_id += 1;
700    }
701
702    let mut output_ids = Vec::with_capacity(lhs_rank + rhs_rank - 2 * axes_a.len());
703    for (axis, slot) in lhs_ids.iter_mut().enumerate() {
704        if *slot == u32::MAX {
705            *slot = next_id;
706            output_ids.push(next_id);
707            next_id += 1;
708        } else {
709            let _ = axis;
710        }
711    }
712    for slot in &mut rhs_ids {
713        if *slot == u32::MAX {
714            *slot = next_id;
715            output_ids.push(next_id);
716            next_id += 1;
717        }
718    }
719
720    Ok((lhs_ids, rhs_ids, output_ids))
721}
722
723/// Build a dense native tensor from column-major data.
724pub fn dense_native_tensor_from_col_major<T: TensorElement>(
725    data: &[T],
726    logical_dims: &[usize],
727) -> Result<NativeTensor> {
728    T::dense_native_tensor_from_col_major(data, logical_dims)
729}
730
731/// Build a dense native tensor whose logical values are diagonal.
732pub fn diag_native_tensor_from_col_major<T: TensorElement>(
733    data: &[T],
734    logical_rank: usize,
735) -> Result<NativeTensor> {
736    T::diag_native_tensor_from_col_major(data, logical_rank)
737}
738
739/// Convert storage to a dense native tensor.
740pub fn storage_to_native_tensor(storage: &Storage, logical_dims: &[usize]) -> Result<NativeTensor> {
741    if storage.is_c64() {
742        dense_native_tensor_from_col_major(
743            &storage
744                .to_dense_c64_col_major_vec(logical_dims)
745                .map_err(|e| anyhow!("dense c64 materialization failed: {e}"))?,
746            logical_dims,
747        )
748    } else {
749        dense_native_tensor_from_col_major(
750            &storage
751                .to_dense_f64_col_major_vec(logical_dims)
752                .map_err(|e| anyhow!("dense f64 materialization failed: {e}"))?,
753            logical_dims,
754        )
755    }
756}
757
758/// Build a read-only native tensor input over the compact storage payload.
759///
760/// Contiguous payloads are borrowed without copying. Non-contiguous payloads
761/// are materialized into an owned native tensor.
762pub fn storage_payload_native_read_input(storage: &Storage) -> Result<NativeTensorReadInput<'_>> {
763    if storage.is_f64() {
764        if let Some(view) = storage
765            .payload_f64_col_major_view_if_contiguous()
766            .map_err(anyhow::Error::msg)?
767        {
768            return Ok(NativeTensorReadInput::Borrowed(TensorRead::from_view(
769                TensorView::f64(storage.payload_dims(), view)?,
770            )));
771        }
772        Ok(NativeTensorReadInput::Owned(NativeTensor::from_vec(
773            storage.payload_dims().to_vec(),
774            storage
775                .payload_f64_col_major_vec()
776                .map_err(anyhow::Error::msg)?,
777        )))
778    } else if storage.is_c64() {
779        if let Some(view) = storage
780            .payload_c64_col_major_view_if_contiguous()
781            .map_err(anyhow::Error::msg)?
782        {
783            return Ok(NativeTensorReadInput::Borrowed(TensorRead::from_view(
784                TensorView::c64(storage.payload_dims(), view)?,
785            )));
786        }
787        Ok(NativeTensorReadInput::Owned(NativeTensor::from_vec(
788            storage.payload_dims().to_vec(),
789            storage
790                .payload_c64_col_major_vec()
791                .map_err(anyhow::Error::msg)?,
792        )))
793    } else {
794        Err(anyhow!("unsupported storage scalar type"))
795    }
796}
797
798/// Materialize a native tensor into dense storage.
799pub fn native_tensor_primal_to_storage(tensor: &NativeTensor) -> Result<Storage> {
800    match tensor.dtype() {
801        DType::F32 => Storage::from_dense_col_major(
802            tensor
803                .as_slice::<f32>()
804                .ok_or_else(|| anyhow!("failed to read f32 native tensor"))?
805                .iter()
806                .map(|&value| value as f64)
807                .collect::<Vec<_>>(),
808            tensor.shape(),
809        ),
810        DType::F64 => Storage::from_dense_col_major(
811            tensor
812                .as_slice::<f64>()
813                .ok_or_else(|| anyhow!("failed to read f64 native tensor"))?
814                .to_vec(),
815            tensor.shape(),
816        ),
817        DType::I64 => Storage::from_dense_col_major(
818            tensor
819                .as_slice::<i64>()
820                .ok_or_else(|| anyhow!("failed to read i64 native tensor"))?
821                .iter()
822                .map(|&value| value as f64)
823                .collect::<Vec<_>>(),
824            tensor.shape(),
825        ),
826        DType::C32 => Storage::from_dense_col_major(
827            tensor
828                .as_slice::<Complex32>()
829                .ok_or_else(|| anyhow!("failed to read c32 native tensor"))?
830                .iter()
831                .map(|&value| Complex64::new(value.re as f64, value.im as f64))
832                .collect::<Vec<_>>(),
833            tensor.shape(),
834        ),
835        DType::C64 => Storage::from_dense_col_major(
836            tensor
837                .as_slice::<Complex64>()
838                .ok_or_else(|| anyhow!("failed to read c64 native tensor"))?
839                .to_vec(),
840            tensor.shape(),
841        ),
842    }
843    .map_err(|e| anyhow!("native tensor snapshot materialization failed: {e}"))
844}
845
846/// Materialize dense f64 values from a native tensor.
847pub fn native_tensor_primal_to_dense_f64_col_major(tensor: &NativeTensor) -> Result<Vec<f64>> {
848    match tensor.dtype() {
849        DType::F32 => Ok(tensor
850            .as_slice::<f32>()
851            .ok_or_else(|| anyhow!("failed to read f32 native tensor"))?
852            .iter()
853            .map(|&value| value as f64)
854            .collect()),
855        DType::F64 => <f64 as TensorElement>::dense_values_from_native_col_major(tensor),
856        DType::I64 => Ok(tensor
857            .as_slice::<i64>()
858            .ok_or_else(|| anyhow!("failed to read i64 native tensor"))?
859            .iter()
860            .map(|&value| value as f64)
861            .collect()),
862        other => Err(anyhow!("expected real native tensor, got dtype {other:?}")),
863    }
864}
865
866/// Materialize dense Complex64 values from a native tensor.
867pub fn native_tensor_primal_to_dense_c64_col_major(
868    tensor: &NativeTensor,
869) -> Result<Vec<Complex64>> {
870    match tensor.dtype() {
871        DType::C32 => Ok(tensor
872            .as_slice::<Complex32>()
873            .ok_or_else(|| anyhow!("failed to read c32 native tensor"))?
874            .iter()
875            .map(|&value| Complex64::new(value.re as f64, value.im as f64))
876            .collect()),
877        DType::C64 => <Complex64 as TensorElement>::dense_values_from_native_col_major(tensor),
878        other => Err(anyhow!(
879            "expected complex native tensor, got dtype {other:?}"
880        )),
881    }
882}
883
884/// Materialize dense column-major values from a native tensor.
885pub fn native_tensor_primal_to_dense_col_major<T: TensorElement>(
886    tensor: &NativeTensor,
887) -> Result<Vec<T>> {
888    T::dense_values_from_native_col_major(tensor)
889}
890
891/// Extract the diagonal payload from a real native tensor.
892pub fn native_tensor_primal_to_diag_f64(tensor: &NativeTensor) -> Result<Vec<f64>> {
893    match tensor.dtype() {
894        DType::F32 => {
895            let promoted = convert_tensor(tensor, DType::F64)?;
896            <f64 as TensorElement>::diag_values_from_native_temp(&promoted)
897        }
898        DType::F64 => <f64 as TensorElement>::diag_values_from_native_temp(tensor),
899        DType::I64 => {
900            let promoted = convert_tensor(tensor, DType::F64)?;
901            <f64 as TensorElement>::diag_values_from_native_temp(&promoted)
902        }
903        other => Err(anyhow!("expected real native tensor, got dtype {other:?}")),
904    }
905}
906
907/// Extract the diagonal payload from a complex native tensor.
908pub fn native_tensor_primal_to_diag_c64(tensor: &NativeTensor) -> Result<Vec<Complex64>> {
909    match tensor.dtype() {
910        DType::C32 => {
911            let promoted = convert_tensor(tensor, DType::C64)?;
912            <Complex64 as TensorElement>::diag_values_from_native_temp(&promoted)
913        }
914        DType::C64 => <Complex64 as TensorElement>::diag_values_from_native_temp(tensor),
915        other => Err(anyhow!(
916            "expected complex native tensor, got dtype {other:?}"
917        )),
918    }
919}
920
921/// Reshape a native tensor without changing its column-major linearization.
922pub fn reshape_col_major_native_tensor(
923    tensor: &NativeTensor,
924    logical_dims: &[usize],
925) -> Result<NativeTensor> {
926    with_default_backend(|backend| tensor.reshape(logical_dims, backend))
927        .map_err(|e| anyhow!("native reshape failed: {e}"))
928}
929
930/// Compute a QR decomposition on a native tensor.
931pub fn qr_native_tensor(tensor: &NativeTensor) -> Result<(NativeTensor, NativeTensor)> {
932    with_default_backend(|backend| tensor.qr(backend)).map_err(|e| anyhow!("native QR failed: {e}"))
933}
934
935/// Compute an SVD on a native tensor.
936pub fn svd_native_tensor(
937    tensor: &NativeTensor,
938) -> Result<(NativeTensor, NativeTensor, NativeTensor)> {
939    with_default_backend(|backend| tensor.svd(backend))
940        .map_err(|e| anyhow!("native SVD failed: {e}"))
941}
942
943/// Sum all elements of a native tensor, returning a dynamic scalar.
944pub fn sum_native_tensor(tensor: &NativeTensor) -> Result<AnyScalar> {
945    let reduced = if tensor.shape().is_empty() {
946        tensor.clone()
947    } else {
948        let axes: Vec<usize> = (0..tensor.shape().len()).collect();
949        with_default_backend(|backend| tensor.reduce_sum(&axes, backend))
950            .map_err(|e| anyhow!("native sum failed: {e}"))?
951    };
952    AnyScalar::from_native(reduced)
953}
954
955/// Return the tangent tensor when present.
956///
957/// Plain `Tensor` values do not carry tangent storage, so this bridge returns
958/// `None`.
959pub fn tangent_native_tensor(_tensor: &NativeTensor) -> Option<NativeTensor> {
960    None
961}
962
963/// Multiply a native tensor by a dynamic scalar.
964pub fn scale_native_tensor(tensor: &NativeTensor, scalar: &AnyScalar) -> Result<NativeTensor> {
965    let target = common_dtype(&[tensor.dtype(), scalar.as_native().dtype()]);
966    let tensor = convert_tensor(tensor, target)?;
967    let scalar = promote_scalar_native(scalar.as_native(), target)?;
968
969    match target {
970        DType::F32 => {
971            let factor = scalar
972                .as_slice::<f32>()
973                .and_then(|values| values.first().copied())
974                .ok_or_else(|| anyhow!("failed to read promoted f32 scalar"))?;
975            let values = tensor
976                .as_slice::<f32>()
977                .ok_or_else(|| anyhow!("failed to read promoted f32 tensor"))?
978                .iter()
979                .map(|&value| value * factor)
980                .collect::<Vec<_>>();
981            Ok(NativeTensor::from_vec(tensor.shape().to_vec(), values))
982        }
983        DType::F64 => {
984            let factor = scalar
985                .as_slice::<f64>()
986                .and_then(|values| values.first().copied())
987                .ok_or_else(|| anyhow!("failed to read promoted f64 scalar"))?;
988            let values = tensor
989                .as_slice::<f64>()
990                .ok_or_else(|| anyhow!("failed to read promoted f64 tensor"))?
991                .iter()
992                .map(|&value| value * factor)
993                .collect::<Vec<_>>();
994            Ok(NativeTensor::from_vec(tensor.shape().to_vec(), values))
995        }
996        DType::C32 => {
997            let factor = scalar
998                .as_slice::<Complex32>()
999                .and_then(|values| values.first().copied())
1000                .ok_or_else(|| anyhow!("failed to read promoted c32 scalar"))?;
1001            let values = tensor
1002                .as_slice::<Complex32>()
1003                .ok_or_else(|| anyhow!("failed to read promoted c32 tensor"))?
1004                .iter()
1005                .map(|&value| value * factor)
1006                .collect::<Vec<_>>();
1007            Ok(NativeTensor::from_vec(tensor.shape().to_vec(), values))
1008        }
1009        DType::C64 => {
1010            let factor = scalar
1011                .as_slice::<Complex64>()
1012                .and_then(|values| values.first().copied())
1013                .ok_or_else(|| anyhow!("failed to read promoted c64 scalar"))?;
1014            let values = tensor
1015                .as_slice::<Complex64>()
1016                .ok_or_else(|| anyhow!("failed to read promoted c64 tensor"))?
1017                .iter()
1018                .map(|&value| value * factor)
1019                .collect::<Vec<_>>();
1020            Ok(NativeTensor::from_vec(tensor.shape().to_vec(), values))
1021        }
1022        DType::I64 => Err(anyhow!("scale_native_tensor does not support i64 tensors")),
1023    }
1024}
1025
1026/// Compute `a * lhs + b * rhs`.
1027pub fn axpby_native_tensor(
1028    lhs: &NativeTensor,
1029    a: &AnyScalar,
1030    rhs: &NativeTensor,
1031    b: &AnyScalar,
1032) -> Result<NativeTensor> {
1033    ensure!(
1034        lhs.shape() == rhs.shape(),
1035        "axpby requires matching tensor shapes, got lhs {:?} and rhs {:?}",
1036        lhs.shape(),
1037        rhs.shape()
1038    );
1039
1040    let target = common_dtype(&[
1041        lhs.dtype(),
1042        rhs.dtype(),
1043        a.as_native().dtype(),
1044        b.as_native().dtype(),
1045    ]);
1046    let lhs = convert_tensor(lhs, target)?;
1047    let rhs = convert_tensor(rhs, target)?;
1048    let a = promote_scalar_native(a.as_native(), target)?;
1049    let b = promote_scalar_native(b.as_native(), target)?;
1050
1051    match target {
1052        DType::F32 => {
1053            let a = a
1054                .as_slice::<f32>()
1055                .and_then(|values| values.first().copied())
1056                .ok_or_else(|| anyhow!("failed to read promoted f32 scalar a"))?;
1057            let b = b
1058                .as_slice::<f32>()
1059                .and_then(|values| values.first().copied())
1060                .ok_or_else(|| anyhow!("failed to read promoted f32 scalar b"))?;
1061            let lhs_values = lhs
1062                .as_slice::<f32>()
1063                .ok_or_else(|| anyhow!("failed to read promoted f32 lhs"))?;
1064            let rhs_values = rhs
1065                .as_slice::<f32>()
1066                .ok_or_else(|| anyhow!("failed to read promoted f32 rhs"))?;
1067            let values = lhs_values
1068                .iter()
1069                .zip(rhs_values.iter())
1070                .map(|(&x, &y)| a * x + b * y)
1071                .collect::<Vec<_>>();
1072            Ok(NativeTensor::from_vec(lhs.shape().to_vec(), values))
1073        }
1074        DType::F64 => {
1075            let a = a
1076                .as_slice::<f64>()
1077                .and_then(|values| values.first().copied())
1078                .ok_or_else(|| anyhow!("failed to read promoted f64 scalar a"))?;
1079            let b = b
1080                .as_slice::<f64>()
1081                .and_then(|values| values.first().copied())
1082                .ok_or_else(|| anyhow!("failed to read promoted f64 scalar b"))?;
1083            let lhs_values = lhs
1084                .as_slice::<f64>()
1085                .ok_or_else(|| anyhow!("failed to read promoted f64 lhs"))?;
1086            let rhs_values = rhs
1087                .as_slice::<f64>()
1088                .ok_or_else(|| anyhow!("failed to read promoted f64 rhs"))?;
1089            let values = lhs_values
1090                .iter()
1091                .zip(rhs_values.iter())
1092                .map(|(&x, &y)| a * x + b * y)
1093                .collect::<Vec<_>>();
1094            Ok(NativeTensor::from_vec(lhs.shape().to_vec(), values))
1095        }
1096        DType::C32 => {
1097            let a = a
1098                .as_slice::<Complex32>()
1099                .and_then(|values| values.first().copied())
1100                .ok_or_else(|| anyhow!("failed to read promoted c32 scalar a"))?;
1101            let b = b
1102                .as_slice::<Complex32>()
1103                .and_then(|values| values.first().copied())
1104                .ok_or_else(|| anyhow!("failed to read promoted c32 scalar b"))?;
1105            let lhs_values = lhs
1106                .as_slice::<Complex32>()
1107                .ok_or_else(|| anyhow!("failed to read promoted c32 lhs"))?;
1108            let rhs_values = rhs
1109                .as_slice::<Complex32>()
1110                .ok_or_else(|| anyhow!("failed to read promoted c32 rhs"))?;
1111            let values = lhs_values
1112                .iter()
1113                .zip(rhs_values.iter())
1114                .map(|(&x, &y)| a * x + b * y)
1115                .collect::<Vec<_>>();
1116            Ok(NativeTensor::from_vec(lhs.shape().to_vec(), values))
1117        }
1118        DType::C64 => {
1119            let a = a
1120                .as_slice::<Complex64>()
1121                .and_then(|values| values.first().copied())
1122                .ok_or_else(|| anyhow!("failed to read promoted c64 scalar a"))?;
1123            let b = b
1124                .as_slice::<Complex64>()
1125                .and_then(|values| values.first().copied())
1126                .ok_or_else(|| anyhow!("failed to read promoted c64 scalar b"))?;
1127            let lhs_values = lhs
1128                .as_slice::<Complex64>()
1129                .ok_or_else(|| anyhow!("failed to read promoted c64 lhs"))?;
1130            let rhs_values = rhs
1131                .as_slice::<Complex64>()
1132                .ok_or_else(|| anyhow!("failed to read promoted c64 rhs"))?;
1133            let values = lhs_values
1134                .iter()
1135                .zip(rhs_values.iter())
1136                .map(|(&x, &y)| a * x + b * y)
1137                .collect::<Vec<_>>();
1138            Ok(NativeTensor::from_vec(lhs.shape().to_vec(), values))
1139        }
1140        DType::I64 => Err(anyhow!("axpby_native_tensor does not support i64 tensors")),
1141    }
1142}
1143
1144/// Execute a cached einsum over owned native tensors.
1145///
1146/// This is the consuming bridge used by higher-level owned contraction APIs.
1147/// Inputs are promoted to a common dtype before tenferro evaluates the
1148/// contraction. Repeated calls with the same equation and shapes reuse
1149/// tenferro's process-global contraction path cache.
1150///
1151/// # Arguments
1152/// * `operands` - Native tensors paired with numeric einsum labels for each axis.
1153/// * `output_ids` - Numeric labels to keep in the result, in output axis order.
1154///
1155/// # Returns
1156/// The contracted native tensor in the promoted common dtype.
1157///
1158/// # Errors
1159/// Returns an error if the operand list is empty, any label list length does
1160/// not match its tensor rank, label generation exceeds the supported range, or
1161/// the backend contraction fails.
1162///
1163/// # Examples
1164/// ```
1165/// use tensor4all_tensorbackend::einsum_native_tensors_owned;
1166/// use tenferro::Tensor as NativeTensor;
1167///
1168/// let lhs = NativeTensor::from_vec(vec![2, 3], vec![1.0_f64; 6]);
1169/// let rhs = NativeTensor::from_vec(vec![3, 2], vec![1.0_f64; 6]);
1170/// let result = einsum_native_tensors_owned(vec![(lhs, vec![0, 1]), (rhs, vec![1, 2])], &[0, 2]).unwrap();
1171///
1172/// assert_eq!(result.shape(), &[2, 2]);
1173/// assert_eq!(result.as_slice::<f64>().unwrap(), &[3.0, 3.0, 3.0, 3.0]);
1174/// ```
1175pub fn einsum_native_tensors_owned(
1176    operands: Vec<(NativeTensor, Vec<usize>)>,
1177    output_ids: &[usize],
1178) -> Result<NativeTensor> {
1179    ensure!(
1180        !operands.is_empty(),
1181        "native einsum requires at least one operand"
1182    );
1183
1184    let target = common_dtype(
1185        &operands
1186            .iter()
1187            .map(|(tensor, _)| tensor.dtype())
1188            .collect::<Vec<_>>(),
1189    );
1190
1191    let mut converted = Vec::with_capacity(operands.len());
1192    let mut input_ids = Vec::with_capacity(operands.len());
1193    for (tensor, ids) in operands {
1194        ensure!(
1195            tensor.shape().len() == ids.len(),
1196            "einsum id list {:?} does not match tensor shape {:?}",
1197            ids,
1198            tensor.shape()
1199        );
1200        let tensor = if tensor.dtype() == target {
1201            tensor
1202        } else {
1203            convert_tensor(&tensor, target)?
1204        };
1205        input_ids.push(ids.into_iter().map(|id| id as u32).collect::<Vec<_>>());
1206        converted.push(tensor);
1207    }
1208
1209    let input_slices = input_ids.iter().map(Vec::as_slice).collect::<Vec<_>>();
1210    let output_ids_u32 = output_ids.iter().map(|&id| id as u32).collect::<Vec<_>>();
1211    let subscripts = EinsumSubscripts::new(&input_slices, &output_ids_u32);
1212
1213    let input_refs = converted.iter().collect::<Vec<_>>();
1214    let trace_ids = input_ids
1215        .iter()
1216        .map(|ids| ids.iter().map(|&id| id as usize).collect::<Vec<_>>())
1217        .collect::<Vec<_>>();
1218    let trace_operands = input_refs
1219        .iter()
1220        .zip(trace_ids.iter())
1221        .map(|(tensor, ids)| (*tensor, ids.as_slice()))
1222        .collect::<Vec<_>>();
1223    maybe_trace_native_einsum_path(NativeEinsumPath::Owned, &trace_operands, &output_ids_u32);
1224    let started = Instant::now();
1225    let result = cached_einsum_native_tensors(&input_refs, &subscripts)?;
1226    record_native_einsum_profile(
1227        NativeEinsumPath::Owned,
1228        &trace_operands,
1229        &output_ids_u32,
1230        started.elapsed(),
1231    );
1232    Ok(result)
1233}
1234
1235/// Execute a cached einsum over borrowed native tensors.
1236///
1237/// Inputs are promoted to a common dtype before contraction. Operands that
1238/// already have the target dtype are passed to the backend by reference;
1239/// operands with another dtype are converted into temporary native tensors and
1240/// then borrowed for the contraction. Repeated calls with the same equation
1241/// and shapes reuse tenferro's process-global contraction path cache.
1242///
1243/// # Arguments
1244/// * `operands` - Native tensors paired with numeric einsum labels for each axis.
1245///   Each label slice must have the same length as the corresponding tensor rank.
1246/// * `output_ids` - Numeric labels to keep in the result, in output axis order.
1247///
1248/// # Returns
1249/// The contracted native tensor in the promoted common dtype.
1250///
1251/// # Errors
1252/// Returns an error if the operand list is empty, any label list length does
1253/// not match its tensor rank, label generation exceeds the supported range,
1254/// dtype conversion fails, or the backend contraction fails.
1255///
1256/// # Examples
1257/// ```
1258/// use tensor4all_tensorbackend::einsum_native_tensors;
1259/// use tenferro::Tensor as NativeTensor;
1260///
1261/// let lhs = NativeTensor::from_vec(vec![2, 3], vec![1.0_f64; 6]);
1262/// let rhs = NativeTensor::from_vec(vec![3, 2], vec![1.0_f64; 6]);
1263/// let result = einsum_native_tensors(&[(&lhs, &[0, 1]), (&rhs, &[1, 2])], &[0, 2]).unwrap();
1264///
1265/// assert_eq!(result.shape(), &[2, 2]);
1266/// assert_eq!(result.as_slice::<f64>().unwrap(), &[3.0, 3.0, 3.0, 3.0]);
1267/// ```
1268pub fn einsum_native_tensors(
1269    operands: &[(&NativeTensor, &[usize])],
1270    output_ids: &[usize],
1271) -> Result<NativeTensor> {
1272    ensure!(
1273        !operands.is_empty(),
1274        "native einsum requires at least one operand"
1275    );
1276
1277    let target = common_dtype(
1278        &operands
1279            .iter()
1280            .map(|(tensor, _)| tensor.dtype())
1281            .collect::<Vec<_>>(),
1282    );
1283    let mut converted = Vec::with_capacity(operands.len());
1284    let mut input_ids = Vec::with_capacity(operands.len());
1285    let mut has_conversions = false;
1286    let started = Instant::now();
1287
1288    for (tensor, ids) in operands {
1289        ensure!(
1290            tensor.shape().len() == ids.len(),
1291            "einsum id list {:?} does not match tensor shape {:?}",
1292            ids,
1293            tensor.shape()
1294        );
1295        input_ids.push(ids.iter().map(|&id| id as u32).collect::<Vec<_>>());
1296        if tensor.dtype() == target {
1297            converted.push(None);
1298        } else {
1299            converted.push(Some(convert_tensor(tensor, target)?));
1300            has_conversions = true;
1301        }
1302    }
1303
1304    let input_slices = input_ids.iter().map(Vec::as_slice).collect::<Vec<_>>();
1305    let output_ids_u32 = output_ids.iter().map(|&id| id as u32).collect::<Vec<_>>();
1306    let subscripts = EinsumSubscripts::new(&input_slices, &output_ids_u32);
1307    let input_refs = operands
1308        .iter()
1309        .zip(converted.iter())
1310        .map(|((tensor, _), converted)| converted.as_ref().unwrap_or(*tensor))
1311        .collect::<Vec<_>>();
1312    let trace_path = if has_conversions {
1313        NativeEinsumPath::BorrowedWithConversions
1314    } else {
1315        NativeEinsumPath::Borrowed
1316    };
1317    maybe_trace_native_einsum_path(trace_path, operands, &output_ids_u32);
1318    let result = cached_einsum_native_tensors(&input_refs, &subscripts)?;
1319    record_native_einsum_profile(trace_path, operands, &output_ids_u32, started.elapsed());
1320    Ok(result)
1321}
1322
1323/// Execute a cached einsum over read-only native tensor inputs.
1324///
1325/// Backends may consume borrowed host views directly or materialize/upload them
1326/// inside their execution session. Mixed dtypes are promoted by materializing
1327/// only the operands that require conversion.
1328pub fn einsum_native_tensor_reads(
1329    operands: &[(&NativeTensorReadInput<'_>, &[usize])],
1330    output_ids: &[usize],
1331) -> Result<NativeTensor> {
1332    ensure!(
1333        !operands.is_empty(),
1334        "native einsum requires at least one operand"
1335    );
1336
1337    let target = common_dtype(
1338        &operands
1339            .iter()
1340            .map(|(tensor, _)| tensor.dtype())
1341            .collect::<Vec<_>>(),
1342    );
1343    let mut converted = Vec::with_capacity(operands.len());
1344    let mut input_ids = Vec::with_capacity(operands.len());
1345    let mut read_inputs = Vec::with_capacity(operands.len());
1346
1347    for (tensor, ids) in operands {
1348        ensure!(
1349            tensor.shape().len() == ids.len(),
1350            "einsum id list {:?} does not match tensor shape {:?}",
1351            ids,
1352            tensor.shape()
1353        );
1354        input_ids.push(ids.iter().map(|&id| id as u32).collect::<Vec<_>>());
1355        if tensor.dtype() == target {
1356            converted.push(None);
1357        } else {
1358            converted.push(Some(convert_tensor(&tensor.as_read().to_tensor(), target)?));
1359        }
1360    }
1361
1362    for (tensor, converted) in operands
1363        .iter()
1364        .map(|(tensor, _)| *tensor)
1365        .zip(converted.iter())
1366    {
1367        if let Some(converted) = converted {
1368            read_inputs.push(TensorRead::from_tensor(converted));
1369        } else {
1370            read_inputs.push(tensor.as_read());
1371        }
1372    }
1373
1374    let output_ids_u32 = output_ids.iter().map(|&id| id as u32).collect::<Vec<_>>();
1375    let subscripts = Subscripts {
1376        inputs: input_ids,
1377        output: output_ids_u32,
1378    };
1379    cached_einsum_native_reads(&read_inputs, &subscripts)
1380}
1381
1382/// Permute axes of a native tensor.
1383pub fn permute_native_tensor(tensor: &NativeTensor, perm: &[usize]) -> Result<NativeTensor> {
1384    with_default_backend(|backend| tensor.transpose(perm, backend))
1385        .map_err(|e| anyhow!("native permute failed: {e}"))
1386}
1387
1388/// Contract two native tensors along matching axes.
1389pub fn contract_native_tensor(
1390    lhs: &NativeTensor,
1391    axes_a: &[usize],
1392    rhs: &NativeTensor,
1393    axes_b: &[usize],
1394) -> Result<NativeTensor> {
1395    let (lhs_ids, rhs_ids, output_ids) =
1396        build_binary_einsum_ids(lhs.shape().len(), axes_a, rhs.shape().len(), axes_b)?;
1397    let lhs_ids_usize = lhs_ids.iter().map(|&id| id as usize).collect::<Vec<_>>();
1398    let rhs_ids_usize = rhs_ids.iter().map(|&id| id as usize).collect::<Vec<_>>();
1399    let output_ids_usize = output_ids.iter().map(|&id| id as usize).collect::<Vec<_>>();
1400    let operands = [
1401        (lhs, lhs_ids_usize.as_slice()),
1402        (rhs, rhs_ids_usize.as_slice()),
1403    ];
1404    einsum_native_tensors(&operands, &output_ids_usize)
1405}
1406
1407/// Compute the outer product of two native tensors.
1408pub fn outer_product_native_tensor(lhs: &NativeTensor, rhs: &NativeTensor) -> Result<NativeTensor> {
1409    contract_native_tensor(lhs, &[], rhs, &[])
1410}
1411
1412/// Conjugate a native tensor.
1413pub fn conj_native_tensor(tensor: &NativeTensor) -> Result<NativeTensor> {
1414    match tensor.dtype() {
1415        DType::F32 | DType::F64 | DType::I64 => Ok(tensor.clone()),
1416        DType::C32 => Ok(NativeTensor::from_vec(
1417            tensor.shape().to_vec(),
1418            tensor
1419                .as_slice::<Complex32>()
1420                .ok_or_else(|| anyhow!("failed to read c32 native tensor"))?
1421                .iter()
1422                .map(|&value| value.conj())
1423                .collect::<Vec<_>>(),
1424        )),
1425        DType::C64 => Ok(NativeTensor::from_vec(
1426            tensor.shape().to_vec(),
1427            tensor
1428                .as_slice::<Complex64>()
1429                .ok_or_else(|| anyhow!("failed to read c64 native tensor"))?
1430                .iter()
1431                .map(|&value| value.conj())
1432                .collect::<Vec<_>>(),
1433        )),
1434    }
1435}
1436
1437/// Permute storage by round-tripping through native tensors.
1438pub fn permute_storage_native(
1439    storage: &Storage,
1440    logical_dims: &[usize],
1441    perm: &[usize],
1442) -> Result<Storage> {
1443    let native = storage_to_native_tensor(storage, logical_dims)?;
1444    let permuted = permute_native_tensor(&native, perm)?;
1445    native_tensor_primal_to_storage(&permuted)
1446}
1447
1448/// Contract storages via native tensors.
1449pub fn contract_storage_native(
1450    storage_a: &Storage,
1451    dims_a: &[usize],
1452    axes_a: &[usize],
1453    storage_b: &Storage,
1454    dims_b: &[usize],
1455    axes_b: &[usize],
1456    _result_dims: &[usize],
1457) -> Result<Storage> {
1458    let lhs = storage_to_native_tensor(storage_a, dims_a)?;
1459    let rhs = storage_to_native_tensor(storage_b, dims_b)?;
1460    let result = contract_native_tensor(&lhs, axes_a, &rhs, axes_b)?;
1461    native_tensor_primal_to_storage(&result)
1462}
1463
1464/// Outer-product storages via native tensors.
1465pub fn outer_product_storage_native(
1466    lhs: &Storage,
1467    lhs_dims: &[usize],
1468    rhs: &Storage,
1469    rhs_dims: &[usize],
1470    _result_dims: &[usize],
1471) -> Result<Storage> {
1472    let lhs = storage_to_native_tensor(lhs, lhs_dims)?;
1473    let rhs = storage_to_native_tensor(rhs, rhs_dims)?;
1474    let result = outer_product_native_tensor(&lhs, &rhs)?;
1475    native_tensor_primal_to_storage(&result)
1476}
1477
1478/// Scale storage by a scalar via native tensors.
1479pub fn scale_storage_native(
1480    storage: &Storage,
1481    logical_dims: &[usize],
1482    scalar: &AnyScalar,
1483) -> Result<Storage> {
1484    let native = storage_to_native_tensor(storage, logical_dims)?;
1485    let scaled = scale_native_tensor(&native, scalar)?;
1486    native_tensor_primal_to_storage(&scaled)
1487}
1488
1489/// Compute `a * lhs + b * rhs` over storages via native tensors.
1490pub fn axpby_storage_native(
1491    lhs: &Storage,
1492    lhs_dims: &[usize],
1493    a: &AnyScalar,
1494    rhs: &Storage,
1495    rhs_dims: &[usize],
1496    b: &AnyScalar,
1497) -> Result<Storage> {
1498    let lhs = storage_to_native_tensor(lhs, lhs_dims)?;
1499    let rhs = storage_to_native_tensor(rhs, rhs_dims)?;
1500    let combined = axpby_native_tensor(&lhs, a, &rhs, b)?;
1501    native_tensor_primal_to_storage(&combined)
1502}
1503
1504#[cfg(test)]
1505mod tests;