1#[cfg(not(any(feature = "cpu-faer", feature = "cpu-blas")))]
20compile_error!("enable at least one CPU backend: cpu-faer or cpu-blas");
21
22#[cfg(all(feature = "provider-inject", not(feature = "cpu-blas")))]
23compile_error!("provider-inject requires cpu-blas");
24
25#[cfg(any(
26 all(feature = "blas-openblas", feature = "blas-accelerate"),
27 all(feature = "blas-openblas", feature = "blas-mkl"),
28 all(feature = "blas-accelerate", feature = "blas-mkl"),
29))]
30compile_error!(
31 "enable at most one explicit BLAS provider feature: blas-openblas, blas-accelerate, or blas-mkl"
32);
33
34#[cfg(all(
35 feature = "provider-inject",
36 any(
37 feature = "blas-openblas",
38 feature = "blas-accelerate",
39 feature = "blas-mkl"
40 )
41))]
42compile_error!("provider-inject cannot be combined with explicit BLAS provider features");
43
44pub mod affinity;
45mod analytic;
46pub mod backend;
47mod buffer_pool;
48pub mod context;
49mod elementwise;
50mod exec_session;
51mod gemm;
52mod indexing;
53mod indexing_alloc;
54#[cfg(feature = "provider-inject")]
55pub mod inject;
56mod reduction;
57mod structural;
58
59use strided_kernel::{col_major_strides as kernel_col_major_strides, StridedArray, StridedView};
60
61use crate::buffer_pool::{BufferPool, PoolScalar};
62pub(crate) use tenferro_tensor::*;
63
64#[cfg(feature = "provider-src")]
65extern crate blas_src as _;
66#[cfg(feature = "provider-inject")]
67extern crate cblas_inject as _;
68#[cfg(feature = "provider-src")]
69extern crate cblas_src as _;
70#[cfg(feature = "provider-inject")]
71extern crate lapack_inject as _;
72#[cfg(feature = "provider-src")]
73extern crate lapack_src as _;
74
75pub use affinity::{available_parallelism, process_cpu_affinity_count};
76pub use backend::{CpuBackend, CpuBackendKind};
77pub use buffer_pool::BufferPoolStats;
78pub use context::CpuContext;
79pub use elementwise::{
80 abs, add, clamp, compare, conj, div, maximum, minimum, mul, neg, select, sign,
81};
82pub use indexing::{dynamic_slice, dynamic_update_slice, gather, pad, scatter};
83pub use reduction::{reduce_max, reduce_min, reduce_prod, reduce_sum};
84pub use structural::{
85 broadcast_in_dim, convert, embed_diagonal, extract_diagonal, reshape, transpose, tril, triu,
86};
87
88#[doc(hidden)]
94pub mod linalg_interop {
95 pub use crate::buffer_pool::{BufferPool, PoolScalar};
96}
97
98pub(crate) fn cpu_backend_buffer_error(op: &'static str) -> crate::Error {
99 crate::Error::backend_failure(
100 op,
101 "CPU backend received backend buffer; download to host before CPU execution",
102 )
103}
104
105pub(crate) trait ConjElem {
106 fn conj_elem(self) -> Self;
107}
108
109impl ConjElem for f32 {
110 fn conj_elem(self) -> Self {
111 self
112 }
113}
114
115impl ConjElem for f64 {
116 fn conj_elem(self) -> Self {
117 self
118 }
119}
120
121impl ConjElem for num_complex::Complex32 {
122 fn conj_elem(self) -> Self {
123 self.conj()
124 }
125}
126
127impl ConjElem for num_complex::Complex64 {
128 fn conj_elem(self) -> Self {
129 self.conj()
130 }
131}
132
133pub(crate) fn typed_host_data<'a, T>(
134 op: &'static str,
135 tensor: &'a TypedTensor<T>,
136) -> crate::Result<&'a [T]> {
137 match tensor.buffer() {
138 Buffer::Host(data) => Ok(data.as_slice()),
139 Buffer::Backend(_) => Err(cpu_backend_buffer_error(op)),
140 }
141}
142
143pub(crate) fn typed_view<'a, T: Copy>(
144 op: &'static str,
145 tensor: &'a TypedTensor<T>,
146) -> crate::Result<StridedView<'a, T>> {
147 match tensor.buffer() {
148 Buffer::Host(data) => {
149 let strides = kernel_col_major_strides(tensor.shape());
150 StridedView::new(data.as_slice(), tensor.shape(), &strides, 0)
151 .map_err(|err| crate::Error::backend_failure(op, err))
152 }
153 Buffer::Backend(_) => Err(cpu_backend_buffer_error(op)),
154 }
155}
156
157pub(crate) fn typed_view_from_view<'a, T: Copy + 'static, R: TensorRank>(
158 op: &'static str,
159 view: &TypedTensorView<'a, T, R>,
160) -> crate::Result<StridedView<'a, T>> {
161 if view.backend_buffer().is_some() {
162 return Err(cpu_backend_buffer_error(op));
163 }
164 StridedView::new(
165 view.host_storage()?,
166 view.shape(),
167 view.strides(),
168 view.offset(),
169 )
170 .map_err(|err| crate::Error::backend_failure(op, err))
171}
172
173pub(crate) fn materialize_tensor_read(
174 op: &'static str,
175 input: TensorRead<'_>,
176) -> crate::Result<Tensor> {
177 match input {
178 TensorRead::Tensor(tensor) => clone_host_tensor_read(op, tensor),
179 TensorRead::View(view) => materialize_tensor_view(op, view),
180 }
181}
182
183fn clone_host_tensor_read(op: &'static str, tensor: &Tensor) -> crate::Result<Tensor> {
184 macro_rules! clone_host {
185 ($variant:ident, $tensor:expr) => {{
186 typed_host_data(op, $tensor)?;
187 Ok(Tensor::$variant($tensor.clone()))
188 }};
189 }
190
191 match tensor {
192 Tensor::F32(tensor) => clone_host!(F32, tensor),
193 Tensor::F64(tensor) => clone_host!(F64, tensor),
194 Tensor::I32(tensor) => clone_host!(I32, tensor),
195 Tensor::I64(tensor) => clone_host!(I64, tensor),
196 Tensor::Bool(tensor) => clone_host!(Bool, tensor),
197 Tensor::C32(tensor) => clone_host!(C32, tensor),
198 Tensor::C64(tensor) => clone_host!(C64, tensor),
199 }
200}
201
202fn materialize_tensor_view(op: &'static str, view: TensorView<'_>) -> crate::Result<Tensor> {
203 macro_rules! materialize {
204 ($variant:ident, $view:expr) => {{
205 if $view.backend_buffer().is_some() {
206 return Err(cpu_backend_buffer_error(op));
207 }
208 Ok(Tensor::$variant($view.to_contiguous()?))
209 }};
210 }
211
212 match view {
213 TensorView::F32(view) => materialize!(F32, view),
214 TensorView::F64(view) => materialize!(F64, view),
215 TensorView::I32(view) => materialize!(I32, view),
216 TensorView::I64(view) => materialize!(I64, view),
217 TensorView::Bool(view) => materialize!(Bool, view),
218 TensorView::C32(view) => materialize!(C32, view),
219 TensorView::C64(view) => materialize!(C64, view),
220 }
221}
222
223#[allow(clippy::uninit_vec)]
229#[cfg(test)]
230pub(crate) unsafe fn typed_array_uninit<T>(shape: &[usize]) -> StridedArray<T> {
231 let total: usize = shape.iter().product();
232 let strides = kernel_col_major_strides(shape);
233 let mut data = Vec::with_capacity(total);
234 unsafe { data.set_len(total) };
236 StridedArray::from_parts(data, shape, &strides, 0).expect("column-major output array")
239}
240
241pub(crate) unsafe fn typed_array_uninit_from_pool<T>(
247 buffers: &mut BufferPool,
248 shape: &[usize],
249) -> StridedArray<T>
250where
251 T: PoolScalar,
252{
253 let total: usize = shape.iter().product();
254 let strides = kernel_col_major_strides(shape);
255 let data = unsafe { T::pool_acquire(buffers, total) };
257 StridedArray::from_parts(data, shape, &strides, 0).expect("column-major output array")
260}
261
262pub(crate) fn tensor_from_array<T: Clone>(array: StridedArray<T>) -> TypedTensor<T> {
263 TypedTensor::from_vec_col_major(array.dims().to_vec(), array.into_data())
265 .expect("strided array dimensions match owned data length")
266}
267
268pub(crate) fn default_placement() -> Placement {
269 Placement {
270 memory_kind: MemoryKind::UnpinnedHost,
271 device: None,
272 }
273}
274
275pub(crate) fn flat_to_multi(mut flat: usize, shape: &[usize], out: &mut [usize]) {
276 assert_eq!(shape.len(), out.len());
277 for (axis, &dim) in shape.iter().enumerate() {
278 if dim == 0 {
279 out[axis] = 0;
280 } else {
281 out[axis] = flat % dim;
282 flat /= dim;
283 }
284 }
285}
286
287#[cfg(test)]
288mod tests;