tenferro_tensor/tensor/
element_access.rs

1use super::Tensor;
2
3impl<T> Tensor<T> {
4    /// Access a single element by multi-dimensional index.
5    ///
6    /// Returns `None` if the index is out of bounds or the underlying buffer
7    /// is not CPU-accessible.
8    ///
9    /// # Examples
10    ///
11    /// ```
12    /// use tenferro_tensor::{MemoryOrder, Tensor};
13    ///
14    /// // Column-major: data is laid out column by column.
15    /// // from_slice with ColumnMajor and data [1,2,3,4] gives:
16    /// //   column 0 = [1, 2], column 1 = [3, 4]
17    /// //   matrix = [[1, 3],
18    /// //             [2, 4]]
19    /// let t = Tensor::<f64>::from_slice(
20    ///     &[1.0, 2.0, 3.0, 4.0], &[2, 2], MemoryOrder::ColumnMajor,
21    /// ).unwrap();
22    /// assert_eq!(t.get(&[0, 0]), Some(&1.0));
23    /// assert_eq!(t.get(&[1, 0]), Some(&2.0));
24    /// assert_eq!(t.get(&[0, 1]), Some(&3.0));
25    /// assert_eq!(t.get(&[1, 1]), Some(&4.0));
26    /// assert_eq!(t.get(&[2, 0]), None); // out of bounds
27    /// ```
28    pub fn get(&self, index: &[usize]) -> Option<&T> {
29        let pos = self.linear_offset(index)?;
30        self.buffer.as_slice().and_then(|s| s.get(pos))
31    }
32
33    /// Access a single element mutably by multi-dimensional index.
34    ///
35    /// Returns `None` if the index is out of bounds, the buffer is not
36    /// CPU-accessible, or the buffer is shared (Arc refcount > 1).
37    ///
38    /// # Examples
39    ///
40    /// ```
41    /// use tenferro_tensor::{MemoryOrder, Tensor};
42    ///
43    /// let mut t = Tensor::<f64>::from_slice(
44    ///     &[1.0, 2.0, 3.0, 4.0], &[2, 2], MemoryOrder::ColumnMajor,
45    /// ).unwrap();
46    /// *t.get_mut(&[0, 1]).unwrap() = 99.0;
47    /// assert_eq!(t.get(&[0, 1]), Some(&99.0));
48    /// // Out of bounds returns None:
49    /// assert!(t.get_mut(&[2, 0]).is_none());
50    /// ```
51    pub fn get_mut(&mut self, index: &[usize]) -> Option<&mut T> {
52        let pos = self.linear_offset(index)?;
53        self.buffer.as_mut_slice().and_then(|s| s.get_mut(pos))
54    }
55
56    /// Write a value at the given multi-dimensional index.
57    ///
58    /// Returns `Ok(())` on success, or an error if the index is out of bounds,
59    /// the buffer is not CPU-accessible, or the buffer is shared
60    /// (Arc refcount > 1). Call [`deep_clone`](Tensor::deep_clone) first to
61    /// obtain an exclusively-owned copy.
62    ///
63    /// # Examples
64    ///
65    /// ```
66    /// use tenferro_tensor::{MemoryOrder, Tensor};
67    ///
68    /// let mut t = Tensor::<f64>::from_slice(
69    ///     &[1.0, 2.0, 3.0, 4.0], &[2, 2], MemoryOrder::ColumnMajor,
70    /// ).unwrap();
71    /// t.set(&[1, 0], 10.0).unwrap();
72    /// assert_eq!(t.get(&[1, 0]), Some(&10.0));
73    ///
74    /// // Shared buffers cannot be written:
75    /// let shared = t.clone(); // refcount == 2
76    /// // t.set(&[0, 0], 5.0) would fail here because buffer is shared
77    /// ```
78    pub fn set(&mut self, index: &[usize], value: T) -> tenferro_device::Result<()> {
79        // Collect error context before taking &mut self.
80        let dims_debug = format!("{:?}", &*self.dims);
81        let unique = self.buffer.is_unique();
82        let elem = self.get_mut(index).ok_or_else(|| {
83            tenferro_device::Error::InvalidArgument(format!(
84                "set: cannot write at index {index:?} (dims {dims_debug}, buffer {})",
85                if unique { "accessible" } else { "shared" },
86            ))
87        })?;
88        *elem = value;
89        Ok(())
90    }
91
92    /// Compute the linear buffer offset for a multi-dimensional index.
93    ///
94    /// Returns `None` if the index is out of bounds.
95    fn linear_offset(&self, index: &[usize]) -> Option<usize> {
96        if index.len() != self.dims.len() {
97            return None;
98        }
99        for (i, &idx) in index.iter().enumerate() {
100            if idx >= self.dims[i] {
101                return None;
102            }
103        }
104        let pos: isize = index.iter().zip(self.strides.iter()).try_fold(
105            self.offset,
106            |acc, (&idx, &stride)| {
107                (idx as isize)
108                    .checked_mul(stride)
109                    .and_then(|v| acc.checked_add(v))
110            },
111        )?;
112        usize::try_from(pos).ok()
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use crate::{MemoryOrder, Tensor};
119
120    #[test]
121    fn get_mut_and_set() {
122        let mut t =
123            Tensor::<f64>::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2], MemoryOrder::ColumnMajor)
124                .unwrap();
125        *t.get_mut(&[0, 1]).unwrap() = 99.0;
126        assert_eq!(t.get(&[0, 1]), Some(&99.0));
127
128        t.set(&[1, 0], 42.0).unwrap();
129        assert_eq!(t.get(&[1, 0]), Some(&42.0));
130    }
131
132    #[test]
133    fn get_mut_out_of_bounds() {
134        let mut t = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
135        assert!(t.get_mut(&[2]).is_none());
136    }
137
138    #[test]
139    fn get_mut_wrong_rank() {
140        let mut t = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
141        assert!(t.get_mut(&[0, 0]).is_none());
142    }
143
144    #[test]
145    fn set_shared_buffer_fails() {
146        let mut t = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
147        let _shared = t.clone(); // refcount == 2
148        assert!(t.set(&[0], 99.0).is_err());
149    }
150
151    #[test]
152    fn set_out_of_bounds_fails() {
153        let mut t = Tensor::<f64>::from_slice(&[1.0, 2.0], &[2], MemoryOrder::ColumnMajor).unwrap();
154        assert!(t.set(&[5], 99.0).is_err());
155    }
156
157    #[test]
158    fn get_and_set_on_view() {
159        let t = Tensor::<f64>::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4], MemoryOrder::ColumnMajor)
160            .unwrap();
161        // narrow creates a view sharing the buffer
162        let view = t.narrow(0, 1, 2).unwrap();
163        assert_eq!(view.get(&[0]), Some(&2.0));
164        assert_eq!(view.get(&[1]), Some(&3.0));
165
166        // view is shared, so deep_clone to get exclusive ownership
167        let mut owned = view.deep_clone();
168        owned.set(&[0], 99.0).unwrap();
169        assert_eq!(owned.get(&[0]), Some(&99.0));
170        // original unchanged
171        assert_eq!(t.get(&[1]), Some(&2.0));
172    }
173
174    #[test]
175    fn deep_clone_is_independent() {
176        let a =
177            Tensor::<f64>::from_slice(&[1.0, 2.0, 3.0], &[3], MemoryOrder::ColumnMajor).unwrap();
178        let mut b = a.deep_clone();
179        b.set(&[0], 99.0).unwrap();
180        assert_eq!(b.get(&[0]), Some(&99.0));
181        assert_eq!(a.get(&[0]), Some(&1.0));
182    }
183
184    #[test]
185    fn deep_clone_empty_tensor() {
186        let a = Tensor::<f64>::from_slice(&[], &[0], MemoryOrder::ColumnMajor).unwrap();
187        let b = a.deep_clone();
188        assert_eq!(b.dims(), &[0]);
189        assert_eq!(b.to_vec(), Vec::<f64>::new());
190    }
191}