1use crate::config::{
2 CompareDir, DotGeneralConfig, GatherConfig, PadConfig, ScatterConfig, SliceConfig,
3};
4use crate::types::{TensorRank, TypedTensor, TypedTensorView, TypedTensorViewMut};
5use crate::validate::validate_convert_dtype;
6use crate::{RuntimeCacheControl, Tensor, TensorRead, TensorValue};
7
8fn read_boundary_error(op: &'static str) -> crate::Error {
9 crate::Error::backend_failure(
10 op,
11 "backend does not accept borrowed tensor views at this execution boundary",
12 )
13}
14
15fn read_tensor<'a>(op: &'static str, input: TensorRead<'a>) -> crate::Result<&'a Tensor> {
16 input.as_tensor().ok_or_else(|| read_boundary_error(op))
17}
18
19#[doc(hidden)]
21#[derive(Clone, Debug, Hash, PartialEq, Eq)]
22pub struct ElementwiseFusionPlan {
23 dtype: crate::DType,
24 input_count: usize,
25 outputs: Vec<usize>,
26 ops: Vec<ElementwiseFusionInst>,
27}
28
29#[doc(hidden)]
31#[derive(Clone, Debug, Hash, PartialEq, Eq)]
32pub struct ElementwiseFusionInst {
33 op: ElementwiseFusionOp,
34 inputs: Vec<usize>,
35}
36
37tenferro_core_ops::define_elementwise_fusion_op!();
38
39impl ElementwiseFusionPlan {
40 pub fn new(
59 dtype: crate::DType,
60 input_count: usize,
61 outputs: Vec<usize>,
62 ops: Vec<ElementwiseFusionInst>,
63 ) -> Self {
64 Self {
65 dtype,
66 input_count,
67 outputs,
68 ops,
69 }
70 }
71
72 pub fn dtype(&self) -> crate::DType {
84 self.dtype
85 }
86
87 pub fn input_count(&self) -> usize {
99 self.input_count
100 }
101
102 pub fn outputs(&self) -> &[usize] {
114 &self.outputs
115 }
116
117 pub fn ops(&self) -> &[ElementwiseFusionInst] {
132 &self.ops
133 }
134}
135
136impl ElementwiseFusionInst {
137 pub fn new(op: ElementwiseFusionOp, inputs: Vec<usize>) -> Self {
148 Self { op, inputs }
149 }
150
151 pub fn op(&self) -> ElementwiseFusionOp {
162 self.op
163 }
164
165 pub fn inputs(&self) -> &[usize] {
176 &self.inputs
177 }
178}
179
180pub trait TensorElementwise {
190 fn add(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
191
192 fn add_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
212 self.add(read_tensor("add", lhs)?, read_tensor("add", rhs)?)
213 }
214
215 fn mul(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
216 fn mul_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
217 self.mul(read_tensor("mul", lhs)?, read_tensor("mul", rhs)?)
218 }
219
220 fn neg(&mut self, input: &Tensor) -> crate::Result<Tensor>;
221 fn neg_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
222 self.neg(read_tensor("neg", input)?)
223 }
224
225 fn conj(&mut self, input: &Tensor) -> crate::Result<Tensor>;
226 fn conj_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
227 self.conj(read_tensor("conj", input)?)
228 }
229
230 fn div(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
231 fn div_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
232 self.div(read_tensor("div", lhs)?, read_tensor("div", rhs)?)
233 }
234
235 fn abs(&mut self, input: &Tensor) -> crate::Result<Tensor>;
236 fn abs_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
237 self.abs(read_tensor("abs", input)?)
238 }
239
240 fn sign(&mut self, input: &Tensor) -> crate::Result<Tensor>;
241 fn sign_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
242 self.sign(read_tensor("sign", input)?)
243 }
244
245 fn maximum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
246 fn maximum_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
247 self.maximum(read_tensor("maximum", lhs)?, read_tensor("maximum", rhs)?)
248 }
249
250 fn minimum(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
251 fn minimum_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
252 self.minimum(read_tensor("minimum", lhs)?, read_tensor("minimum", rhs)?)
253 }
254
255 fn compare(&mut self, lhs: &Tensor, rhs: &Tensor, dir: &CompareDir) -> crate::Result<Tensor>;
256 fn compare_read(
257 &mut self,
258 lhs: TensorRead<'_>,
259 rhs: TensorRead<'_>,
260 dir: &CompareDir,
261 ) -> crate::Result<Tensor> {
262 self.compare(
263 read_tensor("compare", lhs)?,
264 read_tensor("compare", rhs)?,
265 dir,
266 )
267 }
268
269 fn select(
270 &mut self,
271 pred: &Tensor,
272 on_true: &Tensor,
273 on_false: &Tensor,
274 ) -> crate::Result<Tensor>;
275 fn select_read(
276 &mut self,
277 pred: TensorRead<'_>,
278 on_true: TensorRead<'_>,
279 on_false: TensorRead<'_>,
280 ) -> crate::Result<Tensor> {
281 self.select(
282 read_tensor("select", pred)?,
283 read_tensor("select", on_true)?,
284 read_tensor("select", on_false)?,
285 )
286 }
287
288 fn clamp(&mut self, input: &Tensor, lower: &Tensor, upper: &Tensor) -> crate::Result<Tensor>;
289 fn clamp_read(
290 &mut self,
291 input: TensorRead<'_>,
292 lower: TensorRead<'_>,
293 upper: TensorRead<'_>,
294 ) -> crate::Result<Tensor> {
295 self.clamp(
296 read_tensor("clamp", input)?,
297 read_tensor("clamp", lower)?,
298 read_tensor("clamp", upper)?,
299 )
300 }
301}
302
303pub trait TensorAnalytic {
313 fn exp(&mut self, input: &Tensor) -> crate::Result<Tensor>;
314 fn exp_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
315 self.exp(read_tensor("exp", input)?)
316 }
317
318 fn log(&mut self, input: &Tensor) -> crate::Result<Tensor>;
319 fn log_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
320 self.log(read_tensor("log", input)?)
321 }
322
323 fn sin(&mut self, input: &Tensor) -> crate::Result<Tensor>;
324 fn sin_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
325 self.sin(read_tensor("sin", input)?)
326 }
327
328 fn cos(&mut self, input: &Tensor) -> crate::Result<Tensor>;
329 fn cos_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
330 self.cos(read_tensor("cos", input)?)
331 }
332
333 fn tanh(&mut self, input: &Tensor) -> crate::Result<Tensor>;
334 fn tanh_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
335 self.tanh(read_tensor("tanh", input)?)
336 }
337
338 fn sqrt(&mut self, input: &Tensor) -> crate::Result<Tensor>;
339 fn sqrt_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
340 self.sqrt(read_tensor("sqrt", input)?)
341 }
342
343 fn rsqrt(&mut self, input: &Tensor) -> crate::Result<Tensor>;
344 fn rsqrt_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
345 self.rsqrt(read_tensor("rsqrt", input)?)
346 }
347
348 fn pow(&mut self, lhs: &Tensor, rhs: &Tensor) -> crate::Result<Tensor>;
349 fn pow_read(&mut self, lhs: TensorRead<'_>, rhs: TensorRead<'_>) -> crate::Result<Tensor> {
350 self.pow(read_tensor("pow", lhs)?, read_tensor("pow", rhs)?)
351 }
352
353 fn expm1(&mut self, input: &Tensor) -> crate::Result<Tensor>;
354 fn expm1_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
355 self.expm1(read_tensor("expm1", input)?)
356 }
357
358 fn log1p(&mut self, input: &Tensor) -> crate::Result<Tensor>;
359 fn log1p_read(&mut self, input: TensorRead<'_>) -> crate::Result<Tensor> {
360 self.log1p(read_tensor("log1p", input)?)
361 }
362}
363
364pub trait TensorStructural {
374 fn transpose(&mut self, input: &Tensor, perm: &[usize]) -> crate::Result<Tensor>;
375 fn transpose_read(&mut self, input: TensorRead<'_>, perm: &[usize]) -> crate::Result<Tensor> {
376 self.transpose(read_tensor("transpose", input)?, perm)
377 }
378
379 fn reshape(&mut self, input: &Tensor, shape: &[usize]) -> crate::Result<Tensor>;
380 fn reshape_read(&mut self, input: TensorRead<'_>, shape: &[usize]) -> crate::Result<Tensor> {
381 self.reshape(read_tensor("reshape", input)?, shape)
382 }
383
384 fn broadcast_in_dim(
385 &mut self,
386 input: &Tensor,
387 shape: &[usize],
388 dims: &[usize],
389 ) -> crate::Result<Tensor>;
390 fn broadcast_in_dim_read(
391 &mut self,
392 input: TensorRead<'_>,
393 shape: &[usize],
394 dims: &[usize],
395 ) -> crate::Result<Tensor> {
396 self.broadcast_in_dim(read_tensor("broadcast_in_dim", input)?, shape, dims)
397 }
398
399 fn cast(&mut self, input: &Tensor, to: crate::DType) -> crate::Result<Tensor>;
417
418 fn convert(&mut self, input: &Tensor, to: crate::DType) -> crate::Result<Tensor> {
436 validate_convert_dtype("convert", input.dtype(), to)?;
437 self.cast(input, to)
438 }
439
440 fn extract_diagonal(
441 &mut self,
442 input: &Tensor,
443 axis_a: usize,
444 axis_b: usize,
445 ) -> crate::Result<Tensor>;
446 fn embed_diagonal(
447 &mut self,
448 input: &Tensor,
449 axis_a: usize,
450 axis_b: usize,
451 ) -> crate::Result<Tensor>;
452 fn tril(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor>;
453 fn triu(&mut self, input: &Tensor, k: i64) -> crate::Result<Tensor>;
454}
455
456pub trait TensorReduction {
466 fn reduce_sum(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
467
468 fn reduce_sum_read(&mut self, input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
483 match input.as_tensor() {
484 Some(input) => self.reduce_sum(input, axes),
485 None => Err(crate::Error::backend_failure(
486 "reduce_sum",
487 "backend does not accept borrowed tensor views at this execution boundary",
488 )),
489 }
490 }
491
492 fn reduce_prod(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
493
494 fn reduce_prod_read(&mut self, input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
509 match input.as_tensor() {
510 Some(input) => self.reduce_prod(input, axes),
511 None => Err(crate::Error::backend_failure(
512 "reduce_prod",
513 "backend does not accept borrowed tensor views at this execution boundary",
514 )),
515 }
516 }
517
518 fn reduce_max(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
519
520 fn reduce_max_read(&mut self, input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
535 match input.as_tensor() {
536 Some(input) => self.reduce_max(input, axes),
537 None => Err(crate::Error::backend_failure(
538 "reduce_max",
539 "backend does not accept borrowed tensor views at this execution boundary",
540 )),
541 }
542 }
543
544 fn reduce_min(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
545
546 fn reduce_min_read(&mut self, input: TensorRead<'_>, axes: &[usize]) -> crate::Result<Tensor> {
561 match input.as_tensor() {
562 Some(input) => self.reduce_min(input, axes),
563 None => Err(crate::Error::backend_failure(
564 "reduce_min",
565 "backend does not accept borrowed tensor views at this execution boundary",
566 )),
567 }
568 }
569}
570
571pub trait TensorDot: TensorElementwise {
581 fn dot_general(
582 &mut self,
583 lhs: &Tensor,
584 rhs: &Tensor,
585 config: &DotGeneralConfig,
586 ) -> crate::Result<Tensor>;
587
588 #[doc(hidden)]
589 fn dot_general_read(
590 &mut self,
591 lhs: TensorRead<'_>,
592 rhs: TensorRead<'_>,
593 config: &DotGeneralConfig,
594 ) -> crate::Result<Tensor> {
595 match (lhs.as_tensor(), rhs.as_tensor()) {
596 (Some(lhs), Some(rhs)) => self.dot_general(lhs, rhs, config),
597 _ => {
598 let lhs = lhs.to_tensor()?;
599 let rhs = rhs.to_tensor()?;
600 self.dot_general(&lhs, &rhs, config)
601 }
602 }
603 }
604
605 #[doc(hidden)]
606 fn dot_general_with_conj(
607 &mut self,
608 lhs: &Tensor,
609 rhs: &Tensor,
610 config: &DotGeneralConfig,
611 lhs_conj: bool,
612 rhs_conj: bool,
613 ) -> crate::Result<Tensor> {
614 if !lhs_conj && !rhs_conj {
615 return self.dot_general(lhs, rhs, config);
616 }
617
618 let lhs_tmp;
619 let lhs_ref = if lhs_conj {
620 lhs_tmp = self.conj(lhs)?;
621 &lhs_tmp
622 } else {
623 lhs
624 };
625 let rhs_tmp;
626 let rhs_ref = if rhs_conj {
627 rhs_tmp = self.conj(rhs)?;
628 &rhs_tmp
629 } else {
630 rhs
631 };
632 self.dot_general(lhs_ref, rhs_ref, config)
633 }
634
635 #[allow(clippy::too_many_arguments)]
636 #[doc(hidden)]
637 fn dot_general_with_conj_read(
638 &mut self,
639 lhs: TensorRead<'_>,
640 rhs: TensorRead<'_>,
641 config: &DotGeneralConfig,
642 lhs_conj: bool,
643 rhs_conj: bool,
644 ) -> crate::Result<Tensor> {
645 if !lhs_conj && !rhs_conj {
646 return self.dot_general_read(lhs, rhs, config);
647 }
648
649 let lhs_tmp;
650 let lhs_ref = if let Some(tensor) = lhs.as_tensor() {
651 tensor
652 } else {
653 lhs_tmp = lhs.to_tensor()?;
654 &lhs_tmp
655 };
656 let rhs_tmp;
657 let rhs_ref = if let Some(tensor) = rhs.as_tensor() {
658 tensor
659 } else {
660 rhs_tmp = rhs.to_tensor()?;
661 &rhs_tmp
662 };
663 self.dot_general_with_conj(lhs_ref, rhs_ref, config, lhs_conj, rhs_conj)
664 }
665}
666
667pub trait SessionCachedDot: TensorDot {
677 #[doc(hidden)]
678 fn dot_general_cached(
679 &mut self,
680 _cache_slot: Option<usize>,
681 lhs: &Tensor,
682 rhs: &Tensor,
683 config: &DotGeneralConfig,
684 ) -> crate::Result<Tensor> {
685 self.dot_general(lhs, rhs, config)
686 }
687
688 #[doc(hidden)]
689 fn dot_general_read_cached(
690 &mut self,
691 cache_slot: Option<usize>,
692 lhs: TensorRead<'_>,
693 rhs: TensorRead<'_>,
694 config: &DotGeneralConfig,
695 ) -> crate::Result<Tensor> {
696 match (lhs.as_tensor(), rhs.as_tensor()) {
697 (Some(lhs), Some(rhs)) => self.dot_general_cached(cache_slot, lhs, rhs, config),
698 _ => {
699 let lhs = lhs.to_tensor()?;
700 let rhs = rhs.to_tensor()?;
701 self.dot_general_cached(cache_slot, &lhs, &rhs, config)
702 }
703 }
704 }
705
706 #[allow(clippy::too_many_arguments)]
708 #[doc(hidden)]
709 fn dot_general_with_conj_cached(
710 &mut self,
711 _cache_slot: Option<usize>,
712 lhs: &Tensor,
713 rhs: &Tensor,
714 config: &DotGeneralConfig,
715 lhs_conj: bool,
716 rhs_conj: bool,
717 ) -> crate::Result<Tensor> {
718 self.dot_general_with_conj(lhs, rhs, config, lhs_conj, rhs_conj)
719 }
720
721 #[allow(clippy::too_many_arguments)]
723 #[doc(hidden)]
724 fn dot_general_with_conj_read_cached(
725 &mut self,
726 cache_slot: Option<usize>,
727 lhs: TensorRead<'_>,
728 rhs: TensorRead<'_>,
729 config: &DotGeneralConfig,
730 lhs_conj: bool,
731 rhs_conj: bool,
732 ) -> crate::Result<Tensor> {
733 if !lhs_conj && !rhs_conj {
734 return self.dot_general_read_cached(cache_slot, lhs, rhs, config);
735 }
736
737 let lhs_tmp;
738 let lhs_ref = if let Some(tensor) = lhs.as_tensor() {
739 tensor
740 } else {
741 lhs_tmp = lhs.to_tensor()?;
742 &lhs_tmp
743 };
744 let rhs_tmp;
745 let rhs_ref = if let Some(tensor) = rhs.as_tensor() {
746 tensor
747 } else {
748 rhs_tmp = rhs.to_tensor()?;
749 &rhs_tmp
750 };
751 self.dot_general_with_conj_cached(cache_slot, lhs_ref, rhs_ref, config, lhs_conj, rhs_conj)
752 }
753}
754
755pub trait TensorIndexing {
765 fn gather(
766 &mut self,
767 operand: &Tensor,
768 start_indices: &Tensor,
769 config: &GatherConfig,
770 ) -> crate::Result<Tensor>;
771 fn scatter(
772 &mut self,
773 operand: &Tensor,
774 scatter_indices: &Tensor,
775 updates: &Tensor,
776 config: &ScatterConfig,
777 ) -> crate::Result<Tensor>;
778 fn slice(&mut self, input: &Tensor, config: &SliceConfig) -> crate::Result<Tensor>;
779 fn dynamic_slice(
780 &mut self,
781 input: &Tensor,
782 starts: &Tensor,
783 slice_sizes: &[usize],
784 ) -> crate::Result<Tensor>;
785 fn dynamic_update_slice(
786 &mut self,
787 operand: &Tensor,
788 update: &Tensor,
789 starts: &Tensor,
790 ) -> crate::Result<Tensor>;
791 fn pad(&mut self, input: &Tensor, config: &PadConfig) -> crate::Result<Tensor>;
792 fn concatenate(&mut self, inputs: &[&Tensor], axis: usize) -> crate::Result<Tensor>;
793 fn reverse(&mut self, input: &Tensor, axes: &[usize]) -> crate::Result<Tensor>;
794}
795
796pub trait TensorViewCanonicalization<T: Clone + 'static, R: TensorRank> {
820 fn to_contiguous(
821 &mut self,
822 view: &TypedTensorView<'_, T, R>,
823 ) -> crate::Result<TypedTensor<T, R>>;
824
825 fn copy_from_contiguous(
826 &mut self,
827 src: &TypedTensor<T, R>,
828 dst: &mut TypedTensorViewMut<'_, T, R>,
829 ) -> crate::Result<()>;
830}
831
832pub trait TensorFusion {
842 #[doc(hidden)]
843 fn execute_elementwise_fusion(
844 &mut self,
845 _inputs: &[&Tensor],
846 _plan: &ElementwiseFusionPlan,
847 ) -> crate::Result<Option<Vec<Tensor>>> {
848 Ok(None)
849 }
850
851 #[doc(hidden)]
852 #[allow(clippy::too_many_arguments)]
853 fn execute_broadcast_multiply(
854 &mut self,
855 _lhs: TensorRead<'_>,
856 _lhs_shape: &[usize],
857 _lhs_dims: &[usize],
858 _rhs: TensorRead<'_>,
859 _rhs_shape: &[usize],
860 _rhs_dims: &[usize],
861 ) -> crate::Result<Option<Tensor>> {
862 Ok(None)
863 }
864
865 #[doc(hidden)]
866 #[allow(clippy::too_many_arguments)]
867 fn execute_broadcast_multiply_value(
868 &mut self,
869 lhs: TensorRead<'_>,
870 lhs_shape: &[usize],
871 lhs_dims: &[usize],
872 rhs: TensorRead<'_>,
873 rhs_shape: &[usize],
874 rhs_dims: &[usize],
875 ) -> crate::Result<Option<TensorValue>> {
876 self.execute_broadcast_multiply(lhs, lhs_shape, lhs_dims, rhs, rhs_shape, rhs_dims)
877 .map(|tensor| tensor.map(TensorValue::from_tensor))
878 }
879}
880
881pub trait TensorBuffer {
891 fn reclaim_buffer(&mut self, _tensor: Tensor) {}
892}
893
894pub trait TensorDeviceTransfer {
904 fn download_to_host(&mut self, tensor: &Tensor) -> crate::Result<Tensor> {
905 Ok(tensor.clone())
906 }
907
908 fn upload_host_tensor(&mut self, tensor: &Tensor) -> crate::Result<Tensor> {
909 Ok(tensor.clone())
910 }
911}
912
913pub trait BackendRuntimeCache {
923 #[doc(hidden)]
924 type RuntimeCache: RuntimeCacheControl + Send + Sync + 'static;
925}
926
927pub trait BackendCachedDot: BackendRuntimeCache + TensorDot {
937 #[doc(hidden)]
938 fn dot_general_cached(
939 &mut self,
940 _cache: &mut Self::RuntimeCache,
941 _cache_slot: Option<usize>,
942 lhs: &Tensor,
943 rhs: &Tensor,
944 config: &DotGeneralConfig,
945 ) -> crate::Result<Tensor> {
946 self.dot_general(lhs, rhs, config)
947 }
948
949 #[doc(hidden)]
950 fn dot_general_read_cached(
951 &mut self,
952 cache: &mut Self::RuntimeCache,
953 cache_slot: Option<usize>,
954 lhs: TensorRead<'_>,
955 rhs: TensorRead<'_>,
956 config: &DotGeneralConfig,
957 ) -> crate::Result<Tensor> {
958 match (lhs.as_tensor(), rhs.as_tensor()) {
959 (Some(lhs), Some(rhs)) => self.dot_general_cached(cache, cache_slot, lhs, rhs, config),
960 _ => {
961 let lhs = lhs.to_tensor()?;
962 let rhs = rhs.to_tensor()?;
963 self.dot_general_cached(cache, cache_slot, &lhs, &rhs, config)
964 }
965 }
966 }
967
968 #[allow(clippy::too_many_arguments)]
970 #[doc(hidden)]
971 fn dot_general_with_conj_cached(
972 &mut self,
973 _cache: &mut Self::RuntimeCache,
974 _cache_slot: Option<usize>,
975 lhs: &Tensor,
976 rhs: &Tensor,
977 config: &DotGeneralConfig,
978 lhs_conj: bool,
979 rhs_conj: bool,
980 ) -> crate::Result<Tensor> {
981 self.dot_general_with_conj(lhs, rhs, config, lhs_conj, rhs_conj)
982 }
983
984 #[allow(clippy::too_many_arguments)]
986 #[doc(hidden)]
987 fn dot_general_with_conj_read_cached(
988 &mut self,
989 cache: &mut Self::RuntimeCache,
990 cache_slot: Option<usize>,
991 lhs: TensorRead<'_>,
992 rhs: TensorRead<'_>,
993 config: &DotGeneralConfig,
994 lhs_conj: bool,
995 rhs_conj: bool,
996 ) -> crate::Result<Tensor> {
997 if !lhs_conj && !rhs_conj {
998 return self.dot_general_read_cached(cache, cache_slot, lhs, rhs, config);
999 }
1000
1001 let lhs_tmp;
1002 let lhs_ref = if let Some(tensor) = lhs.as_tensor() {
1003 tensor
1004 } else {
1005 lhs_tmp = lhs.to_tensor()?;
1006 &lhs_tmp
1007 };
1008 let rhs_tmp;
1009 let rhs_ref = if let Some(tensor) = rhs.as_tensor() {
1010 tensor
1011 } else {
1012 rhs_tmp = rhs.to_tensor()?;
1013 &rhs_tmp
1014 };
1015 self.dot_general_with_conj_cached(
1016 cache, cache_slot, lhs_ref, rhs_ref, config, lhs_conj, rhs_conj,
1017 )
1018 }
1019}
1020
1021pub trait BackendSessionHost: BackendRuntimeCache {
1031 fn with_backend_session<R: Send>(
1032 &mut self,
1033 f: impl FnOnce(&mut dyn BackendSession) -> R + Send,
1034 ) -> R
1035 where
1036 Self: TensorBackend + Sized,
1037 {
1038 default_backend_session(self, f)
1039 }
1040
1041 #[doc(hidden)]
1042 fn with_backend_session_cached<R: Send>(
1043 &mut self,
1044 _cache: &mut Self::RuntimeCache,
1045 f: impl FnOnce(&mut dyn BackendSession) -> R + Send,
1046 ) -> R
1047 where
1048 Self: TensorBackend + Sized,
1049 {
1050 self.with_backend_session(f)
1051 }
1052}
1053
1054#[doc(hidden)]
1056pub trait TensorBackendOps:
1057 TensorElementwise
1058 + TensorAnalytic
1059 + TensorStructural
1060 + TensorReduction
1061 + TensorIndexing
1062 + TensorDot
1063 + TensorFusion
1064 + TensorBuffer
1065{
1066}
1067
1068impl<T> TensorBackendOps for T where
1069 T: TensorElementwise
1070 + TensorAnalytic
1071 + TensorStructural
1072 + TensorReduction
1073 + TensorIndexing
1074 + TensorDot
1075 + TensorFusion
1076 + TensorBuffer
1077 + ?Sized
1078{
1079}
1080
1081pub trait BackendSession: TensorBackendOps + SessionCachedDot {}
1104
1105impl<T> BackendSession for T where T: TensorBackendOps + SessionCachedDot + ?Sized {}
1106
1107pub trait TensorBackend:
1117 BackendRuntimeCache
1118 + TensorBackendOps
1119 + BackendCachedDot
1120 + TensorDeviceTransfer
1121 + BackendSessionHost
1122{
1123}
1124
1125impl<T> SessionCachedDot for T where T: TensorBackend + ?Sized {}
1126
1127pub fn default_backend_session<B: TensorBackend, R: Send>(
1142 backend: &mut B,
1143 f: impl FnOnce(&mut dyn BackendSession) -> R + Send,
1144) -> R {
1145 f(backend)
1146}