tenferro_tensor/tensor/
constructors_special.rs1use num_traits::{Float, NumCast};
2use tenferro_algebra::Scalar;
3use tenferro_device::{Error, LogicalMemorySpace, Result};
4
5use super::Tensor;
6use crate::MemoryOrder;
7
8impl<T> Tensor<T>
9where
10 T: Scalar + Float + NumCast,
11{
12 pub fn eye(n: usize, memory_space: LogicalMemorySpace, order: MemoryOrder) -> Result<Self> {
28 let dims = [n, n];
29 let storage_len = n
30 .checked_mul(n)
31 .ok_or_else(|| Error::StrideError(format!("eye: storage length overflow for n={n}")))?;
32 let mut data = vec![T::zero(); storage_len];
33 let diag_stride = n + 1;
34 for i in 0..n {
35 data[i * diag_stride] = T::one();
36 }
37 Self::finish_allocation(
38 Self::main_memory_contiguous(data, &dims, order),
39 memory_space,
40 )
41 }
42
43 pub fn arange(
61 start: T,
62 end: T,
63 step: T,
64 memory_space: LogicalMemorySpace,
65 order: MemoryOrder,
66 ) -> Result<Self> {
67 if step.is_zero() {
68 return Err(Error::InvalidArgument(
69 "arange: step must be non-zero".into(),
70 ));
71 }
72
73 let mut data = Vec::new();
74 let zero = T::zero();
75 if step > zero {
76 let mut current = start;
77 while current < end {
78 data.push(current);
79 current = current + step;
80 }
81 } else {
82 let mut current = start;
83 while current > end {
84 data.push(current);
85 current = current + step;
86 }
87 }
88
89 let dims = [data.len()];
90 let tensor = Self::main_memory_contiguous(data, &dims, order);
91 Self::finish_allocation(tensor, memory_space)
92 }
93
94 pub fn linspace(
114 start: T,
115 end: T,
116 n_samples: isize,
117 memory_space: LogicalMemorySpace,
118 order: MemoryOrder,
119 ) -> Result<Self> {
120 if n_samples < 0 {
121 return Err(Error::InvalidArgument(format!(
122 "linspace: steps must be non-negative, got {n_samples}"
123 )));
124 }
125
126 let n_samples = n_samples as usize;
127 let mut data = Vec::with_capacity(n_samples);
128 match n_samples {
129 0 => {}
130 1 => data.push(start),
131 _ => {
132 let denom = <T as NumCast>::from(n_samples - 1).ok_or_else(|| {
133 Error::InvalidArgument(format!(
134 "linspace: sample count {} cannot be represented in target scalar type",
135 n_samples
136 ))
137 })?;
138 let step = (end - start) / denom;
139 let mut current = start;
140 for _ in 0..n_samples {
141 data.push(current);
142 current = current + step;
143 }
144 data[n_samples - 1] = end;
145 }
146 }
147
148 let dims = [data.len()];
149 let tensor = Self::main_memory_contiguous(data, &dims, order);
150 Self::finish_allocation(tensor, memory_space)
151 }
152}
153
154#[cfg(test)]
155mod tests;