tenferro_tensor/tensor/element_access.rs
1use super::Tensor;
2
3impl<T> Tensor<T> {
4 /// Access a single element by multi-dimensional index.
5 ///
6 /// Returns `None` if the index is out of bounds or the underlying buffer
7 /// is not CPU-accessible.
8 ///
9 /// # Examples
10 ///
11 /// ```
12 /// use tenferro_tensor::{MemoryOrder, Tensor};
13 ///
14 /// // Column-major: data is laid out column by column.
15 /// // from_slice with ColumnMajor and data [1,2,3,4] gives:
16 /// // column 0 = [1, 2], column 1 = [3, 4]
17 /// // matrix = [[1, 3],
18 /// // [2, 4]]
19 /// let t = Tensor::<f64>::from_slice(
20 /// &[1.0, 2.0, 3.0, 4.0], &[2, 2], MemoryOrder::ColumnMajor,
21 /// ).unwrap();
22 /// assert_eq!(t.get(&[0, 0]), Some(&1.0));
23 /// assert_eq!(t.get(&[1, 0]), Some(&2.0));
24 /// assert_eq!(t.get(&[0, 1]), Some(&3.0));
25 /// assert_eq!(t.get(&[1, 1]), Some(&4.0));
26 /// assert_eq!(t.get(&[2, 0]), None); // out of bounds
27 /// ```
28 pub fn get(&self, index: &[usize]) -> Option<&T> {
29 let pos = self.linear_offset(index)?;
30 self.buffer.as_slice().and_then(|s| s.get(pos))
31 }
32
33 /// Access a single element mutably by multi-dimensional index.
34 ///
35 /// Returns `None` if the index is out of bounds, the buffer is not
36 /// CPU-accessible, or the buffer is shared (Arc refcount > 1).
37 ///
38 /// # Examples
39 ///
40 /// ```
41 /// use tenferro_tensor::{MemoryOrder, Tensor};
42 ///
43 /// let mut t = Tensor::<f64>::from_slice(
44 /// &[1.0, 2.0, 3.0, 4.0], &[2, 2], MemoryOrder::ColumnMajor,
45 /// ).unwrap();
46 /// *t.get_mut(&[0, 1]).unwrap() = 99.0;
47 /// assert_eq!(t.get(&[0, 1]), Some(&99.0));
48 /// // Out of bounds returns None:
49 /// assert!(t.get_mut(&[2, 0]).is_none());
50 /// ```
51 pub fn get_mut(&mut self, index: &[usize]) -> Option<&mut T> {
52 let pos = self.linear_offset(index)?;
53 self.buffer.as_mut_slice().and_then(|s| s.get_mut(pos))
54 }
55
56 /// Write a value at the given multi-dimensional index.
57 ///
58 /// Returns `Ok(())` on success, or an error if the index is out of bounds,
59 /// the buffer is not CPU-accessible, or the buffer is shared
60 /// (Arc refcount > 1). Call [`deep_clone`](Tensor::deep_clone) first to
61 /// obtain an exclusively-owned copy.
62 ///
63 /// # Examples
64 ///
65 /// ```
66 /// use tenferro_tensor::{MemoryOrder, Tensor};
67 ///
68 /// let mut t = Tensor::<f64>::from_slice(
69 /// &[1.0, 2.0, 3.0, 4.0], &[2, 2], MemoryOrder::ColumnMajor,
70 /// ).unwrap();
71 /// t.set(&[1, 0], 10.0).unwrap();
72 /// assert_eq!(t.get(&[1, 0]), Some(&10.0));
73 ///
74 /// // Shared buffers cannot be written:
75 /// let shared = t.clone(); // refcount == 2
76 /// // t.set(&[0, 0], 5.0) would fail here because buffer is shared
77 /// ```
78 pub fn set(&mut self, index: &[usize], value: T) -> tenferro_device::Result<()> {
79 // Collect error context before taking &mut self.
80 let dims_debug = format!("{:?}", &*self.dims);
81 let unique = self.buffer.is_unique();
82 let elem = self.get_mut(index).ok_or_else(|| {
83 tenferro_device::Error::InvalidArgument(format!(
84 "set: cannot write at index {index:?} (dims {dims_debug}, buffer {})",
85 if unique { "accessible" } else { "shared" },
86 ))
87 })?;
88 *elem = value;
89 Ok(())
90 }
91
92 /// Compute the linear buffer offset for a multi-dimensional index.
93 ///
94 /// Returns `None` if the index is out of bounds.
95 fn linear_offset(&self, index: &[usize]) -> Option<usize> {
96 if index.len() != self.dims.len() {
97 return None;
98 }
99 for (i, &idx) in index.iter().enumerate() {
100 if idx >= self.dims[i] {
101 return None;
102 }
103 }
104 let pos: isize = index.iter().zip(self.strides.iter()).try_fold(
105 self.offset,
106 |acc, (&idx, &stride)| {
107 (idx as isize)
108 .checked_mul(stride)
109 .and_then(|v| acc.checked_add(v))
110 },
111 )?;
112 usize::try_from(pos).ok()
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use crate::{MemoryOrder, Tensor};
119
120 #[test]
121 fn get_mut_and_set() {
122 let mut t =
123 Tensor::<f64>::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2], MemoryOrder::ColumnMajor)
124 .unwrap();
125 *t.get_mut(&[0, 1]).unwrap() = 99.0;
126 assert_eq!(t.get(&[0, 1]), Some(&99.0));
127
128 t.set(&[1, 0], 42.0).unwrap();
129 assert_eq!(t.get(&[1, 0]), Some(&42.0));
130 }
131
132 #[test]
133 fn get_mut_out_of_bounds() {
134 let mut t = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
135 assert!(t.get_mut(&[2]).is_none());
136 }
137
138 #[test]
139 fn get_mut_wrong_rank() {
140 let mut t = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
141 assert!(t.get_mut(&[0, 0]).is_none());
142 }
143
144 #[test]
145 fn set_shared_buffer_fails() {
146 let mut t = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
147 let _shared = t.clone(); // refcount == 2
148 assert!(t.set(&[0], 99.0).is_err());
149 }
150
151 #[test]
152 fn set_out_of_bounds_fails() {
153 let mut t = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
154 assert!(t.set(&[5], 99.0).is_err());
155 }
156
157 #[test]
158 fn get_and_set_on_view() {
159 let t = Tensor::<f64>::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4], MemoryOrder::ColumnMajor)
160 .unwrap();
161 // narrow creates a view sharing the buffer
162 let view = t.narrow(0, 1, 2).unwrap();
163 assert_eq!(view.get(&[0]), Some(&2.0));
164 assert_eq!(view.get(&[1]), Some(&3.0));
165
166 // view is shared, so deep_clone to get exclusive ownership
167 let mut owned = view.deep_clone();
168 owned.set(&[0], 99.0).unwrap();
169 assert_eq!(owned.get(&[0]), Some(&99.0));
170 // original unchanged
171 assert_eq!(t.get(&[1]), Some(&2.0));
172 }
173
174 #[test]
175 fn deep_clone_is_independent() {
176 let a =
177 Tensor::<f64>::from_slice(&[1.0, 2.0, 3.0], &[3], MemoryOrder::ColumnMajor).unwrap();
178 let mut b = a.deep_clone();
179 b.set(&[0], 99.0).unwrap();
180 assert_eq!(b.get(&[0]), Some(&99.0));
181 assert_eq!(a.get(&[0]), Some(&1.0));
182 }
183
184 #[test]
185 fn deep_clone_empty_tensor() {
186 let a = Tensor::<f64>::from_slice(&[], &[0], MemoryOrder::ColumnMajor).unwrap();
187 let b = a.deep_clone();
188 assert_eq!(b.dims(), &[0]);
189 assert_eq!(b.to_vec(), Vec::<f64>::new());
190 }
191}