1use std::cell::RefCell;
4use std::cmp::Reverse;
5use std::collections::{HashMap, HashSet};
6use std::env;
7use std::time::{Duration, Instant};
8
9use anyhow::{anyhow, ensure, Result};
10use num_complex::{Complex32, Complex64};
11use omeco::ScoreFunction;
12use tenferro::traced_tensor::{einsum_subscripts_with, EinsumOptimize};
13use tenferro::{
14 DType, EinsumSubscripts, Tensor as NativeTensor, TensorBackend, TensorRead, TensorView,
15 TracedTensor,
16};
17use tenferro_einsum::{ContractionOptimizerOptions, ContractionTree, Subscripts};
18
19use crate::any_scalar::promote_scalar_native;
20use crate::context::{
21 default_engine_buffer_pool_stats, reset_default_engine, reset_default_engine_buffer_pool,
22 with_default_backend, with_default_engine,
23};
24use crate::memory::release_process_allocator_cached_memory;
25use crate::storage::Storage;
26#[cfg(test)]
27use crate::storage::StorageRepr;
28use crate::tensor_element::TensorElement;
29use crate::AnyScalar;
30
31pub enum NativeTensorReadInput<'a> {
34 Borrowed(TensorRead<'a>),
36 Owned(NativeTensor),
38}
39
40impl<'a> NativeTensorReadInput<'a> {
41 pub fn as_read(&'a self) -> TensorRead<'a> {
43 match self {
44 Self::Borrowed(read) => *read,
45 Self::Owned(tensor) => TensorRead::from_tensor(tensor),
46 }
47 }
48
49 pub fn dtype(&self) -> DType {
51 match self {
52 Self::Borrowed(read) => read.dtype(),
53 Self::Owned(tensor) => tensor.dtype(),
54 }
55 }
56
57 pub fn shape(&self) -> &[usize] {
59 match self {
60 Self::Borrowed(read) => read.shape(),
61 Self::Owned(tensor) => tensor.shape(),
62 }
63 }
64}
65
66#[cfg(test)]
67use std::cell::Cell;
68
69#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
70enum NativeEinsumPath {
71 Owned,
72 Borrowed,
73 BorrowedWithConversions,
74}
75
76#[derive(Debug, Clone, Hash, PartialEq, Eq)]
77struct NativeOperandSignature {
78 shape: Vec<usize>,
79 ids: Vec<u32>,
80 dtype: DType,
81}
82
83#[derive(Debug, Clone, Hash, PartialEq, Eq)]
84struct NativeEinsumSignature {
85 path: NativeEinsumPath,
86 operands: Vec<NativeOperandSignature>,
87 output_ids: Vec<u32>,
88}
89
90#[derive(Debug, Default, Clone)]
91struct NativeEinsumProfileEntry {
92 calls: usize,
93 total_time: Duration,
94}
95
96thread_local! {
97 static NATIVE_EINSUM_PROFILE_STATE: RefCell<HashMap<NativeEinsumSignature, NativeEinsumProfileEntry>> =
98 RefCell::new(HashMap::new());
99 static NATIVE_EINSUM_TRACE_STATE: RefCell<HashSet<NativeEinsumSignature>> =
100 RefCell::new(HashSet::new());
101}
102
103#[cfg(test)]
104thread_local! {
105 static FORCE_NATIVE_EINSUM_PROFILE: Cell<bool> = const { Cell::new(false) };
106}
107
108fn native_einsum_profile_enabled() -> bool {
109 #[cfg(test)]
110 if FORCE_NATIVE_EINSUM_PROFILE.with(Cell::get) {
111 return true;
112 }
113 env::var("T4A_PROFILE_NATIVE_EINSUM").is_ok()
114}
115
116fn native_einsum_path_trace_enabled() -> bool {
117 env::var("T4A_TRACE_NATIVE_EINSUM_PATHS").is_ok()
118}
119
120fn native_einsum_path_trace_min_bytes() -> usize {
121 env::var("T4A_TRACE_NATIVE_EINSUM_MIN_BYTES")
122 .ok()
123 .and_then(|value| value.parse().ok())
124 .unwrap_or(0)
125}
126
127fn native_einsum_path_trace_max_signatures() -> usize {
128 env::var("T4A_TRACE_NATIVE_EINSUM_MAX_SIGNATURES")
129 .ok()
130 .and_then(|value| value.parse().ok())
131 .unwrap_or(64)
132}
133
134fn native_einsum_pool_trace_enabled() -> bool {
135 env::var("T4A_TRACE_NATIVE_EINSUM_POOL").is_ok()
136}
137
138fn native_einsum_pool_trace_min_output_bytes() -> usize {
139 env::var("T4A_TRACE_NATIVE_EINSUM_POOL_MIN_OUTPUT_BYTES")
140 .ok()
141 .and_then(|value| value.parse().ok())
142 .unwrap_or(0)
143}
144
145fn native_einsum_pool_trace_min_retained_bytes() -> usize {
146 env::var("T4A_TRACE_NATIVE_EINSUM_POOL_MIN_RETAINED_BYTES")
147 .ok()
148 .and_then(|value| value.parse().ok())
149 .unwrap_or(0)
150}
151
152fn reset_native_einsum_engine_after_call() -> bool {
153 env::var("T4A_RESET_NATIVE_EINSUM_ENGINE_AFTER_CALL").is_ok()
154}
155
156fn reset_native_einsum_buffer_pool_after_call() -> bool {
157 env::var("T4A_RESET_NATIVE_EINSUM_BUFFER_POOL_AFTER_CALL").is_ok()
158}
159
160fn release_allocator_after_native_einsum_call() -> bool {
161 env::var("T4A_RELEASE_ALLOCATOR_AFTER_NATIVE_EINSUM_CALL").is_ok()
162}
163
164#[cfg(test)]
165pub(crate) fn set_native_einsum_profile_enabled_for_tests(enabled: bool) {
166 FORCE_NATIVE_EINSUM_PROFILE.with(|slot| slot.set(enabled));
167}
168
169fn native_einsum_signature(
170 path: NativeEinsumPath,
171 operands: &[(&NativeTensor, &[usize])],
172 output_ids: &[u32],
173) -> NativeEinsumSignature {
174 NativeEinsumSignature {
175 path,
176 operands: operands
177 .iter()
178 .map(|(tensor, ids)| NativeOperandSignature {
179 shape: tensor.shape().to_vec(),
180 ids: ids.iter().map(|&id| id as u32).collect(),
181 dtype: tensor.dtype(),
182 })
183 .collect(),
184 output_ids: output_ids.to_vec(),
185 }
186}
187
188fn record_native_einsum_profile(
189 path: NativeEinsumPath,
190 operands: &[(&NativeTensor, &[usize])],
191 output_ids: &[u32],
192 elapsed: Duration,
193) {
194 if !native_einsum_profile_enabled() {
195 return;
196 }
197 let signature = native_einsum_signature(path, operands, output_ids);
198 NATIVE_EINSUM_PROFILE_STATE.with(|state| {
199 let mut state = state.borrow_mut();
200 let entry = state.entry(signature).or_default();
201 entry.calls += 1;
202 entry.total_time += elapsed;
203 });
204}
205
206fn dtype_size_bytes(dtype: DType) -> usize {
207 match dtype {
208 DType::F32 => 4,
209 DType::F64 => 8,
210 DType::C32 => 8,
211 DType::C64 => 16,
212 DType::I64 => 8,
213 }
214}
215
216fn native_tensor_bytes(tensor: &NativeTensor) -> usize {
217 tensor
218 .shape()
219 .iter()
220 .copied()
221 .fold(1usize, usize::saturating_mul)
222 .saturating_mul(dtype_size_bytes(tensor.dtype()))
223}
224
225fn format_label(label: u32) -> String {
226 char::from_u32(label).map_or_else(|| label.to_string(), |label| label.to_string())
227}
228
229fn format_labels(labels: &[u32]) -> String {
230 if labels.is_empty() {
231 "scalar".to_string()
232 } else {
233 labels
234 .iter()
235 .map(|&label| format_label(label))
236 .collect::<Vec<_>>()
237 .join("")
238 }
239}
240
241fn label_dims(subscripts: &Subscripts, shapes: &[Vec<usize>]) -> Result<HashMap<u32, usize>> {
242 let mut dims = HashMap::new();
243 for (labels, shape) in subscripts.inputs.iter().zip(shapes.iter()) {
244 ensure!(
245 labels.len() == shape.len(),
246 "einsum labels {:?} do not match shape {:?}",
247 labels,
248 shape
249 );
250 for (&label, &dim) in labels.iter().zip(shape.iter()) {
251 if let Some(previous) = dims.insert(label, dim) {
252 ensure!(
253 previous == dim,
254 "inconsistent dimension for einsum label {}: {} vs {}",
255 format_label(label),
256 previous,
257 dim
258 );
259 }
260 }
261 }
262 Ok(dims)
263}
264
265fn labels_size(labels: &[u32], dims: &HashMap<u32, usize>) -> usize {
266 labels.iter().fold(1usize, |size, label| {
267 size.saturating_mul(dims.get(label).copied().unwrap_or(1))
268 })
269}
270
271fn union_labels(lhs: &[u32], rhs: &[u32]) -> Vec<u32> {
272 let mut seen = HashSet::new();
273 let mut labels = Vec::new();
274 for &label in lhs.iter().chain(rhs.iter()) {
275 if seen.insert(label) {
276 labels.push(label);
277 }
278 }
279 labels
280}
281
282#[derive(Debug)]
283struct NativeEinsumPlanReport {
284 lines: Vec<String>,
285 peak_intermediate_bytes: usize,
286}
287
288fn time_optimized_contraction_options() -> ContractionOptimizerOptions {
289 ContractionOptimizerOptions {
290 score: ScoreFunction::time_optimized(),
291 ..ContractionOptimizerOptions::default()
292 }
293}
294
295fn native_einsum_plan_report_with_options(
296 signature: &NativeEinsumSignature,
297 optimizer_name: &'static str,
298 options: &ContractionOptimizerOptions,
299) -> Result<NativeEinsumPlanReport> {
300 let input_ids = signature
301 .operands
302 .iter()
303 .map(|operand| operand.ids.as_slice())
304 .collect::<Vec<_>>();
305 let subscripts_string = build_einsum_subscripts(&input_ids, &signature.output_ids)?;
306 let subscripts = Subscripts {
307 inputs: input_ids.iter().map(|ids| ids.to_vec()).collect(),
308 output: signature.output_ids.clone(),
309 };
310 let shapes = signature
311 .operands
312 .iter()
313 .map(|operand| operand.shape.clone())
314 .collect::<Vec<_>>();
315 let shape_refs = shapes.iter().map(Vec::as_slice).collect::<Vec<_>>();
316 let tree = ContractionTree::optimize_with_options(&subscripts, &shape_refs, options)
317 .map_err(|e| anyhow!("failed to optimize native einsum path: {e}"))?;
318 let dims = label_dims(&subscripts, &shapes)?;
319 let dtype = signature
320 .operands
321 .first()
322 .map(|operand| operand.dtype)
323 .unwrap_or(DType::F64);
324 let dtype_size = dtype_size_bytes(dtype);
325
326 let mut lines = Vec::new();
327 lines.push(format!(
328 "optimizer={optimizer_name} subscripts={subscripts_string} dtype={dtype:?} steps={}",
329 tree.step_count()
330 ));
331 let mut peak_intermediate_elems = 1usize;
332 for step in 0..tree.step_count() {
333 let Some((left, right)) = tree.step_pair(step) else {
334 continue;
335 };
336 let Some((lhs, rhs, out)) = tree.step_subscripts(step) else {
337 continue;
338 };
339 let lhs_elems = labels_size(lhs, &dims);
340 let rhs_elems = labels_size(rhs, &dims);
341 let out_elems = labels_size(out, &dims);
342 let flop_index_elems = labels_size(&union_labels(lhs, rhs), &dims);
343 peak_intermediate_elems = peak_intermediate_elems.max(out_elems);
344 lines.push(format!(
345 " step {step:02}: pair=({left},{right}) {}[{}] x {}[{}] -> {}[{}] flop_index={} intermediate={} elems ({:.3} MiB)",
346 format_labels(lhs),
347 lhs_elems,
348 format_labels(rhs),
349 rhs_elems,
350 format_labels(out),
351 out_elems,
352 flop_index_elems,
353 out_elems,
354 out_elems as f64 * dtype_size as f64 / (1024.0 * 1024.0),
355 ));
356 }
357 let peak_intermediate_bytes = peak_intermediate_elems.saturating_mul(dtype_size);
358 lines.push(format!(
359 " peak_intermediate={} elems ({:.3} MiB)",
360 peak_intermediate_elems,
361 peak_intermediate_bytes as f64 / (1024.0 * 1024.0)
362 ));
363
364 Ok(NativeEinsumPlanReport {
365 lines,
366 peak_intermediate_bytes,
367 })
368}
369
370fn native_einsum_time_optimized_plan_report(
371 signature: &NativeEinsumSignature,
372) -> Result<NativeEinsumPlanReport> {
373 native_einsum_plan_report_with_options(
374 signature,
375 "time_optimized",
376 &time_optimized_contraction_options(),
377 )
378}
379
380fn native_einsum_balanced_plan_report(
381 signature: &NativeEinsumSignature,
382) -> Result<NativeEinsumPlanReport> {
383 native_einsum_plan_report_with_options(
384 signature,
385 "balanced_default",
386 &ContractionOptimizerOptions::default(),
387 )
388}
389
390fn maybe_trace_native_einsum_path(
391 path: NativeEinsumPath,
392 operands: &[(&NativeTensor, &[usize])],
393 output_ids: &[u32],
394) {
395 if !native_einsum_path_trace_enabled() {
396 return;
397 }
398 let signature = native_einsum_signature(path, operands, output_ids);
399 let report = match native_einsum_time_optimized_plan_report(&signature) {
400 Ok(report) if report.peak_intermediate_bytes >= native_einsum_path_trace_min_bytes() => {
401 report
402 }
403 Ok(_) => return,
404 Err(err) => {
405 eprintln!("native_einsum path trace failed: {err:#}");
406 return;
407 }
408 };
409
410 let max_signatures = native_einsum_path_trace_max_signatures();
411 let should_trace = NATIVE_EINSUM_TRACE_STATE.with(|state| {
412 let mut state = state.borrow_mut();
413 if state.len() >= max_signatures || state.contains(&signature) {
414 false
415 } else {
416 state.insert(signature.clone());
417 true
418 }
419 });
420 if !should_trace {
421 return;
422 }
423
424 eprintln!("=== native_einsum Path Trace ===");
425 eprintln!(
426 "path={:?} output_ids={:?}",
427 signature.path, signature.output_ids
428 );
429 for operand in &signature.operands {
430 eprintln!(
431 " operand shape={:?} ids={:?} dtype={:?}",
432 operand.shape, operand.ids, operand.dtype
433 );
434 }
435 for line in report.lines {
436 eprintln!("{line}");
437 }
438 if env::var("T4A_TRACE_NATIVE_EINSUM_COMPARE_BALANCED").is_ok() {
439 match native_einsum_balanced_plan_report(&signature) {
440 Ok(balanced) => {
441 for line in balanced.lines {
442 eprintln!("{line}");
443 }
444 }
445 Err(err) => eprintln!("balanced native_einsum path trace failed: {err:#}"),
446 }
447 }
448}
449
450pub fn reset_native_einsum_profile() {
452 NATIVE_EINSUM_PROFILE_STATE.with(|state| state.borrow_mut().clear());
453 NATIVE_EINSUM_TRACE_STATE.with(|state| state.borrow_mut().clear());
454}
455
456pub fn print_and_reset_native_einsum_profile() {
458 if !native_einsum_profile_enabled() {
459 return;
460 }
461 NATIVE_EINSUM_PROFILE_STATE.with(|state| {
462 let mut entries: Vec<_> = state
463 .borrow()
464 .iter()
465 .map(|(k, v)| (k.clone(), v.clone()))
466 .collect();
467 state.borrow_mut().clear();
468 entries.sort_by_key(|(_, entry)| Reverse(entry.total_time));
469
470 eprintln!("=== native_einsum Profile ===");
471 for (idx, (signature, entry)) in entries.into_iter().take(20).enumerate() {
472 eprintln!(
473 "#{idx:02} path={:?} calls={} total={:.3}s per_call={:.3}us output_ids={:?}",
474 signature.path,
475 entry.calls,
476 entry.total_time.as_secs_f64(),
477 entry.total_time.as_secs_f64() * 1e6 / entry.calls as f64,
478 signature.output_ids,
479 );
480 for operand in &signature.operands {
481 eprintln!(
482 " shape={:?} ids={:?} dtype={:?}",
483 operand.shape, operand.ids, operand.dtype
484 );
485 }
486 match native_einsum_time_optimized_plan_report(&signature) {
487 Ok(report) => {
488 for line in report.lines {
489 eprintln!(" {line}");
490 }
491 }
492 Err(err) => eprintln!(" path report failed: {err:#}"),
493 }
494 }
495 });
496}
497
498fn common_dtype(dtypes: &[DType]) -> DType {
499 let has_f64 = dtypes.contains(&DType::F64);
500 let has_c64 = dtypes.contains(&DType::C64);
501 let has_c32 = dtypes.contains(&DType::C32);
502 let has_i64 = dtypes.contains(&DType::I64);
503 let has_complex = has_c64 || has_c32;
504 if has_c64 || (has_f64 && has_complex) {
505 DType::C64
506 } else if has_c32 {
507 DType::C32
508 } else if has_f64 || has_i64 {
509 DType::F64
510 } else {
511 DType::F32
512 }
513}
514
515fn convert_tensor(tensor: &NativeTensor, to: DType) -> Result<NativeTensor> {
516 if tensor.dtype() == to {
517 return Ok(tensor.clone());
518 }
519 with_default_backend(|backend| backend.with_exec_session(|exec| exec.convert(tensor, to)))
520 .map_err(|e| anyhow!("tensor conversion to {to:?} failed: {e}"))
521}
522
523fn ids_to_subscript(ids: &[u32]) -> Result<String> {
524 const LETTERS: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
525 let mut out = String::with_capacity(ids.len());
526 for &id in ids {
527 let idx = usize::try_from(id).unwrap_or(usize::MAX);
528 let letter = LETTERS
529 .get(idx)
530 .ok_or_else(|| anyhow!("einsum label {id} exceeds supported label range"))?;
531 out.push(char::from(*letter));
532 }
533 Ok(out)
534}
535
536fn build_einsum_subscripts(operands: &[&[u32]], output_ids: &[u32]) -> Result<String> {
537 let inputs = operands
538 .iter()
539 .map(|ids| ids_to_subscript(ids))
540 .collect::<Result<Vec<_>>>()?;
541 Ok(format!(
542 "{}->{}",
543 inputs.join(","),
544 ids_to_subscript(output_ids)?
545 ))
546}
547
548fn cached_einsum_native_tensors(
549 inputs: &[&NativeTensor],
550 subscripts: &EinsumSubscripts,
551) -> Result<NativeTensor> {
552 let placeholders = inputs
553 .iter()
554 .map(|tensor| TracedTensor::input_concrete_shape(tensor.dtype(), tensor.shape()))
555 .collect::<Vec<_>>();
556 let placeholder_refs = placeholders.iter().collect::<Vec<_>>();
557 let bindings = placeholders
558 .iter()
559 .zip(inputs.iter())
560 .map(|(placeholder, tensor)| (placeholder, *tensor))
561 .collect::<Vec<_>>();
562
563 let trace_pool = native_einsum_pool_trace_enabled();
564 let pool_before = trace_pool.then(default_engine_buffer_pool_stats);
565 let result = with_default_engine(|engine| {
566 let mut result = einsum_subscripts_with(
567 engine,
568 &placeholder_refs,
569 subscripts,
570 EinsumOptimize::default(),
571 )
572 .map_err(|e| anyhow!("native einsum failed: {e}"))?;
573 result
574 .eval_with_inputs(engine, &bindings)
575 .cloned()
576 .map_err(|e| anyhow!("native einsum failed: {e}"))
577 })?;
578 if trace_pool {
579 let pool_after = default_engine_buffer_pool_stats();
580 let output_bytes = native_tensor_bytes(&result);
581 let retained_threshold = native_einsum_pool_trace_min_retained_bytes();
582 if pool_after != pool_before.unwrap_or_default()
583 && pool_after.capacity_bytes >= retained_threshold
584 || output_bytes >= native_einsum_pool_trace_min_output_bytes()
585 {
586 let before = pool_before.unwrap_or_default();
587 eprintln!(
588 "native_einsum pool subscripts={subscripts:?} before_buffers={} before_capacity={:.3} MiB after_buffers={} after_capacity={:.3} MiB output_shape={:?} output_bytes={:.3} MiB",
589 before.buffers,
590 before.capacity_bytes as f64 / (1024.0 * 1024.0),
591 pool_after.buffers,
592 pool_after.capacity_bytes as f64 / (1024.0 * 1024.0),
593 result.shape(),
594 output_bytes as f64 / (1024.0 * 1024.0),
595 );
596 }
597 }
598 if reset_native_einsum_engine_after_call() {
599 let before_reset = trace_pool.then(default_engine_buffer_pool_stats);
600 reset_default_engine();
601 if trace_pool
602 && before_reset.unwrap_or_default().capacity_bytes
603 >= native_einsum_pool_trace_min_retained_bytes()
604 {
605 let before = before_reset.unwrap_or_default();
606 let after = default_engine_buffer_pool_stats();
607 eprintln!(
608 "native_einsum engine_reset before_buffers={} before_capacity={:.3} MiB after_buffers={} after_capacity={:.3} MiB",
609 before.buffers,
610 before.capacity_bytes as f64 / (1024.0 * 1024.0),
611 after.buffers,
612 after.capacity_bytes as f64 / (1024.0 * 1024.0),
613 );
614 }
615 } else if reset_native_einsum_buffer_pool_after_call() {
616 let before_clear = trace_pool.then(default_engine_buffer_pool_stats);
617 reset_default_engine_buffer_pool();
618 if trace_pool
619 && before_clear.unwrap_or_default().capacity_bytes
620 >= native_einsum_pool_trace_min_retained_bytes()
621 {
622 let before = before_clear.unwrap_or_default();
623 let after = default_engine_buffer_pool_stats();
624 eprintln!(
625 "native_einsum pool_reset before_buffers={} before_capacity={:.3} MiB after_buffers={} after_capacity={:.3} MiB",
626 before.buffers,
627 before.capacity_bytes as f64 / (1024.0 * 1024.0),
628 after.buffers,
629 after.capacity_bytes as f64 / (1024.0 * 1024.0),
630 );
631 }
632 }
633 if release_allocator_after_native_einsum_call() {
634 let report = release_process_allocator_cached_memory();
635 if trace_pool && (report.released_bytes.unwrap_or(0) > 0 || report.success == Some(true)) {
636 eprintln!(
637 "native_einsum allocator_pressure_relief supported={} released_bytes={:?} success={:?}",
638 report.supported,
639 report.released_bytes,
640 report.success,
641 );
642 }
643 }
644 Ok(result)
645}
646
647fn cached_einsum_native_reads(
648 inputs: &[TensorRead<'_>],
649 subscripts: &Subscripts,
650) -> Result<NativeTensor> {
651 with_default_backend(|backend| {
652 tenferro_einsum::eager_einsum_read_subscripts(backend, inputs, subscripts)
653 .map_err(|e| anyhow!("native read einsum failed: {e}"))
654 })
655}
656
657pub(crate) fn build_binary_einsum_ids(
659 lhs_rank: usize,
660 axes_a: &[usize],
661 rhs_rank: usize,
662 axes_b: &[usize],
663) -> Result<(Vec<u32>, Vec<u32>, Vec<u32>)> {
664 ensure!(
665 axes_a.len() == axes_b.len(),
666 "contract axis length mismatch: lhs {:?}, rhs {:?}",
667 axes_a,
668 axes_b
669 );
670
671 let mut lhs_ids = vec![u32::MAX; lhs_rank];
672 let mut rhs_ids = vec![u32::MAX; rhs_rank];
673 let mut next_id = 0u32;
674
675 let mut seen_lhs = vec![false; lhs_rank];
676 let mut seen_rhs = vec![false; rhs_rank];
677
678 for (&lhs_axis, &rhs_axis) in axes_a.iter().zip(axes_b.iter()) {
679 ensure!(
680 lhs_axis < lhs_rank,
681 "lhs contract axis {lhs_axis} out of range"
682 );
683 ensure!(
684 rhs_axis < rhs_rank,
685 "rhs contract axis {rhs_axis} out of range"
686 );
687 ensure!(
688 !seen_lhs[lhs_axis],
689 "duplicate lhs contract axis {lhs_axis}"
690 );
691 ensure!(
692 !seen_rhs[rhs_axis],
693 "duplicate rhs contract axis {rhs_axis}"
694 );
695 seen_lhs[lhs_axis] = true;
696 seen_rhs[rhs_axis] = true;
697 lhs_ids[lhs_axis] = next_id;
698 rhs_ids[rhs_axis] = next_id;
699 next_id += 1;
700 }
701
702 let mut output_ids = Vec::with_capacity(lhs_rank + rhs_rank - 2 * axes_a.len());
703 for (axis, slot) in lhs_ids.iter_mut().enumerate() {
704 if *slot == u32::MAX {
705 *slot = next_id;
706 output_ids.push(next_id);
707 next_id += 1;
708 } else {
709 let _ = axis;
710 }
711 }
712 for slot in &mut rhs_ids {
713 if *slot == u32::MAX {
714 *slot = next_id;
715 output_ids.push(next_id);
716 next_id += 1;
717 }
718 }
719
720 Ok((lhs_ids, rhs_ids, output_ids))
721}
722
723pub fn dense_native_tensor_from_col_major<T: TensorElement>(
725 data: &[T],
726 logical_dims: &[usize],
727) -> Result<NativeTensor> {
728 T::dense_native_tensor_from_col_major(data, logical_dims)
729}
730
731pub fn diag_native_tensor_from_col_major<T: TensorElement>(
733 data: &[T],
734 logical_rank: usize,
735) -> Result<NativeTensor> {
736 T::diag_native_tensor_from_col_major(data, logical_rank)
737}
738
739pub fn storage_to_native_tensor(storage: &Storage, logical_dims: &[usize]) -> Result<NativeTensor> {
741 if storage.is_c64() {
742 dense_native_tensor_from_col_major(
743 &storage
744 .to_dense_c64_col_major_vec(logical_dims)
745 .map_err(|e| anyhow!("dense c64 materialization failed: {e}"))?,
746 logical_dims,
747 )
748 } else {
749 dense_native_tensor_from_col_major(
750 &storage
751 .to_dense_f64_col_major_vec(logical_dims)
752 .map_err(|e| anyhow!("dense f64 materialization failed: {e}"))?,
753 logical_dims,
754 )
755 }
756}
757
758pub fn storage_payload_native_read_input(storage: &Storage) -> Result<NativeTensorReadInput<'_>> {
763 if storage.is_f64() {
764 if let Some(view) = storage
765 .payload_f64_col_major_view_if_contiguous()
766 .map_err(anyhow::Error::msg)?
767 {
768 return Ok(NativeTensorReadInput::Borrowed(TensorRead::from_view(
769 TensorView::f64(storage.payload_dims(), view)?,
770 )));
771 }
772 Ok(NativeTensorReadInput::Owned(NativeTensor::from_vec(
773 storage.payload_dims().to_vec(),
774 storage
775 .payload_f64_col_major_vec()
776 .map_err(anyhow::Error::msg)?,
777 )))
778 } else if storage.is_c64() {
779 if let Some(view) = storage
780 .payload_c64_col_major_view_if_contiguous()
781 .map_err(anyhow::Error::msg)?
782 {
783 return Ok(NativeTensorReadInput::Borrowed(TensorRead::from_view(
784 TensorView::c64(storage.payload_dims(), view)?,
785 )));
786 }
787 Ok(NativeTensorReadInput::Owned(NativeTensor::from_vec(
788 storage.payload_dims().to_vec(),
789 storage
790 .payload_c64_col_major_vec()
791 .map_err(anyhow::Error::msg)?,
792 )))
793 } else {
794 Err(anyhow!("unsupported storage scalar type"))
795 }
796}
797
798pub fn native_tensor_primal_to_storage(tensor: &NativeTensor) -> Result<Storage> {
800 match tensor.dtype() {
801 DType::F32 => Storage::from_dense_col_major(
802 tensor
803 .as_slice::<f32>()
804 .ok_or_else(|| anyhow!("failed to read f32 native tensor"))?
805 .iter()
806 .map(|&value| value as f64)
807 .collect::<Vec<_>>(),
808 tensor.shape(),
809 ),
810 DType::F64 => Storage::from_dense_col_major(
811 tensor
812 .as_slice::<f64>()
813 .ok_or_else(|| anyhow!("failed to read f64 native tensor"))?
814 .to_vec(),
815 tensor.shape(),
816 ),
817 DType::I64 => Storage::from_dense_col_major(
818 tensor
819 .as_slice::<i64>()
820 .ok_or_else(|| anyhow!("failed to read i64 native tensor"))?
821 .iter()
822 .map(|&value| value as f64)
823 .collect::<Vec<_>>(),
824 tensor.shape(),
825 ),
826 DType::C32 => Storage::from_dense_col_major(
827 tensor
828 .as_slice::<Complex32>()
829 .ok_or_else(|| anyhow!("failed to read c32 native tensor"))?
830 .iter()
831 .map(|&value| Complex64::new(value.re as f64, value.im as f64))
832 .collect::<Vec<_>>(),
833 tensor.shape(),
834 ),
835 DType::C64 => Storage::from_dense_col_major(
836 tensor
837 .as_slice::<Complex64>()
838 .ok_or_else(|| anyhow!("failed to read c64 native tensor"))?
839 .to_vec(),
840 tensor.shape(),
841 ),
842 }
843 .map_err(|e| anyhow!("native tensor snapshot materialization failed: {e}"))
844}
845
846pub fn native_tensor_primal_to_dense_f64_col_major(tensor: &NativeTensor) -> Result<Vec<f64>> {
848 match tensor.dtype() {
849 DType::F32 => Ok(tensor
850 .as_slice::<f32>()
851 .ok_or_else(|| anyhow!("failed to read f32 native tensor"))?
852 .iter()
853 .map(|&value| value as f64)
854 .collect()),
855 DType::F64 => <f64 as TensorElement>::dense_values_from_native_col_major(tensor),
856 DType::I64 => Ok(tensor
857 .as_slice::<i64>()
858 .ok_or_else(|| anyhow!("failed to read i64 native tensor"))?
859 .iter()
860 .map(|&value| value as f64)
861 .collect()),
862 other => Err(anyhow!("expected real native tensor, got dtype {other:?}")),
863 }
864}
865
866pub fn native_tensor_primal_to_dense_c64_col_major(
868 tensor: &NativeTensor,
869) -> Result<Vec<Complex64>> {
870 match tensor.dtype() {
871 DType::C32 => Ok(tensor
872 .as_slice::<Complex32>()
873 .ok_or_else(|| anyhow!("failed to read c32 native tensor"))?
874 .iter()
875 .map(|&value| Complex64::new(value.re as f64, value.im as f64))
876 .collect()),
877 DType::C64 => <Complex64 as TensorElement>::dense_values_from_native_col_major(tensor),
878 other => Err(anyhow!(
879 "expected complex native tensor, got dtype {other:?}"
880 )),
881 }
882}
883
884pub fn native_tensor_primal_to_dense_col_major<T: TensorElement>(
886 tensor: &NativeTensor,
887) -> Result<Vec<T>> {
888 T::dense_values_from_native_col_major(tensor)
889}
890
891pub fn native_tensor_primal_to_diag_f64(tensor: &NativeTensor) -> Result<Vec<f64>> {
893 match tensor.dtype() {
894 DType::F32 => {
895 let promoted = convert_tensor(tensor, DType::F64)?;
896 <f64 as TensorElement>::diag_values_from_native_temp(&promoted)
897 }
898 DType::F64 => <f64 as TensorElement>::diag_values_from_native_temp(tensor),
899 DType::I64 => {
900 let promoted = convert_tensor(tensor, DType::F64)?;
901 <f64 as TensorElement>::diag_values_from_native_temp(&promoted)
902 }
903 other => Err(anyhow!("expected real native tensor, got dtype {other:?}")),
904 }
905}
906
907pub fn native_tensor_primal_to_diag_c64(tensor: &NativeTensor) -> Result<Vec<Complex64>> {
909 match tensor.dtype() {
910 DType::C32 => {
911 let promoted = convert_tensor(tensor, DType::C64)?;
912 <Complex64 as TensorElement>::diag_values_from_native_temp(&promoted)
913 }
914 DType::C64 => <Complex64 as TensorElement>::diag_values_from_native_temp(tensor),
915 other => Err(anyhow!(
916 "expected complex native tensor, got dtype {other:?}"
917 )),
918 }
919}
920
921pub fn reshape_col_major_native_tensor(
923 tensor: &NativeTensor,
924 logical_dims: &[usize],
925) -> Result<NativeTensor> {
926 with_default_backend(|backend| tensor.reshape(logical_dims, backend))
927 .map_err(|e| anyhow!("native reshape failed: {e}"))
928}
929
930pub fn qr_native_tensor(tensor: &NativeTensor) -> Result<(NativeTensor, NativeTensor)> {
932 with_default_backend(|backend| tensor.qr(backend)).map_err(|e| anyhow!("native QR failed: {e}"))
933}
934
935pub fn svd_native_tensor(
937 tensor: &NativeTensor,
938) -> Result<(NativeTensor, NativeTensor, NativeTensor)> {
939 with_default_backend(|backend| tensor.svd(backend))
940 .map_err(|e| anyhow!("native SVD failed: {e}"))
941}
942
943pub fn sum_native_tensor(tensor: &NativeTensor) -> Result<AnyScalar> {
945 let reduced = if tensor.shape().is_empty() {
946 tensor.clone()
947 } else {
948 let axes: Vec<usize> = (0..tensor.shape().len()).collect();
949 with_default_backend(|backend| tensor.reduce_sum(&axes, backend))
950 .map_err(|e| anyhow!("native sum failed: {e}"))?
951 };
952 AnyScalar::from_native(reduced)
953}
954
955pub fn tangent_native_tensor(_tensor: &NativeTensor) -> Option<NativeTensor> {
960 None
961}
962
963pub fn scale_native_tensor(tensor: &NativeTensor, scalar: &AnyScalar) -> Result<NativeTensor> {
965 let target = common_dtype(&[tensor.dtype(), scalar.as_native().dtype()]);
966 let tensor = convert_tensor(tensor, target)?;
967 let scalar = promote_scalar_native(scalar.as_native(), target)?;
968
969 match target {
970 DType::F32 => {
971 let factor = scalar
972 .as_slice::<f32>()
973 .and_then(|values| values.first().copied())
974 .ok_or_else(|| anyhow!("failed to read promoted f32 scalar"))?;
975 let values = tensor
976 .as_slice::<f32>()
977 .ok_or_else(|| anyhow!("failed to read promoted f32 tensor"))?
978 .iter()
979 .map(|&value| value * factor)
980 .collect::<Vec<_>>();
981 Ok(NativeTensor::from_vec(tensor.shape().to_vec(), values))
982 }
983 DType::F64 => {
984 let factor = scalar
985 .as_slice::<f64>()
986 .and_then(|values| values.first().copied())
987 .ok_or_else(|| anyhow!("failed to read promoted f64 scalar"))?;
988 let values = tensor
989 .as_slice::<f64>()
990 .ok_or_else(|| anyhow!("failed to read promoted f64 tensor"))?
991 .iter()
992 .map(|&value| value * factor)
993 .collect::<Vec<_>>();
994 Ok(NativeTensor::from_vec(tensor.shape().to_vec(), values))
995 }
996 DType::C32 => {
997 let factor = scalar
998 .as_slice::<Complex32>()
999 .and_then(|values| values.first().copied())
1000 .ok_or_else(|| anyhow!("failed to read promoted c32 scalar"))?;
1001 let values = tensor
1002 .as_slice::<Complex32>()
1003 .ok_or_else(|| anyhow!("failed to read promoted c32 tensor"))?
1004 .iter()
1005 .map(|&value| value * factor)
1006 .collect::<Vec<_>>();
1007 Ok(NativeTensor::from_vec(tensor.shape().to_vec(), values))
1008 }
1009 DType::C64 => {
1010 let factor = scalar
1011 .as_slice::<Complex64>()
1012 .and_then(|values| values.first().copied())
1013 .ok_or_else(|| anyhow!("failed to read promoted c64 scalar"))?;
1014 let values = tensor
1015 .as_slice::<Complex64>()
1016 .ok_or_else(|| anyhow!("failed to read promoted c64 tensor"))?
1017 .iter()
1018 .map(|&value| value * factor)
1019 .collect::<Vec<_>>();
1020 Ok(NativeTensor::from_vec(tensor.shape().to_vec(), values))
1021 }
1022 DType::I64 => Err(anyhow!("scale_native_tensor does not support i64 tensors")),
1023 }
1024}
1025
1026pub fn axpby_native_tensor(
1028 lhs: &NativeTensor,
1029 a: &AnyScalar,
1030 rhs: &NativeTensor,
1031 b: &AnyScalar,
1032) -> Result<NativeTensor> {
1033 ensure!(
1034 lhs.shape() == rhs.shape(),
1035 "axpby requires matching tensor shapes, got lhs {:?} and rhs {:?}",
1036 lhs.shape(),
1037 rhs.shape()
1038 );
1039
1040 let target = common_dtype(&[
1041 lhs.dtype(),
1042 rhs.dtype(),
1043 a.as_native().dtype(),
1044 b.as_native().dtype(),
1045 ]);
1046 let lhs = convert_tensor(lhs, target)?;
1047 let rhs = convert_tensor(rhs, target)?;
1048 let a = promote_scalar_native(a.as_native(), target)?;
1049 let b = promote_scalar_native(b.as_native(), target)?;
1050
1051 match target {
1052 DType::F32 => {
1053 let a = a
1054 .as_slice::<f32>()
1055 .and_then(|values| values.first().copied())
1056 .ok_or_else(|| anyhow!("failed to read promoted f32 scalar a"))?;
1057 let b = b
1058 .as_slice::<f32>()
1059 .and_then(|values| values.first().copied())
1060 .ok_or_else(|| anyhow!("failed to read promoted f32 scalar b"))?;
1061 let lhs_values = lhs
1062 .as_slice::<f32>()
1063 .ok_or_else(|| anyhow!("failed to read promoted f32 lhs"))?;
1064 let rhs_values = rhs
1065 .as_slice::<f32>()
1066 .ok_or_else(|| anyhow!("failed to read promoted f32 rhs"))?;
1067 let values = lhs_values
1068 .iter()
1069 .zip(rhs_values.iter())
1070 .map(|(&x, &y)| a * x + b * y)
1071 .collect::<Vec<_>>();
1072 Ok(NativeTensor::from_vec(lhs.shape().to_vec(), values))
1073 }
1074 DType::F64 => {
1075 let a = a
1076 .as_slice::<f64>()
1077 .and_then(|values| values.first().copied())
1078 .ok_or_else(|| anyhow!("failed to read promoted f64 scalar a"))?;
1079 let b = b
1080 .as_slice::<f64>()
1081 .and_then(|values| values.first().copied())
1082 .ok_or_else(|| anyhow!("failed to read promoted f64 scalar b"))?;
1083 let lhs_values = lhs
1084 .as_slice::<f64>()
1085 .ok_or_else(|| anyhow!("failed to read promoted f64 lhs"))?;
1086 let rhs_values = rhs
1087 .as_slice::<f64>()
1088 .ok_or_else(|| anyhow!("failed to read promoted f64 rhs"))?;
1089 let values = lhs_values
1090 .iter()
1091 .zip(rhs_values.iter())
1092 .map(|(&x, &y)| a * x + b * y)
1093 .collect::<Vec<_>>();
1094 Ok(NativeTensor::from_vec(lhs.shape().to_vec(), values))
1095 }
1096 DType::C32 => {
1097 let a = a
1098 .as_slice::<Complex32>()
1099 .and_then(|values| values.first().copied())
1100 .ok_or_else(|| anyhow!("failed to read promoted c32 scalar a"))?;
1101 let b = b
1102 .as_slice::<Complex32>()
1103 .and_then(|values| values.first().copied())
1104 .ok_or_else(|| anyhow!("failed to read promoted c32 scalar b"))?;
1105 let lhs_values = lhs
1106 .as_slice::<Complex32>()
1107 .ok_or_else(|| anyhow!("failed to read promoted c32 lhs"))?;
1108 let rhs_values = rhs
1109 .as_slice::<Complex32>()
1110 .ok_or_else(|| anyhow!("failed to read promoted c32 rhs"))?;
1111 let values = lhs_values
1112 .iter()
1113 .zip(rhs_values.iter())
1114 .map(|(&x, &y)| a * x + b * y)
1115 .collect::<Vec<_>>();
1116 Ok(NativeTensor::from_vec(lhs.shape().to_vec(), values))
1117 }
1118 DType::C64 => {
1119 let a = a
1120 .as_slice::<Complex64>()
1121 .and_then(|values| values.first().copied())
1122 .ok_or_else(|| anyhow!("failed to read promoted c64 scalar a"))?;
1123 let b = b
1124 .as_slice::<Complex64>()
1125 .and_then(|values| values.first().copied())
1126 .ok_or_else(|| anyhow!("failed to read promoted c64 scalar b"))?;
1127 let lhs_values = lhs
1128 .as_slice::<Complex64>()
1129 .ok_or_else(|| anyhow!("failed to read promoted c64 lhs"))?;
1130 let rhs_values = rhs
1131 .as_slice::<Complex64>()
1132 .ok_or_else(|| anyhow!("failed to read promoted c64 rhs"))?;
1133 let values = lhs_values
1134 .iter()
1135 .zip(rhs_values.iter())
1136 .map(|(&x, &y)| a * x + b * y)
1137 .collect::<Vec<_>>();
1138 Ok(NativeTensor::from_vec(lhs.shape().to_vec(), values))
1139 }
1140 DType::I64 => Err(anyhow!("axpby_native_tensor does not support i64 tensors")),
1141 }
1142}
1143
1144pub fn einsum_native_tensors_owned(
1176 operands: Vec<(NativeTensor, Vec<usize>)>,
1177 output_ids: &[usize],
1178) -> Result<NativeTensor> {
1179 ensure!(
1180 !operands.is_empty(),
1181 "native einsum requires at least one operand"
1182 );
1183
1184 let target = common_dtype(
1185 &operands
1186 .iter()
1187 .map(|(tensor, _)| tensor.dtype())
1188 .collect::<Vec<_>>(),
1189 );
1190
1191 let mut converted = Vec::with_capacity(operands.len());
1192 let mut input_ids = Vec::with_capacity(operands.len());
1193 for (tensor, ids) in operands {
1194 ensure!(
1195 tensor.shape().len() == ids.len(),
1196 "einsum id list {:?} does not match tensor shape {:?}",
1197 ids,
1198 tensor.shape()
1199 );
1200 let tensor = if tensor.dtype() == target {
1201 tensor
1202 } else {
1203 convert_tensor(&tensor, target)?
1204 };
1205 input_ids.push(ids.into_iter().map(|id| id as u32).collect::<Vec<_>>());
1206 converted.push(tensor);
1207 }
1208
1209 let input_slices = input_ids.iter().map(Vec::as_slice).collect::<Vec<_>>();
1210 let output_ids_u32 = output_ids.iter().map(|&id| id as u32).collect::<Vec<_>>();
1211 let subscripts = EinsumSubscripts::new(&input_slices, &output_ids_u32);
1212
1213 let input_refs = converted.iter().collect::<Vec<_>>();
1214 let trace_ids = input_ids
1215 .iter()
1216 .map(|ids| ids.iter().map(|&id| id as usize).collect::<Vec<_>>())
1217 .collect::<Vec<_>>();
1218 let trace_operands = input_refs
1219 .iter()
1220 .zip(trace_ids.iter())
1221 .map(|(tensor, ids)| (*tensor, ids.as_slice()))
1222 .collect::<Vec<_>>();
1223 maybe_trace_native_einsum_path(NativeEinsumPath::Owned, &trace_operands, &output_ids_u32);
1224 let started = Instant::now();
1225 let result = cached_einsum_native_tensors(&input_refs, &subscripts)?;
1226 record_native_einsum_profile(
1227 NativeEinsumPath::Owned,
1228 &trace_operands,
1229 &output_ids_u32,
1230 started.elapsed(),
1231 );
1232 Ok(result)
1233}
1234
1235pub fn einsum_native_tensors(
1269 operands: &[(&NativeTensor, &[usize])],
1270 output_ids: &[usize],
1271) -> Result<NativeTensor> {
1272 ensure!(
1273 !operands.is_empty(),
1274 "native einsum requires at least one operand"
1275 );
1276
1277 let target = common_dtype(
1278 &operands
1279 .iter()
1280 .map(|(tensor, _)| tensor.dtype())
1281 .collect::<Vec<_>>(),
1282 );
1283 let mut converted = Vec::with_capacity(operands.len());
1284 let mut input_ids = Vec::with_capacity(operands.len());
1285 let mut has_conversions = false;
1286 let started = Instant::now();
1287
1288 for (tensor, ids) in operands {
1289 ensure!(
1290 tensor.shape().len() == ids.len(),
1291 "einsum id list {:?} does not match tensor shape {:?}",
1292 ids,
1293 tensor.shape()
1294 );
1295 input_ids.push(ids.iter().map(|&id| id as u32).collect::<Vec<_>>());
1296 if tensor.dtype() == target {
1297 converted.push(None);
1298 } else {
1299 converted.push(Some(convert_tensor(tensor, target)?));
1300 has_conversions = true;
1301 }
1302 }
1303
1304 let input_slices = input_ids.iter().map(Vec::as_slice).collect::<Vec<_>>();
1305 let output_ids_u32 = output_ids.iter().map(|&id| id as u32).collect::<Vec<_>>();
1306 let subscripts = EinsumSubscripts::new(&input_slices, &output_ids_u32);
1307 let input_refs = operands
1308 .iter()
1309 .zip(converted.iter())
1310 .map(|((tensor, _), converted)| converted.as_ref().unwrap_or(*tensor))
1311 .collect::<Vec<_>>();
1312 let trace_path = if has_conversions {
1313 NativeEinsumPath::BorrowedWithConversions
1314 } else {
1315 NativeEinsumPath::Borrowed
1316 };
1317 maybe_trace_native_einsum_path(trace_path, operands, &output_ids_u32);
1318 let result = cached_einsum_native_tensors(&input_refs, &subscripts)?;
1319 record_native_einsum_profile(trace_path, operands, &output_ids_u32, started.elapsed());
1320 Ok(result)
1321}
1322
1323pub fn einsum_native_tensor_reads(
1329 operands: &[(&NativeTensorReadInput<'_>, &[usize])],
1330 output_ids: &[usize],
1331) -> Result<NativeTensor> {
1332 ensure!(
1333 !operands.is_empty(),
1334 "native einsum requires at least one operand"
1335 );
1336
1337 let target = common_dtype(
1338 &operands
1339 .iter()
1340 .map(|(tensor, _)| tensor.dtype())
1341 .collect::<Vec<_>>(),
1342 );
1343 let mut converted = Vec::with_capacity(operands.len());
1344 let mut input_ids = Vec::with_capacity(operands.len());
1345 let mut read_inputs = Vec::with_capacity(operands.len());
1346
1347 for (tensor, ids) in operands {
1348 ensure!(
1349 tensor.shape().len() == ids.len(),
1350 "einsum id list {:?} does not match tensor shape {:?}",
1351 ids,
1352 tensor.shape()
1353 );
1354 input_ids.push(ids.iter().map(|&id| id as u32).collect::<Vec<_>>());
1355 if tensor.dtype() == target {
1356 converted.push(None);
1357 } else {
1358 converted.push(Some(convert_tensor(&tensor.as_read().to_tensor(), target)?));
1359 }
1360 }
1361
1362 for (tensor, converted) in operands
1363 .iter()
1364 .map(|(tensor, _)| *tensor)
1365 .zip(converted.iter())
1366 {
1367 if let Some(converted) = converted {
1368 read_inputs.push(TensorRead::from_tensor(converted));
1369 } else {
1370 read_inputs.push(tensor.as_read());
1371 }
1372 }
1373
1374 let output_ids_u32 = output_ids.iter().map(|&id| id as u32).collect::<Vec<_>>();
1375 let subscripts = Subscripts {
1376 inputs: input_ids,
1377 output: output_ids_u32,
1378 };
1379 cached_einsum_native_reads(&read_inputs, &subscripts)
1380}
1381
1382pub fn permute_native_tensor(tensor: &NativeTensor, perm: &[usize]) -> Result<NativeTensor> {
1384 with_default_backend(|backend| tensor.transpose(perm, backend))
1385 .map_err(|e| anyhow!("native permute failed: {e}"))
1386}
1387
1388pub fn contract_native_tensor(
1390 lhs: &NativeTensor,
1391 axes_a: &[usize],
1392 rhs: &NativeTensor,
1393 axes_b: &[usize],
1394) -> Result<NativeTensor> {
1395 let (lhs_ids, rhs_ids, output_ids) =
1396 build_binary_einsum_ids(lhs.shape().len(), axes_a, rhs.shape().len(), axes_b)?;
1397 let lhs_ids_usize = lhs_ids.iter().map(|&id| id as usize).collect::<Vec<_>>();
1398 let rhs_ids_usize = rhs_ids.iter().map(|&id| id as usize).collect::<Vec<_>>();
1399 let output_ids_usize = output_ids.iter().map(|&id| id as usize).collect::<Vec<_>>();
1400 let operands = [
1401 (lhs, lhs_ids_usize.as_slice()),
1402 (rhs, rhs_ids_usize.as_slice()),
1403 ];
1404 einsum_native_tensors(&operands, &output_ids_usize)
1405}
1406
1407pub fn outer_product_native_tensor(lhs: &NativeTensor, rhs: &NativeTensor) -> Result<NativeTensor> {
1409 contract_native_tensor(lhs, &[], rhs, &[])
1410}
1411
1412pub fn conj_native_tensor(tensor: &NativeTensor) -> Result<NativeTensor> {
1414 match tensor.dtype() {
1415 DType::F32 | DType::F64 | DType::I64 => Ok(tensor.clone()),
1416 DType::C32 => Ok(NativeTensor::from_vec(
1417 tensor.shape().to_vec(),
1418 tensor
1419 .as_slice::<Complex32>()
1420 .ok_or_else(|| anyhow!("failed to read c32 native tensor"))?
1421 .iter()
1422 .map(|&value| value.conj())
1423 .collect::<Vec<_>>(),
1424 )),
1425 DType::C64 => Ok(NativeTensor::from_vec(
1426 tensor.shape().to_vec(),
1427 tensor
1428 .as_slice::<Complex64>()
1429 .ok_or_else(|| anyhow!("failed to read c64 native tensor"))?
1430 .iter()
1431 .map(|&value| value.conj())
1432 .collect::<Vec<_>>(),
1433 )),
1434 }
1435}
1436
1437pub fn permute_storage_native(
1439 storage: &Storage,
1440 logical_dims: &[usize],
1441 perm: &[usize],
1442) -> Result<Storage> {
1443 let native = storage_to_native_tensor(storage, logical_dims)?;
1444 let permuted = permute_native_tensor(&native, perm)?;
1445 native_tensor_primal_to_storage(&permuted)
1446}
1447
1448pub fn contract_storage_native(
1450 storage_a: &Storage,
1451 dims_a: &[usize],
1452 axes_a: &[usize],
1453 storage_b: &Storage,
1454 dims_b: &[usize],
1455 axes_b: &[usize],
1456 _result_dims: &[usize],
1457) -> Result<Storage> {
1458 let lhs = storage_to_native_tensor(storage_a, dims_a)?;
1459 let rhs = storage_to_native_tensor(storage_b, dims_b)?;
1460 let result = contract_native_tensor(&lhs, axes_a, &rhs, axes_b)?;
1461 native_tensor_primal_to_storage(&result)
1462}
1463
1464pub fn outer_product_storage_native(
1466 lhs: &Storage,
1467 lhs_dims: &[usize],
1468 rhs: &Storage,
1469 rhs_dims: &[usize],
1470 _result_dims: &[usize],
1471) -> Result<Storage> {
1472 let lhs = storage_to_native_tensor(lhs, lhs_dims)?;
1473 let rhs = storage_to_native_tensor(rhs, rhs_dims)?;
1474 let result = outer_product_native_tensor(&lhs, &rhs)?;
1475 native_tensor_primal_to_storage(&result)
1476}
1477
1478pub fn scale_storage_native(
1480 storage: &Storage,
1481 logical_dims: &[usize],
1482 scalar: &AnyScalar,
1483) -> Result<Storage> {
1484 let native = storage_to_native_tensor(storage, logical_dims)?;
1485 let scaled = scale_native_tensor(&native, scalar)?;
1486 native_tensor_primal_to_storage(&scaled)
1487}
1488
1489pub fn axpby_storage_native(
1491 lhs: &Storage,
1492 lhs_dims: &[usize],
1493 a: &AnyScalar,
1494 rhs: &Storage,
1495 rhs_dims: &[usize],
1496 b: &AnyScalar,
1497) -> Result<Storage> {
1498 let lhs = storage_to_native_tensor(lhs, lhs_dims)?;
1499 let rhs = storage_to_native_tensor(rhs, rhs_dims)?;
1500 let combined = axpby_native_tensor(&lhs, a, &rhs, b)?;
1501 native_tensor_primal_to_storage(&combined)
1502}
1503
1504#[cfg(test)]
1505mod tests;