Skip to main content

tenferro_tensor/
backend.rs

1use strided_kernel::{col_major_strides, reduce_axis, zip_map2_into, StridedArray, StridedView};
2use tenferro_algebra::Semiring;
3
4use crate::config::{
5    CompareDir, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig,
6};
7use crate::{Buffer, Tensor, TypedTensor};
8
9/// Canonical elementwise fusion plan shared between segmented execution and backends.
10#[doc(hidden)]
11#[derive(Clone, Debug, Hash, PartialEq, Eq)]
12pub struct ElementwiseFusionPlan {
13    pub dtype: crate::DType,
14    pub n_inputs: usize,
15    pub outputs: Vec<usize>,
16    pub ops: Vec<ElementwiseFusionInst>,
17}
18
19/// One node in a canonical elementwise fusion plan.
20#[doc(hidden)]
21#[derive(Clone, Debug, Hash, PartialEq, Eq)]
22pub struct ElementwiseFusionInst {
23    pub op: ElementwiseFusionOp,
24    pub inputs: Vec<usize>,
25}
26
27/// Elementwise op kinds supported by backend fusion implementations.
28#[doc(hidden)]
29#[derive(Clone, Debug, Hash, PartialEq, Eq)]
30pub enum ElementwiseFusionOp {
31    Add,
32    Multiply,
33    Negate,
34    Conj,
35    Divide,
36    Abs,
37    Maximum,
38    Minimum,
39    Compare(CompareDir),
40    Select,
41    Clamp,
42    Exp,
43    Log,
44    Sin,
45    Cos,
46    Tanh,
47    Sqrt,
48    Rsqrt,
49    Pow,
50    Expm1,
51    Log1p,
52}
53
54pub(crate) fn typed_view<T: Copy>(tensor: &TypedTensor<T>) -> StridedView<'_, T> {
55    match &tensor.buffer {
56        Buffer::Host(data) => {
57            let strides = col_major_strides(&tensor.shape);
58            StridedView::new(data, &tensor.shape, &strides, 0).expect("contiguous host tensor")
59        }
60        Buffer::Backend(_) => todo!("typed_view for backend buffers"),
61        #[cfg(feature = "cubecl")]
62        Buffer::Cubecl(_) => panic!("GPU tensor (Buffer::Cubecl) passed to CPU backend. Use cubecl::download_tensor() to transfer to CPU first."),
63    }
64}
65
66pub(crate) fn typed_array<T: Clone>(shape: &[usize], fill: T) -> StridedArray<T> {
67    let total: usize = shape.iter().product();
68    let strides = col_major_strides(shape);
69    StridedArray::from_parts(vec![fill; total], shape, &strides, 0)
70        .expect("column-major output array")
71}
72
73pub(crate) fn tensor_from_array<T: Clone>(array: StridedArray<T>) -> TypedTensor<T> {
74    TypedTensor::from_vec(array.dims().to_vec(), array.into_data())
75}
76
77fn backend_failure(op: &'static str, err: impl ToString) -> crate::Error {
78    crate::Error::BackendFailure {
79        op,
80        message: err.to_string(),
81    }
82}
83
84fn validate_axis_list(
85    op: &'static str,
86    role: &'static str,
87    axes: &[usize],
88    rank: usize,
89) -> crate::Result<()> {
90    let mut seen = vec![false; rank];
91    for &axis in axes {
92        if axis >= rank {
93            return Err(crate::Error::AxisOutOfBounds { op, axis, rank });
94        }
95        if seen[axis] {
96            return Err(crate::Error::DuplicateAxis { op, axis, role });
97        }
98        seen[axis] = true;
99    }
100    Ok(())
101}
102
103fn validate_binary_shapes(op: &'static str, lhs: &[usize], rhs: &[usize]) -> crate::Result<()> {
104    if lhs != rhs {
105        return Err(crate::Error::ShapeMismatch {
106            op,
107            lhs: lhs.to_vec(),
108            rhs: rhs.to_vec(),
109        });
110    }
111    Ok(())
112}
113
114/// Execution session surface for dense tensor backends.
115///
116/// All operations run within a backend-owned execution scope such as a CPU
117/// rayon pool or a GPU stream. Individual ops must not try to re-enter that
118/// scope.
119///
120/// # Examples
121///
122/// ```ignore
123/// use tenferro_tensor::{cpu::CpuBackend, Tensor, TensorBackend, TypedTensor};
124///
125/// let mut backend = CpuBackend::new();
126/// let a = Tensor::F64(TypedTensor::from_vec(vec![2], vec![1.0, 2.0]));
127/// let b = Tensor::F64(TypedTensor::from_vec(vec![2], vec![3.0, 4.0]));
128/// let sum = backend
129///     .with_exec_session(|exec| exec.add(&a, &b))
130///     .unwrap();
131/// assert_eq!(sum.shape(), &[2]);
132/// ```
133pub trait TensorExec {
134    fn add(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
135    fn mul(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
136    fn neg(&mut self, input: &Tensor) -> crate::Result<Tensor>;
137    fn conj(&mut self, input: &Tensor) -> crate::Result<Tensor>;
138    fn div(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
139    fn abs(&mut self, input: &Tensor) -> crate::Result<Tensor>;
140    fn sign(&mut self, input: &Tensor) -> crate::Result<Tensor>;
141    fn maximum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
142    fn minimum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
143    fn compare(&mut self, lhs: &Tensor, rhs: &Tensor, dir: &CompareDir) -> crate::Result<Tensor>;
144    fn select(
145        &mut self,
146        pred: &Tensor,
147        on_true: &Tensor,
148        on_false: &Tensor,
149    ) -> crate::Result<Tensor>;
150    fn clamp(&mut self, input: &Tensor, lower: &Tensor, upper: &Tensor) -> crate::Result<Tensor>;
151
152    fn exp(&mut self, input: &Tensor) -> crate::Result<Tensor>;
153    fn log(&mut self, input: &Tensor) -> crate::Result<Tensor>;
154    fn sin(&mut self, input: &Tensor) -> crate::Result<Tensor>;
155    fn cos(&mut self, input: &Tensor) -> crate::Result<Tensor>;
156    fn tanh(&mut self, input: &Tensor) -> crate::Result<Tensor>;
157    fn sqrt(&mut self, input: &Tensor) -> crate::Result<Tensor>;
158    fn rsqrt(&mut self, input: &Tensor) -> crate::Result<Tensor>;
159    fn pow(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
160    fn expm1(&mut self, input: &Tensor) -> crate::Result<Tensor>;
161    fn log1p(&mut self, input: &Tensor) -> crate::Result<Tensor>;
162
163    fn transpose(&mut self, input: &Tensor, perm: &[usize]) -> crate::Result<Tensor>;
164    fn reshape(&mut self, input: &Tensor, shape: &[usize]) -> crate::Result<Tensor>;
165    fn broadcast_in_dim(
166        &mut self,
167        input: &Tensor,
168        shape: &[usize],
169        dims: &[usize],
170    ) -> crate::Result<Tensor>;
171    fn convert(&mut self, input: &Tensor, to: crate::DType) -> crate::Result<Tensor>;
172    fn extract_diagonal(
173        &mut self,
174        input: &Tensor,
175        axis_a: usize,
176        axis_b: usize,
177    ) -> crate::Result<Tensor>;
178    fn embed_diagonal(
179        &mut self,
180        input: &Tensor,
181        axis_a: usize,
182        axis_b: usize,
183    ) -> crate::Result<Tensor>;
184    fn tril(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor>;
185    fn triu(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor>;
186
187    fn reduce_sum(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
188    fn reduce_prod(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
189    fn reduce_max(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
190    fn reduce_min(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
191
192    fn dot_general(
193        &mut self,
194        lhs: &Tensor,
195        rhs: &Tensor,
196        config: &DotGeneralConfig,
197    ) -> crate::Result<Tensor>;
198
199    fn gather(
200        &mut self,
201        operand: &Tensor,
202        start_indices: &Tensor,
203        config: &GatherConfig,
204    ) -> crate::Result<Tensor>;
205    fn scatter(
206        &mut self,
207        operand: &Tensor,
208        scatter_indices: &Tensor,
209        updates: &Tensor,
210        config: &ScatterConfig,
211    ) -> crate::Result<Tensor>;
212    fn slice(&mut self, input: &Tensor, config: &SliceConfig) -> crate::Result<Tensor>;
213    fn dynamic_slice(
214        &mut self,
215        input: &Tensor,
216        starts: &Tensor,
217        slice_sizes: &[usize],
218    ) -> crate::Result<Tensor>;
219    fn pad(&mut self, input: &Tensor, config: &PadConfig) -> crate::Result<Tensor>;
220    fn concatenate(&mut self, inputs: &[&Tensor], axis: usize) -> crate::Result<Tensor>;
221    fn reverse(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
222
223    fn cholesky(&mut self, input: &Tensor) -> crate::Result<Tensor>;
224    fn triangular_solve(
225        &mut self,
226        a: &Tensor,
227        b: &Tensor,
228        left_side: bool,
229        lower: bool,
230        transpose_a: bool,
231        unit_diagonal: bool,
232    ) -> crate::Result<Tensor>;
233    fn lu(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
234    fn svd(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
235    fn qr(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
236    fn eigh(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
237    fn eig(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
238
239    fn reclaim_buffer(&mut self, tensor: Tensor);
240
241    #[doc(hidden)]
242    fn execute_elementwise_fusion(
243        &mut self,
244        _inputs: &[&Tensor],
245        _plan: &ElementwiseFusionPlan,
246    ) -> crate::Result<Option<Vec<Tensor>>> {
247        Ok(None)
248    }
249}
250
251struct BackendExecAdapter<'a, B: TensorBackend + ?Sized> {
252    backend: &'a mut B,
253}
254
255macro_rules! forward_exec_to_backend {
256    ($($name:ident($($arg:ident : $argty:ty),*) -> $ret:ty;)+) => {
257        $(
258            fn $name(&mut self, $($arg: $argty),*) -> $ret {
259                self.backend.$name($($arg),*)
260            }
261        )+
262    };
263}
264
265impl<B: TensorBackend + ?Sized> TensorExec for BackendExecAdapter<'_, B> {
266    forward_exec_to_backend! {
267        add(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
268        mul(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
269        neg(input: &Tensor) -> crate::Result<Tensor>;
270        conj(input: &Tensor) -> crate::Result<Tensor>;
271        div(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
272        abs(input: &Tensor) -> crate::Result<Tensor>;
273        sign(input: &Tensor) -> crate::Result<Tensor>;
274        maximum(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
275        minimum(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
276        compare(lhs: &Tensor, rhs: &Tensor, dir: &CompareDir) -> crate::Result<Tensor>;
277        select(pred: &Tensor, on_true: &Tensor, on_false: &Tensor) -> crate::Result<Tensor>;
278        clamp(input: &Tensor, lower: &Tensor, upper: &Tensor) -> crate::Result<Tensor>;
279        exp(input: &Tensor) -> crate::Result<Tensor>;
280        log(input: &Tensor) -> crate::Result<Tensor>;
281        sin(input: &Tensor) -> crate::Result<Tensor>;
282        cos(input: &Tensor) -> crate::Result<Tensor>;
283        tanh(input: &Tensor) -> crate::Result<Tensor>;
284        sqrt(input: &Tensor) -> crate::Result<Tensor>;
285        rsqrt(input: &Tensor) -> crate::Result<Tensor>;
286        pow(lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
287        expm1(input: &Tensor) -> crate::Result<Tensor>;
288        log1p(input: &Tensor) -> crate::Result<Tensor>;
289        transpose(input: &Tensor, perm: &[usize]) -> crate::Result<Tensor>;
290        reshape(input: &Tensor, shape: &[usize]) -> crate::Result<Tensor>;
291        broadcast_in_dim(input: &Tensor, shape: &[usize], dims: &[usize]) -> crate::Result<Tensor>;
292        convert(input: &Tensor, to: crate::DType) -> crate::Result<Tensor>;
293        extract_diagonal(input: &Tensor, axis_a: usize, axis_b: usize) -> crate::Result<Tensor>;
294        embed_diagonal(input: &Tensor, axis_a: usize, axis_b: usize) -> crate::Result<Tensor>;
295        tril(input: &Tensor, k: i64) -> crate::Result<Tensor>;
296        triu(input: &Tensor, k: i64) -> crate::Result<Tensor>;
297        reduce_sum(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
298        reduce_prod(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
299        reduce_max(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
300        reduce_min(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
301        dot_general(lhs: &Tensor, rhs: &Tensor, config: &DotGeneralConfig) -> crate::Result<Tensor>;
302        gather(operand: &Tensor, start_indices: &Tensor, config: &GatherConfig) -> crate::Result<Tensor>;
303        scatter(
304            operand: &Tensor,
305            scatter_indices: &Tensor,
306            updates: &Tensor,
307            config: &ScatterConfig
308        ) -> crate::Result<Tensor>;
309        slice(input: &Tensor, config: &SliceConfig) -> crate::Result<Tensor>;
310        dynamic_slice(input: &Tensor, starts: &Tensor, slice_sizes: &[usize]) -> crate::Result<Tensor>;
311        pad(input: &Tensor, config: &PadConfig) -> crate::Result<Tensor>;
312        concatenate(inputs: &[&Tensor], axis: usize) -> crate::Result<Tensor>;
313        reverse(input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
314        cholesky(input: &Tensor) -> crate::Result<Tensor>;
315        triangular_solve(
316            a: &Tensor,
317            b: &Tensor,
318            left_side: bool,
319            lower: bool,
320            transpose_a: bool,
321            unit_diagonal: bool
322        ) -> crate::Result<Tensor>;
323        lu(input: &Tensor) -> crate::Result<Vec<Tensor>>;
324        svd(input: &Tensor) -> crate::Result<Vec<Tensor>>;
325        qr(input: &Tensor) -> crate::Result<Vec<Tensor>>;
326        eigh(input: &Tensor) -> crate::Result<Vec<Tensor>>;
327        eig(input: &Tensor) -> crate::Result<Vec<Tensor>>;
328        reclaim_buffer(tensor: Tensor) -> ();
329        execute_elementwise_fusion(
330            inputs: &[&Tensor],
331            plan: &ElementwiseFusionPlan
332        ) -> crate::Result<Option<Vec<Tensor>>>;
333    }
334}
335
336/// Run a closure using the default execution-session adapter.
337///
338/// This forwards [`TensorExec`] calls back to the backend's existing
339/// [`TensorBackend`] methods, which is suitable for backends whose individual
340/// ops already manage their own execution context.
341///
342/// # Examples
343///
344/// ```ignore
345/// use tenferro_tensor::{cpu::CpuBackend, default_exec_session};
346///
347/// let mut backend = CpuBackend::new();
348/// let _ = default_exec_session(&mut backend, |_exec| 1usize);
349/// ```
350pub fn default_exec_session<B: TensorBackend + ?Sized, R: Send>(
351    backend: &mut B,
352    f: impl FnOnce(&mut dyn TensorExec) -> R + Send,
353) -> R {
354    let mut adapter = BackendExecAdapter { backend };
355    f(&mut adapter)
356}
357
358/// Standard runtime backend over dynamic [`Tensor`] values.
359///
360/// # Examples
361///
362/// ```ignore
363/// use tenferro_tensor::{cpu::CpuBackend, TensorBackend};
364///
365/// let mut backend = CpuBackend::new();
366/// ```
367pub trait TensorBackend {
368    fn add(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
369    fn mul(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
370    fn neg(&mut self, input: &Tensor) -> crate::Result<Tensor>;
371    fn conj(&mut self, input: &Tensor) -> crate::Result<Tensor>;
372    fn div(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
373    fn abs(&mut self, input: &Tensor) -> crate::Result<Tensor>;
374    fn sign(&mut self, input: &Tensor) -> crate::Result<Tensor>;
375    fn maximum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
376    fn minimum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
377    fn compare(&mut self, lhs: &Tensor, rhs: &Tensor, dir: &CompareDir) -> crate::Result<Tensor>;
378    fn select(
379        &mut self,
380        pred: &Tensor,
381        on_true: &Tensor,
382        on_false: &Tensor,
383    ) -> crate::Result<Tensor>;
384    fn clamp(&mut self, input: &Tensor, lower: &Tensor, upper: &Tensor) -> crate::Result<Tensor>;
385
386    fn exp(&mut self, input: &Tensor) -> crate::Result<Tensor>;
387    fn log(&mut self, input: &Tensor) -> crate::Result<Tensor>;
388    fn sin(&mut self, input: &Tensor) -> crate::Result<Tensor>;
389    fn cos(&mut self, input: &Tensor) -> crate::Result<Tensor>;
390    fn tanh(&mut self, input: &Tensor) -> crate::Result<Tensor>;
391    fn sqrt(&mut self, input: &Tensor) -> crate::Result<Tensor>;
392    fn rsqrt(&mut self, input: &Tensor) -> crate::Result<Tensor>;
393    fn pow(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
394    fn expm1(&mut self, input: &Tensor) -> crate::Result<Tensor>;
395    fn log1p(&mut self, input: &Tensor) -> crate::Result<Tensor>;
396
397    fn transpose(&mut self, input: &Tensor, perm: &[usize]) -> crate::Result<Tensor>;
398    fn reshape(&mut self, input: &Tensor, shape: &[usize]) -> crate::Result<Tensor>;
399    fn broadcast_in_dim(
400        &mut self,
401        input: &Tensor,
402        shape: &[usize],
403        dims: &[usize],
404    ) -> crate::Result<Tensor>;
405    fn convert(&mut self, input: &Tensor, to: crate::DType) -> crate::Result<Tensor>;
406    fn extract_diagonal(
407        &mut self,
408        input: &Tensor,
409        axis_a: usize,
410        axis_b: usize,
411    ) -> crate::Result<Tensor>;
412    fn embed_diagonal(
413        &mut self,
414        input: &Tensor,
415        axis_a: usize,
416        axis_b: usize,
417    ) -> crate::Result<Tensor>;
418    fn tril(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor>;
419    fn triu(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor>;
420
421    fn reduce_sum(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
422    fn reduce_prod(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
423    fn reduce_max(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
424    fn reduce_min(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
425
426    fn dot_general(
427        &mut self,
428        lhs: &Tensor,
429        rhs: &Tensor,
430        config: &DotGeneralConfig,
431    ) -> crate::Result<Tensor>;
432
433    fn gather(
434        &mut self,
435        operand: &Tensor,
436        start_indices: &Tensor,
437        config: &GatherConfig,
438    ) -> crate::Result<Tensor>;
439    fn scatter(
440        &mut self,
441        operand: &Tensor,
442        scatter_indices: &Tensor,
443        updates: &Tensor,
444        config: &ScatterConfig,
445    ) -> crate::Result<Tensor>;
446    fn slice(&mut self, input: &Tensor, config: &SliceConfig) -> crate::Result<Tensor>;
447    fn dynamic_slice(
448        &mut self,
449        input: &Tensor,
450        starts: &Tensor,
451        slice_sizes: &[usize],
452    ) -> crate::Result<Tensor>;
453    fn pad(&mut self, input: &Tensor, config: &PadConfig) -> crate::Result<Tensor>;
454    fn concatenate(&mut self, inputs: &[&Tensor], axis: usize) -> crate::Result<Tensor>;
455    fn reverse(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
456
457    fn cholesky(&mut self, input: &Tensor) -> crate::Result<Tensor>;
458    fn triangular_solve(
459        &mut self,
460        a: &Tensor,
461        b: &Tensor,
462        left_side: bool,
463        lower: bool,
464        transpose_a: bool,
465        unit_diagonal: bool,
466    ) -> crate::Result<Tensor>;
467    fn lu(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
468    fn svd(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
469    fn qr(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
470    fn eigh(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
471    fn eig(&mut self, input: &Tensor) -> crate::Result<Vec<Tensor>>;
472    fn solve(&mut self, a: &Tensor, b: &Tensor) -> crate::Result<Tensor>;
473
474    /// Execute a batch of operations inside the backend's execution context.
475    ///
476    /// Backends can override this to establish one shared scope for many ops,
477    /// such as a rayon pool install on CPU.
478    ///
479    /// # Examples
480    ///
481    /// ```ignore
482    /// use tenferro_tensor::{cpu::CpuBackend, TensorBackend};
483    ///
484    /// let mut backend = CpuBackend::new();
485    /// let _value = backend.with_exec_session(|_exec| 1usize);
486    /// ```
487    fn with_exec_session<R: Send>(&mut self, f: impl FnOnce(&mut dyn TensorExec) -> R + Send) -> R {
488        default_exec_session(self, f)
489    }
490
491    /// Materialize a backend tensor into host memory.
492    ///
493    /// Backends that already operate on host tensors can keep the default
494    /// implementation, which clones the input tensor.
495    ///
496    /// # Examples
497    ///
498    /// ```
499    /// use tenferro_tensor::{cpu::CpuBackend, Tensor, TensorBackend, TypedTensor};
500    ///
501    /// let mut backend = CpuBackend::new();
502    /// let tensor = Tensor::F64(TypedTensor::from_vec(vec![2], vec![1.0, 2.0]));
503    /// let host = backend.download_to_host(&tensor).unwrap();
504    /// assert_eq!(host.shape(), &[2]);
505    /// ```
506    fn download_to_host(&mut self, tensor: &Tensor) -> crate::Result<Tensor> {
507        Ok(tensor.clone())
508    }
509
510    /// Upload a host tensor into backend-owned storage when needed.
511    ///
512    /// Backends that already use host tensors can keep the default
513    /// implementation, which clones the input tensor.
514    ///
515    /// # Examples
516    ///
517    /// ```
518    /// use tenferro_tensor::{cpu::CpuBackend, Tensor, TensorBackend, TypedTensor};
519    ///
520    /// let mut backend = CpuBackend::new();
521    /// let tensor = Tensor::F64(TypedTensor::from_vec(vec![2], vec![1.0, 2.0]));
522    /// let uploaded = backend.upload_host_tensor(&tensor).unwrap();
523    /// assert_eq!(uploaded.shape(), &[2]);
524    /// ```
525    fn upload_host_tensor(&mut self, tensor: &Tensor) -> crate::Result<Tensor> {
526        Ok(tensor.clone())
527    }
528
529    /// Reclaim a tensor buffer for backend-specific reuse.
530    ///
531    /// Backends that do not pool buffers can ignore the tensor and let it drop.
532    ///
533    /// # Examples
534    ///
535    /// ```ignore
536    /// use tenferro_tensor::{cpu::CpuBackend, Tensor, TensorBackend, TypedTensor};
537    ///
538    /// let mut backend = CpuBackend::new();
539    /// let tensor = Tensor::F64(TypedTensor::from_vec(vec![2], vec![1.0, 2.0]));
540    /// backend.reclaim_buffer(tensor);
541    /// ```
542    fn reclaim_buffer(&mut self, _tensor: Tensor) {}
543
544    #[doc(hidden)]
545    fn execute_elementwise_fusion(
546        &mut self,
547        _inputs: &[&Tensor],
548        _plan: &ElementwiseFusionPlan,
549    ) -> crate::Result<Option<Vec<Tensor>>> {
550        Ok(None)
551    }
552}
553
554/// Algebra-generic backend over typed tensors.
555///
556/// # Examples
557///
558/// ```ignore
559/// use tenferro_algebra::Standard;
560/// use tenferro_tensor::{cpu::CpuBackend, SemiringBackend};
561///
562/// fn needs_semiring_backend<B: SemiringBackend<Standard<f64>>>(_backend: &mut B) {}
563/// let mut backend = CpuBackend::new();
564/// needs_semiring_backend(&mut backend);
565/// ```
566pub trait SemiringBackend<Alg: Semiring> {
567    fn batched_gemm(
568        &mut self,
569        lhs: &TypedTensor<Alg::Scalar>,
570        rhs: &TypedTensor<Alg::Scalar>,
571        config: &DotGeneralConfig,
572    ) -> crate::Result<TypedTensor<Alg::Scalar>>;
573
574    fn add(
575        &mut self,
576        lhs: &TypedTensor<Alg::Scalar>,
577        rhs: &TypedTensor<Alg::Scalar>,
578    ) -> crate::Result<TypedTensor<Alg::Scalar>> {
579        validate_binary_shapes("add", &lhs.shape, &rhs.shape)?;
580        let mut out = typed_array(&lhs.shape, Alg::zero());
581        zip_map2_into(
582            &mut out.view_mut(),
583            &typed_view(lhs),
584            &typed_view(rhs),
585            |x, y| Alg::add(x, y),
586        )
587        .map_err(|err| backend_failure("add", err))?;
588        Ok(tensor_from_array(out))
589    }
590
591    fn mul(
592        &mut self,
593        lhs: &TypedTensor<Alg::Scalar>,
594        rhs: &TypedTensor<Alg::Scalar>,
595    ) -> crate::Result<TypedTensor<Alg::Scalar>> {
596        validate_binary_shapes("mul", &lhs.shape, &rhs.shape)?;
597        let mut out = typed_array(&lhs.shape, Alg::zero());
598        zip_map2_into(
599            &mut out.view_mut(),
600            &typed_view(lhs),
601            &typed_view(rhs),
602            |x, y| Alg::mul(x, y),
603        )
604        .map_err(|err| backend_failure("mul", err))?;
605        Ok(tensor_from_array(out))
606    }
607
608    fn reduce_sum(
609        &mut self,
610        input: &TypedTensor<Alg::Scalar>,
611        axes: &[usize],
612    ) -> crate::Result<TypedTensor<Alg::Scalar>> {
613        validate_axis_list("reduce_sum", "axes", axes, input.shape.len())?;
614        if axes.is_empty() {
615            return Ok(input.clone());
616        }
617
618        let output_shape: Vec<usize> = input
619            .shape
620            .iter()
621            .enumerate()
622            .filter(|(axis, _)| !axes.contains(axis))
623            .map(|(_, &dim)| dim)
624            .collect();
625
626        let strides = col_major_strides(&input.shape);
627        let mut current =
628            StridedArray::from_parts(input.host_data().to_vec(), &input.shape, &strides, 0)
629                .map_err(|err| backend_failure("reduce_sum", err))?;
630
631        let mut sorted_axes = axes.to_vec();
632        sorted_axes.sort_unstable_by(|a, b| b.cmp(a));
633        for axis in sorted_axes {
634            current = reduce_axis(
635                &current.view(),
636                axis,
637                |x| x,
638                |a, b| Alg::add(a, b),
639                Alg::zero(),
640            )
641            .map_err(|err| backend_failure("reduce_sum", err))?;
642        }
643        Ok(TypedTensor::from_vec(output_shape, current.into_data()))
644    }
645}