Skip to main content

tenferro_tensor/
backend.rs

1use crate::config::{
2    CompareDir, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig,
3};
4use crate::types::{TensorRank, TypedTensor, TypedTensorView, TypedTensorViewMut};
5use crate::validate::validate_convert_dtype;
6use crate::{RuntimeCacheControl, Tensor, TensorRead, TensorValue};
7
8fn read_boundary_error(op: &'static str) -> crate::Error {
9    crate::Error::backend_failure(
10        op,
11        "backend does not accept borrowed tensor views at this execution boundary",
12    )
13}
14
15fn read_tensor<'a>(op: &'static str, input: TensorRead<'a>) -> crate::Result<&'a Tensor> {
16    input.as_tensor().ok_or_else(|| read_boundary_error(op))
17}
18
19/// Canonical elementwise fusion plan shared between segmented execution and backends.
20#[doc(hidden)]
21#[derive(Clone, Debug, Hash, PartialEq, Eq)]
22pub struct ElementwiseFusionPlan {
23    dtype: crate::DType,
24    input_count: usize,
25    outputs: Vec<usize>,
26    ops: Vec<ElementwiseFusionInst>,
27}
28
29/// One node in a canonical elementwise fusion plan.
30#[doc(hidden)]
31#[derive(Clone, Debug, Hash, PartialEq, Eq)]
32pub struct ElementwiseFusionInst {
33    op: ElementwiseFusionOp,
34    inputs: Vec<usize>,
35}
36
37tenferro_core_ops::define_elementwise_fusion_op!();
38
39impl ElementwiseFusionPlan {
40    /// Build a backend elementwise fusion plan.
41    ///
42    /// # Examples
43    ///
44    /// ```rust
45    /// use tenferro_tensor::backend::{
46    ///     ElementwiseFusionInst, ElementwiseFusionOp, ElementwiseFusionPlan,
47    /// };
48    /// use tenferro_tensor::DType;
49    ///
50    /// let plan = ElementwiseFusionPlan::new(
51    ///     DType::F64,
52    ///     2,
53    ///     vec![2],
54    ///     vec![ElementwiseFusionInst::new(ElementwiseFusionOp::Add, vec![0, 1])],
55    /// );
56    /// assert_eq!(plan.input_count(), 2);
57    /// ```
58    pub fn new(
59        dtype: crate::DType,
60        input_count: usize,
61        outputs: Vec<usize>,
62        ops: Vec<ElementwiseFusionInst>,
63    ) -> Self {
64        Self {
65            dtype,
66            input_count,
67            outputs,
68            ops,
69        }
70    }
71
72    /// Return the scalar dtype expected by this fusion plan.
73    ///
74    /// # Examples
75    ///
76    /// ```rust
77    /// use tenferro_tensor::backend::ElementwiseFusionPlan;
78    /// use tenferro_tensor::DType;
79    ///
80    /// let plan = ElementwiseFusionPlan::new(DType::F32, 0, Vec::new(), Vec::new());
81    /// assert_eq!(plan.dtype(), DType::F32);
82    /// ```
83    pub fn dtype(&self) -> crate::DType {
84        self.dtype
85    }
86
87    /// Return the number of input tensors expected by this plan.
88    ///
89    /// # Examples
90    ///
91    /// ```rust
92    /// use tenferro_tensor::backend::ElementwiseFusionPlan;
93    /// use tenferro_tensor::DType;
94    ///
95    /// let plan = ElementwiseFusionPlan::new(DType::F64, 3, Vec::new(), Vec::new());
96    /// assert_eq!(plan.input_count(), 3);
97    /// ```
98    pub fn input_count(&self) -> usize {
99        self.input_count
100    }
101
102    /// Return the value ids selected as fusion outputs.
103    ///
104    /// # Examples
105    ///
106    /// ```rust
107    /// use tenferro_tensor::backend::ElementwiseFusionPlan;
108    /// use tenferro_tensor::DType;
109    ///
110    /// let plan = ElementwiseFusionPlan::new(DType::F64, 0, vec![0], Vec::new());
111    /// assert_eq!(plan.outputs(), &[0]);
112    /// ```
113    pub fn outputs(&self) -> &[usize] {
114        &self.outputs
115    }
116
117    /// Return the fused elementwise instruction sequence.
118    ///
119    /// # Examples
120    ///
121    /// ```rust
122    /// use tenferro_tensor::backend::{
123    ///     ElementwiseFusionInst, ElementwiseFusionOp, ElementwiseFusionPlan,
124    /// };
125    /// use tenferro_tensor::DType;
126    ///
127    /// let inst = ElementwiseFusionInst::new(ElementwiseFusionOp::Negate, vec![0]);
128    /// let plan = ElementwiseFusionPlan::new(DType::F64, 1, vec![1], vec![inst]);
129    /// assert_eq!(plan.ops().len(), 1);
130    /// ```
131    pub fn ops(&self) -> &[ElementwiseFusionInst] {
132        &self.ops
133    }
134}
135
136impl ElementwiseFusionInst {
137    /// Build a backend elementwise fusion instruction.
138    ///
139    /// # Examples
140    ///
141    /// ```rust
142    /// use tenferro_tensor::backend::{ElementwiseFusionInst, ElementwiseFusionOp};
143    ///
144    /// let inst = ElementwiseFusionInst::new(ElementwiseFusionOp::Add, vec![0, 1]);
145    /// assert_eq!(inst.inputs(), &[0, 1]);
146    /// ```
147    pub fn new(op: ElementwiseFusionOp, inputs: Vec<usize>) -> Self {
148        Self { op, inputs }
149    }
150
151    /// Return the elementwise op executed by this instruction.
152    ///
153    /// # Examples
154    ///
155    /// ```rust
156    /// use tenferro_tensor::backend::{ElementwiseFusionInst, ElementwiseFusionOp};
157    ///
158    /// let inst = ElementwiseFusionInst::new(ElementwiseFusionOp::Negate, vec![0]);
159    /// assert_eq!(inst.op(), ElementwiseFusionOp::Negate);
160    /// ```
161    pub fn op(&self) -> ElementwiseFusionOp {
162        self.op
163    }
164
165    /// Return this instruction's input value ids.
166    ///
167    /// # Examples
168    ///
169    /// ```rust
170    /// use tenferro_tensor::backend::{ElementwiseFusionInst, ElementwiseFusionOp};
171    ///
172    /// let inst = ElementwiseFusionInst::new(ElementwiseFusionOp::Multiply, vec![2, 0]);
173    /// assert_eq!(inst.inputs(), &[2, 0]);
174    /// ```
175    pub fn inputs(&self) -> &[usize] {
176        &self.inputs
177    }
178}
179
180/// Elementwise tensor operations.
181///
182/// # Examples
183///
184/// ```rust
185/// use tenferro_tensor::TensorElementwise;
186///
187/// fn accepts_elementwise<B: TensorElementwise>(_backend: &mut B) {}
188/// ```
189pub trait TensorElementwise {
190    fn add(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
191
192    /// Elementwise addition accepting either owned tensors or borrowed views.
193    ///
194    /// Backends that implement this method must not silently move data across
195    /// devices. A backend that cannot consume views should return an explicit
196    /// backend error rather than materializing or transferring implicitly.
197    ///
198    /// # Examples
199    ///
200    /// ```rust
201    /// use tenferro_tensor::{Tensor, TensorElementwise, TensorRead};
202    ///
203    /// fn add_owned<B: TensorElementwise>(
204    ///     backend: &mut B,
205    ///     lhs: &Tensor,
206    ///     rhs: &Tensor,
207    /// ) -> tenferro_tensor::Result<Tensor> {
208    ///     backend.add_read(TensorRead::from_tensor(lhs), TensorRead::from_tensor(rhs))
209    /// }
210    /// ```
211    fn add_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
212        self.add(read_tensor("add", lhs)?, read_tensor("add", rhs)?)
213    }
214
215    fn mul(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
216    fn mul_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
217        self.mul(read_tensor("mul", lhs)?, read_tensor("mul", rhs)?)
218    }
219
220    fn neg(&mut self, input: &Tensor) -> crate::Result<Tensor>;
221    fn neg_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
222        self.neg(read_tensor("neg", input)?)
223    }
224
225    fn conj(&mut self, input: &Tensor) -> crate::Result<Tensor>;
226    fn conj_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
227        self.conj(read_tensor("conj", input)?)
228    }
229
230    fn div(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
231    fn div_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
232        self.div(read_tensor("div", lhs)?, read_tensor("div", rhs)?)
233    }
234
235    fn abs(&mut self, input: &Tensor) -> crate::Result<Tensor>;
236    fn abs_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
237        self.abs(read_tensor("abs", input)?)
238    }
239
240    fn sign(&mut self, input: &Tensor) -> crate::Result<Tensor>;
241    fn sign_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
242        self.sign(read_tensor("sign", input)?)
243    }
244
245    fn maximum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
246    fn maximum_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
247        self.maximum(read_tensor("maximum", lhs)?, read_tensor("maximum", rhs)?)
248    }
249
250    fn minimum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
251    fn minimum_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
252        self.minimum(read_tensor("minimum", lhs)?, read_tensor("minimum", rhs)?)
253    }
254
255    fn compare(&mut self, lhs: &Tensor, rhs: &Tensor, dir: &CompareDir) -> crate::Result<Tensor>;
256    fn compare_read(
257        &mut self,
258        lhs: TensorRead<'_>,
259        rhs: TensorRead<'_>,
260        dir: &CompareDir,
261    ) -> crate::Result<Tensor> {
262        self.compare(
263            read_tensor("compare", lhs)?,
264            read_tensor("compare", rhs)?,
265            dir,
266        )
267    }
268
269    fn select(
270        &mut self,
271        pred: &Tensor,
272        on_true: &Tensor,
273        on_false: &Tensor,
274    ) -> crate::Result<Tensor>;
275    fn select_read(
276        &mut self,
277        pred: TensorRead<'_>,
278        on_true: TensorRead<'_>,
279        on_false: TensorRead<'_>,
280    ) -> crate::Result<Tensor> {
281        self.select(
282            read_tensor("select", pred)?,
283            read_tensor("select", on_true)?,
284            read_tensor("select", on_false)?,
285        )
286    }
287
288    fn clamp(&mut self, input: &Tensor, lower: &Tensor, upper: &Tensor) -> crate::Result<Tensor>;
289    fn clamp_read(
290        &mut self,
291        input: TensorRead<'_>,
292        lower: TensorRead<'_>,
293        upper: TensorRead<'_>,
294    ) -> crate::Result<Tensor> {
295        self.clamp(
296            read_tensor("clamp", input)?,
297            read_tensor("clamp", lower)?,
298            read_tensor("clamp", upper)?,
299        )
300    }
301}
302
303/// Analytic unary and binary tensor operations.
304///
305/// # Examples
306///
307/// ```rust
308/// use tenferro_tensor::TensorAnalytic;
309///
310/// fn accepts_analytic<B: TensorAnalytic>(_backend: &mut B) {}
311/// ```
312pub trait TensorAnalytic {
313    fn exp(&mut self, input: &Tensor) -> crate::Result<Tensor>;
314    fn exp_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
315        self.exp(read_tensor("exp", input)?)
316    }
317
318    fn log(&mut self, input: &Tensor) -> crate::Result<Tensor>;
319    fn log_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
320        self.log(read_tensor("log", input)?)
321    }
322
323    fn sin(&mut self, input: &Tensor) -> crate::Result<Tensor>;
324    fn sin_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
325        self.sin(read_tensor("sin", input)?)
326    }
327
328    fn cos(&mut self, input: &Tensor) -> crate::Result<Tensor>;
329    fn cos_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
330        self.cos(read_tensor("cos", input)?)
331    }
332
333    fn tanh(&mut self, input: &Tensor) -> crate::Result<Tensor>;
334    fn tanh_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
335        self.tanh(read_tensor("tanh", input)?)
336    }
337
338    fn sqrt(&mut self, input: &Tensor) -> crate::Result<Tensor>;
339    fn sqrt_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
340        self.sqrt(read_tensor("sqrt", input)?)
341    }
342
343    fn rsqrt(&mut self, input: &Tensor) -> crate::Result<Tensor>;
344    fn rsqrt_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
345        self.rsqrt(read_tensor("rsqrt", input)?)
346    }
347
348    fn pow(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
349    fn pow_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
350        self.pow(read_tensor("pow", lhs)?, read_tensor("pow", rhs)?)
351    }
352
353    fn expm1(&mut self, input: &Tensor) -> crate::Result<Tensor>;
354    fn expm1_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
355        self.expm1(read_tensor("expm1", input)?)
356    }
357
358    fn log1p(&mut self, input: &Tensor) -> crate::Result<Tensor>;
359    fn log1p_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
360        self.log1p(read_tensor("log1p", input)?)
361    }
362}
363
364/// Shape, layout, and dtype transformation operations.
365///
366/// # Examples
367///
368/// ```rust
369/// use tenferro_tensor::TensorStructural;
370///
371/// fn accepts_structural<B: TensorStructural>(_backend: &mut B) {}
372/// ```
373pub trait TensorStructural {
374    fn transpose(&mut self, input: &Tensor, perm: &[usize]) -> crate::Result<Tensor>;
375    fn transpose_read(&mut self, input: TensorRead<'_>, perm: &[usize]) -> crate::Result<Tensor> {
376        self.transpose(read_tensor("transpose", input)?, perm)
377    }
378
379    fn reshape(&mut self, input: &Tensor, shape: &[usize]) -> crate::Result<Tensor>;
380    fn reshape_read(&mut self, input: TensorRead<'_>, shape: &[usize]) -> crate::Result<Tensor> {
381        self.reshape(read_tensor("reshape", input)?, shape)
382    }
383
384    fn broadcast_in_dim(
385        &mut self,
386        input: &Tensor,
387        shape: &[usize],
388        dims: &[usize],
389    ) -> crate::Result<Tensor>;
390    fn broadcast_in_dim_read(
391        &mut self,
392        input: TensorRead<'_>,
393        shape: &[usize],
394        dims: &[usize],
395    ) -> crate::Result<Tensor> {
396        self.broadcast_in_dim(read_tensor("broadcast_in_dim", input)?, shape, dims)
397    }
398
399    /// Cast a tensor to another dtype using explicit dtype projection.
400    ///
401    /// Backends may truncate, narrow precision, project complex values, or use
402    /// boolean truthiness according to their documented cast support.
403    ///
404    /// # Examples
405    ///
406    /// ```rust
407    /// use tenferro_tensor::{DType, Tensor, TensorStructural};
408    ///
409    /// fn cast_to_i32<B: TensorStructural>(
410    ///     backend: &mut B,
411    ///     input: &Tensor,
412    /// ) -> tenferro_tensor::Result<Tensor> {
413    ///     backend.cast(input, DType::I32)
414    /// }
415    /// ```
416    fn cast(&mut self, input: &Tensor, to: crate::DType) -> crate::Result<Tensor>;
417
418    /// Convert a tensor to another dtype using checked dtype conversion.
419    ///
420    /// `convert` accepts only conversions allowed by tenferro's dtype-promotion
421    /// lattice. Use [`TensorStructural::cast`] for explicit lossy projection.
422    ///
423    /// # Examples
424    ///
425    /// ```rust
426    /// use tenferro_tensor::{DType, Tensor, TensorStructural};
427    ///
428    /// fn convert_to_f64<B: TensorStructural>(
429    ///     backend: &mut B,
430    ///     input: &Tensor,
431    /// ) -> tenferro_tensor::Result<Tensor> {
432    ///     backend.convert(input, DType::F64)
433    /// }
434    /// ```
435    fn convert(&mut self, input: &Tensor, to: crate::DType) -> crate::Result<Tensor> {
436        validate_convert_dtype("convert", input.dtype(), to)?;
437        self.cast(input, to)
438    }
439
440    fn extract_diagonal(
441        &mut self,
442        input: &Tensor,
443        axis_a: usize,
444        axis_b: usize,
445    ) -> crate::Result<Tensor>;
446    fn embed_diagonal(
447        &mut self,
448        input: &Tensor,
449        axis_a: usize,
450        axis_b: usize,
451    ) -> crate::Result<Tensor>;
452    fn tril(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor>;
453    fn triu(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor>;
454}
455
456/// Reduction operations.
457///
458/// # Examples
459///
460/// ```rust
461/// use tenferro_tensor::TensorReduction;
462///
463/// fn accepts_reduction<B: TensorReduction>(_backend: &mut B) {}
464/// ```
465pub trait TensorReduction {
466    fn reduce_sum(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
467
468    /// Sum elements across axes from an owned tensor or borrowed view.
469    ///
470    /// # Examples
471    ///
472    /// ```rust
473    /// use tenferro_tensor::{Tensor, TensorRead, TensorReduction};
474    ///
475    /// fn sum_owned<B: TensorReduction>(
476    ///     backend: &mut B,
477    ///     input: &Tensor,
478    /// ) -> tenferro_tensor::Result<Tensor> {
479    ///     backend.reduce_sum_read(TensorRead::from_tensor(input), &[0])
480    /// }
481    /// ```
482    fn reduce_sum_read(&mut self, input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
483        match input.as_tensor() {
484            Some(input) => self.reduce_sum(input, axes),
485            None => Err(crate::Error::backend_failure(
486                "reduce_sum",
487                "backend does not accept borrowed tensor views at this execution boundary",
488            )),
489        }
490    }
491
492    fn reduce_prod(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
493
494    /// Multiply elements across axes from an owned tensor or borrowed view.
495    ///
496    /// # Examples
497    ///
498    /// ```rust
499    /// use tenferro_tensor::{Tensor, TensorRead, TensorReduction};
500    ///
501    /// fn prod_owned<B: TensorReduction>(
502    ///     backend: &mut B,
503    ///     input: &Tensor,
504    /// ) -> tenferro_tensor::Result<Tensor> {
505    ///     backend.reduce_prod_read(TensorRead::from_tensor(input), &[0])
506    /// }
507    /// ```
508    fn reduce_prod_read(&mut self, input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
509        match input.as_tensor() {
510            Some(input) => self.reduce_prod(input, axes),
511            None => Err(crate::Error::backend_failure(
512                "reduce_prod",
513                "backend does not accept borrowed tensor views at this execution boundary",
514            )),
515        }
516    }
517
518    fn reduce_max(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
519
520    /// Take maximum values across axes from an owned tensor or borrowed view.
521    ///
522    /// # Examples
523    ///
524    /// ```rust
525    /// use tenferro_tensor::{Tensor, TensorRead, TensorReduction};
526    ///
527    /// fn max_owned<B: TensorReduction>(
528    ///     backend: &mut B,
529    ///     input: &Tensor,
530    /// ) -> tenferro_tensor::Result<Tensor> {
531    ///     backend.reduce_max_read(TensorRead::from_tensor(input), &[0])
532    /// }
533    /// ```
534    fn reduce_max_read(&mut self, input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
535        match input.as_tensor() {
536            Some(input) => self.reduce_max(input, axes),
537            None => Err(crate::Error::backend_failure(
538                "reduce_max",
539                "backend does not accept borrowed tensor views at this execution boundary",
540            )),
541        }
542    }
543
544    fn reduce_min(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
545
546    /// Take minimum values across axes from an owned tensor or borrowed view.
547    ///
548    /// # Examples
549    ///
550    /// ```rust
551    /// use tenferro_tensor::{Tensor, TensorRead, TensorReduction};
552    ///
553    /// fn min_owned<B: TensorReduction>(
554    ///     backend: &mut B,
555    ///     input: &Tensor,
556    /// ) -> tenferro_tensor::Result<Tensor> {
557    ///     backend.reduce_min_read(TensorRead::from_tensor(input), &[0])
558    /// }
559    /// ```
560    fn reduce_min_read(&mut self, input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
561        match input.as_tensor() {
562            Some(input) => self.reduce_min(input, axes),
563            None => Err(crate::Error::backend_failure(
564                "reduce_min",
565                "backend does not accept borrowed tensor views at this execution boundary",
566            )),
567        }
568    }
569}
570
571/// Dot-general operations.
572///
573/// # Examples
574///
575/// ```rust
576/// use tenferro_tensor::TensorDot;
577///
578/// fn accepts_dot<B: TensorDot>(_backend: &mut B) {}
579/// ```
580pub trait TensorDot: TensorElementwise {
581    fn dot_general(
582        &mut self,
583        lhs: &Tensor,
584        rhs: &Tensor,
585        config: &DotGeneralConfig,
586    ) -> crate::Result<Tensor>;
587
588    #[doc(hidden)]
589    fn dot_general_read(
590        &mut self,
591        lhs: TensorRead<'_>,
592        rhs: TensorRead<'_>,
593        config: &DotGeneralConfig,
594    ) -> crate::Result<Tensor> {
595        match (lhs.as_tensor(), rhs.as_tensor()) {
596            (Some(lhs), Some(rhs)) => self.dot_general(lhs, rhs, config),
597            _ => {
598                let lhs = lhs.to_tensor()?;
599                let rhs = rhs.to_tensor()?;
600                self.dot_general(&lhs, &rhs, config)
601            }
602        }
603    }
604
605    #[doc(hidden)]
606    fn dot_general_with_conj(
607        &mut self,
608        lhs: &Tensor,
609        rhs: &Tensor,
610        config: &DotGeneralConfig,
611        lhs_conj: bool,
612        rhs_conj: bool,
613    ) -> crate::Result<Tensor> {
614        if !lhs_conj && !rhs_conj {
615            return self.dot_general(lhs, rhs, config);
616        }
617
618        let lhs_tmp;
619        let lhs_ref = if lhs_conj {
620            lhs_tmp = self.conj(lhs)?;
621            &lhs_tmp
622        } else {
623            lhs
624        };
625        let rhs_tmp;
626        let rhs_ref = if rhs_conj {
627            rhs_tmp = self.conj(rhs)?;
628            &rhs_tmp
629        } else {
630            rhs
631        };
632        self.dot_general(lhs_ref, rhs_ref, config)
633    }
634
635    #[allow(clippy::too_many_arguments)]
636    #[doc(hidden)]
637    fn dot_general_with_conj_read(
638        &mut self,
639        lhs: TensorRead<'_>,
640        rhs: TensorRead<'_>,
641        config: &DotGeneralConfig,
642        lhs_conj: bool,
643        rhs_conj: bool,
644    ) -> crate::Result<Tensor> {
645        if !lhs_conj && !rhs_conj {
646            return self.dot_general_read(lhs, rhs, config);
647        }
648
649        let lhs_tmp;
650        let lhs_ref = if let Some(tensor) = lhs.as_tensor() {
651            tensor
652        } else {
653            lhs_tmp = lhs.to_tensor()?;
654            &lhs_tmp
655        };
656        let rhs_tmp;
657        let rhs_ref = if let Some(tensor) = rhs.as_tensor() {
658            tensor
659        } else {
660            rhs_tmp = rhs.to_tensor()?;
661            &rhs_tmp
662        };
663        self.dot_general_with_conj(lhs_ref, rhs_ref, config, lhs_conj, rhs_conj)
664    }
665}
666
667/// Session-scoped cached dot-general operations.
668///
669/// # Examples
670///
671/// ```rust
672/// use tenferro_tensor::BackendSession;
673///
674/// fn accepts_session_dot<S: BackendSession + ?Sized>(_session: &mut S) {}
675/// ```
676pub trait SessionCachedDot: TensorDot {
677    #[doc(hidden)]
678    fn dot_general_cached(
679        &mut self,
680        _cache_slot: Option<usize>,
681        lhs: &Tensor,
682        rhs: &Tensor,
683        config: &DotGeneralConfig,
684    ) -> crate::Result<Tensor> {
685        self.dot_general(lhs, rhs, config)
686    }
687
688    #[doc(hidden)]
689    fn dot_general_read_cached(
690        &mut self,
691        cache_slot: Option<usize>,
692        lhs: TensorRead<'_>,
693        rhs: TensorRead<'_>,
694        config: &DotGeneralConfig,
695    ) -> crate::Result<Tensor> {
696        match (lhs.as_tensor(), rhs.as_tensor()) {
697            (Some(lhs), Some(rhs)) => self.dot_general_cached(cache_slot, lhs, rhs, config),
698            _ => {
699                let lhs = lhs.to_tensor()?;
700                let rhs = rhs.to_tensor()?;
701                self.dot_general_cached(cache_slot, &lhs, &rhs, config)
702            }
703        }
704    }
705
706    // Mirrors the dot-general signature plus runtime-cache metadata.
707    #[allow(clippy::too_many_arguments)]
708    #[doc(hidden)]
709    fn dot_general_with_conj_cached(
710        &mut self,
711        _cache_slot: Option<usize>,
712        lhs: &Tensor,
713        rhs: &Tensor,
714        config: &DotGeneralConfig,
715        lhs_conj: bool,
716        rhs_conj: bool,
717    ) -> crate::Result<Tensor> {
718        self.dot_general_with_conj(lhs, rhs, config, lhs_conj, rhs_conj)
719    }
720
721    // Mirrors the dot-general read signature plus runtime-cache metadata.
722    #[allow(clippy::too_many_arguments)]
723    #[doc(hidden)]
724    fn dot_general_with_conj_read_cached(
725        &mut self,
726        cache_slot: Option<usize>,
727        lhs: TensorRead<'_>,
728        rhs: TensorRead<'_>,
729        config: &DotGeneralConfig,
730        lhs_conj: bool,
731        rhs_conj: bool,
732    ) -> crate::Result<Tensor> {
733        if !lhs_conj && !rhs_conj {
734            return self.dot_general_read_cached(cache_slot, lhs, rhs, config);
735        }
736
737        let lhs_tmp;
738        let lhs_ref = if let Some(tensor) = lhs.as_tensor() {
739            tensor
740        } else {
741            lhs_tmp = lhs.to_tensor()?;
742            &lhs_tmp
743        };
744        let rhs_tmp;
745        let rhs_ref = if let Some(tensor) = rhs.as_tensor() {
746            tensor
747        } else {
748            rhs_tmp = rhs.to_tensor()?;
749            &rhs_tmp
750        };
751        self.dot_general_with_conj_cached(cache_slot, lhs_ref, rhs_ref, config, lhs_conj, rhs_conj)
752    }
753}
754
755/// Indexing, slicing, and padding operations.
756///
757/// # Examples
758///
759/// ```rust
760/// use tenferro_tensor::TensorIndexing;
761///
762/// fn accepts_indexing<B: TensorIndexing>(_backend: &mut B) {}
763/// ```
764pub trait TensorIndexing {
765    fn gather(
766        &mut self,
767        operand: &Tensor,
768        start_indices: &Tensor,
769        config: &GatherConfig,
770    ) -> crate::Result<Tensor>;
771    fn scatter(
772        &mut self,
773        operand: &Tensor,
774        scatter_indices: &Tensor,
775        updates: &Tensor,
776        config: &ScatterConfig,
777    ) -> crate::Result<Tensor>;
778    fn slice(&mut self, input: &Tensor, config: &SliceConfig) -> crate::Result<Tensor>;
779    fn dynamic_slice(
780        &mut self,
781        input: &Tensor,
782        starts: &Tensor,
783        slice_sizes: &[usize],
784    ) -> crate::Result<Tensor>;
785    fn dynamic_update_slice(
786        &mut self,
787        operand: &Tensor,
788        update: &Tensor,
789        starts: &Tensor,
790    ) -> crate::Result<Tensor>;
791    fn pad(&mut self, input: &Tensor, config: &PadConfig) -> crate::Result<Tensor>;
792    fn concatenate(&mut self, inputs: &[&Tensor], axis: usize) -> crate::Result<Tensor>;
793    fn reverse(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
794}
795
796/// Backend-owned canonicalization for typed tensor views.
797///
798/// Implementations must preserve the input placement family. CPU backends
799/// canonicalize host views through explicit host copies and reject backend
800/// buffers with a diagnostic that asks the caller to download first. GPU
801/// backends canonicalize GPU-resident views on the same device and reject host
802/// buffers with an upload hint.
803///
804/// This trait is intentionally separate from [`BackendSession`] so generic
805/// typed methods do not change the object-safety contract of `dyn BackendSession`.
806///
807/// # Examples
808///
809/// ```rust
810/// use tenferro_tensor::{DynRank, TensorViewCanonicalization, TypedTensor};
811///
812/// fn compact_i32<B: TensorViewCanonicalization<i32, DynRank>>(
813///     backend: &mut B,
814///     tensor: &TypedTensor<i32>,
815/// ) -> tenferro_tensor::Result<TypedTensor<i32>> {
816///     backend.to_contiguous(&tensor.as_view())
817/// }
818/// ```
819pub trait TensorViewCanonicalization<T: Clone + 'static, R: TensorRank> {
820    fn to_contiguous(
821        &mut self,
822        view: &TypedTensorView<'_, T, R>,
823    ) -> crate::Result<TypedTensor<T, R>>;
824
825    fn copy_from_contiguous(
826        &mut self,
827        src: &TypedTensor<T, R>,
828        dst: &mut TypedTensorViewMut<'_, T, R>,
829    ) -> crate::Result<()>;
830}
831
832/// Optional elementwise fusion execution.
833///
834/// # Examples
835///
836/// ```rust
837/// use tenferro_tensor::TensorFusion;
838///
839/// fn accepts_fusion<B: TensorFusion>(_backend: &mut B) {}
840/// ```
841pub trait TensorFusion {
842    #[doc(hidden)]
843    fn execute_elementwise_fusion(
844        &mut self,
845        _inputs: &[&Tensor],
846        _plan: &ElementwiseFusionPlan,
847    ) -> crate::Result<Option<Vec<Tensor>>> {
848        Ok(None)
849    }
850
851    #[doc(hidden)]
852    #[allow(clippy::too_many_arguments)]
853    fn execute_broadcast_multiply(
854        &mut self,
855        _lhs: TensorRead<'_>,
856        _lhs_shape: &[usize],
857        _lhs_dims: &[usize],
858        _rhs: TensorRead<'_>,
859        _rhs_shape: &[usize],
860        _rhs_dims: &[usize],
861    ) -> crate::Result<Option<Tensor>> {
862        Ok(None)
863    }
864
865    #[doc(hidden)]
866    #[allow(clippy::too_many_arguments)]
867    fn execute_broadcast_multiply_value(
868        &mut self,
869        lhs: TensorRead<'_>,
870        lhs_shape: &[usize],
871        lhs_dims: &[usize],
872        rhs: TensorRead<'_>,
873        rhs_shape: &[usize],
874        rhs_dims: &[usize],
875    ) -> crate::Result<Option<TensorValue>> {
876        self.execute_broadcast_multiply(lhs, lhs_shape, lhs_dims, rhs, rhs_shape, rhs_dims)
877            .map(|tensor| tensor.map(TensorValue::from_tensor))
878    }
879}
880
881/// Backend buffer lifecycle operations.
882///
883/// # Examples
884///
885/// ```rust
886/// use tenferro_tensor::TensorBuffer;
887///
888/// fn accepts_buffer<B: TensorBuffer>(_backend: &mut B) {}
889/// ```
890pub trait TensorBuffer {
891    fn reclaim_buffer(&mut self, _tensor: Tensor) {}
892}
893
894/// Device transfer operations on backend boundaries.
895///
896/// # Examples
897///
898/// ```rust
899/// use tenferro_tensor::TensorDeviceTransfer;
900///
901/// fn accepts_transfer<B: TensorDeviceTransfer>(_backend: &mut B) {}
902/// ```
903pub trait TensorDeviceTransfer {
904    fn download_to_host(&mut self, tensor: &Tensor) -> crate::Result<Tensor> {
905        Ok(tensor.clone())
906    }
907
908    fn upload_host_tensor(&mut self, tensor: &Tensor) -> crate::Result<Tensor> {
909        Ok(tensor.clone())
910    }
911}
912
913/// Runtime cache associated with a backend.
914///
915/// # Examples
916///
917/// ```rust
918/// use tenferro_tensor::BackendRuntimeCache;
919///
920/// fn accepts_runtime_cache<B: BackendRuntimeCache>(_backend: &B) {}
921/// ```
922pub trait BackendRuntimeCache {
923    #[doc(hidden)]
924    type RuntimeCache: RuntimeCacheControl + Send + Sync + 'static;
925}
926
927/// Backend-owned cached dot-general operations.
928///
929/// # Examples
930///
931/// ```rust
932/// use tenferro_tensor::BackendCachedDot;
933///
934/// fn accepts_backend_cached_dot<B: BackendCachedDot>(_backend: &mut B) {}
935/// ```
936pub trait BackendCachedDot: BackendRuntimeCache + TensorDot {
937    #[doc(hidden)]
938    fn dot_general_cached(
939        &mut self,
940        _cache: &mut Self::RuntimeCache,
941        _cache_slot: Option<usize>,
942        lhs: &Tensor,
943        rhs: &Tensor,
944        config: &DotGeneralConfig,
945    ) -> crate::Result<Tensor> {
946        self.dot_general(lhs, rhs, config)
947    }
948
949    #[doc(hidden)]
950    fn dot_general_read_cached(
951        &mut self,
952        cache: &mut Self::RuntimeCache,
953        cache_slot: Option<usize>,
954        lhs: TensorRead<'_>,
955        rhs: TensorRead<'_>,
956        config: &DotGeneralConfig,
957    ) -> crate::Result<Tensor> {
958        match (lhs.as_tensor(), rhs.as_tensor()) {
959            (Some(lhs), Some(rhs)) => self.dot_general_cached(cache, cache_slot, lhs, rhs, config),
960            _ => {
961                let lhs = lhs.to_tensor()?;
962                let rhs = rhs.to_tensor()?;
963                self.dot_general_cached(cache, cache_slot, &lhs, &rhs, config)
964            }
965        }
966    }
967
968    // Mirrors the dot-general signature plus runtime-cache metadata.
969    #[allow(clippy::too_many_arguments)]
970    #[doc(hidden)]
971    fn dot_general_with_conj_cached(
972        &mut self,
973        _cache: &mut Self::RuntimeCache,
974        _cache_slot: Option<usize>,
975        lhs: &Tensor,
976        rhs: &Tensor,
977        config: &DotGeneralConfig,
978        lhs_conj: bool,
979        rhs_conj: bool,
980    ) -> crate::Result<Tensor> {
981        self.dot_general_with_conj(lhs, rhs, config, lhs_conj, rhs_conj)
982    }
983
984    // Mirrors the dot-general read signature plus runtime-cache metadata.
985    #[allow(clippy::too_many_arguments)]
986    #[doc(hidden)]
987    fn dot_general_with_conj_read_cached(
988        &mut self,
989        cache: &mut Self::RuntimeCache,
990        cache_slot: Option<usize>,
991        lhs: TensorRead<'_>,
992        rhs: TensorRead<'_>,
993        config: &DotGeneralConfig,
994        lhs_conj: bool,
995        rhs_conj: bool,
996    ) -> crate::Result<Tensor> {
997        if !lhs_conj && !rhs_conj {
998            return self.dot_general_read_cached(cache, cache_slot, lhs, rhs, config);
999        }
1000
1001        let lhs_tmp;
1002        let lhs_ref = if let Some(tensor) = lhs.as_tensor() {
1003            tensor
1004        } else {
1005            lhs_tmp = lhs.to_tensor()?;
1006            &lhs_tmp
1007        };
1008        let rhs_tmp;
1009        let rhs_ref = if let Some(tensor) = rhs.as_tensor() {
1010            tensor
1011        } else {
1012            rhs_tmp = rhs.to_tensor()?;
1013            &rhs_tmp
1014        };
1015        self.dot_general_with_conj_cached(
1016            cache, cache_slot, lhs_ref, rhs_ref, config, lhs_conj, rhs_conj,
1017        )
1018    }
1019}
1020
1021/// Backend execution-session entry points.
1022///
1023/// # Examples
1024///
1025/// ```rust
1026/// use tenferro_tensor::BackendSessionHost;
1027///
1028/// fn accepts_session_host<B: BackendSessionHost>(_backend: &mut B) {}
1029/// ```
1030pub trait BackendSessionHost: BackendRuntimeCache {
1031    fn with_backend_session<R: Send>(
1032        &mut self,
1033        f: impl FnOnce(&mut dyn BackendSession) -> R + Send,
1034    ) -> R
1035    where
1036        Self: TensorBackend + Sized,
1037    {
1038        default_backend_session(self, f)
1039    }
1040
1041    #[doc(hidden)]
1042    fn with_backend_session_cached<R: Send>(
1043        &mut self,
1044        _cache: &mut Self::RuntimeCache,
1045        f: impl FnOnce(&mut dyn BackendSession) -> R + Send,
1046    ) -> R
1047    where
1048        Self: TensorBackend + Sized,
1049    {
1050        self.with_backend_session(f)
1051    }
1052}
1053
1054/// Operation capabilities shared by backends and backend sessions.
1055#[doc(hidden)]
1056pub trait TensorBackendOps:
1057    TensorElementwise
1058    + TensorAnalytic
1059    + TensorStructural
1060    + TensorReduction
1061    + TensorIndexing
1062    + TensorDot
1063    + TensorFusion
1064    + TensorBuffer
1065{
1066}
1067
1068impl<T> TensorBackendOps for T where
1069    T: TensorElementwise
1070        + TensorAnalytic
1071        + TensorStructural
1072        + TensorReduction
1073        + TensorIndexing
1074        + TensorDot
1075        + TensorFusion
1076        + TensorBuffer
1077        + ?Sized
1078{
1079}
1080
1081/// Execution session surface for dense tensor backends.
1082///
1083/// All operations run within a backend-owned execution scope such as a CPU
1084/// thread policy or a GPU stream. Individual ops must not try to re-enter that
1085/// scope.
1086///
1087/// # Examples
1088///
1089/// ```rust
1090/// use tenferro_tensor::{BackendSessionHost, Tensor, TypedTensor};
1091///
1092/// fn add_in_session<B: BackendSessionHost>(
1093///     backend: &mut B,
1094///     a: &Tensor,
1095///     b: &Tensor,
1096/// ) -> tenferro_tensor::Result<Tensor>
1097/// where
1098///     B: tenferro_tensor::TensorBackend,
1099/// {
1100///     backend.with_backend_session(|exec| exec.add(a, b))
1101/// }
1102/// ```
1103pub trait BackendSession: TensorBackendOps + SessionCachedDot {}
1104
1105impl<T> BackendSession for T where T: TensorBackendOps + SessionCachedDot + ?Sized {}
1106
1107/// Standard runtime backend over dynamic [`Tensor`] values.
1108///
1109/// # Examples
1110///
1111/// ```rust
1112/// use tenferro_tensor::TensorBackend;
1113///
1114/// fn accepts_backend<B: TensorBackend>(_backend: &mut B) {}
1115/// ```
1116pub trait TensorBackend:
1117    BackendRuntimeCache
1118    + TensorBackendOps
1119    + BackendCachedDot
1120    + TensorDeviceTransfer
1121    + BackendSessionHost
1122{
1123}
1124
1125impl<T> SessionCachedDot for T where T: TensorBackend + ?Sized {}
1126
1127/// Run a closure using the backend itself as a default execution session.
1128///
1129/// This is suitable for backends whose individual ops already manage their own
1130/// execution context.
1131///
1132/// # Examples
1133///
1134/// ```rust
1135/// use tenferro_tensor::{default_backend_session, TensorBackend};
1136///
1137/// fn run_with_default_session<B: TensorBackend>(backend: &mut B) -> usize {
1138///     default_backend_session(backend, |_exec| 1usize)
1139/// }
1140/// ```
1141pub fn default_backend_session<B: TensorBackend, R: Send>(
1142    backend: &mut B,
1143    f: impl FnOnce(&mut dyn BackendSession) -> R + Send,
1144) -> R {
1145    f(backend)
1146}