tenferro_internal_ad_core/
value.rs1use tenferro_internal_frontend_core::DynTensor;
2
3pub type DynValue = tidu::Value<DynTensor>;
5
6pub fn new_dyn_value(primal: DynTensor) -> DynValue {
7 tidu::Value::new(primal)
8}
9
10pub fn new_reverse_leaf(primal: DynTensor) -> DynValue {
11 tidu::Value::new(primal).with_requires_grad(true)
12}
13
14#[cfg(test)]
15mod tests {
16 use super::*;
17 use tenferro_internal_frontend_core::{DynTensor, ScalarType, StructuredTensor};
18 use tenferro_tensor::{MemoryOrder, Tensor};
19
20 fn dyn_tensor_from_slice(data: &[f64], dims: &[usize]) -> DynTensor {
21 let tensor = Tensor::<f64>::from_slice(data, dims, MemoryOrder::ColumnMajor).unwrap();
22 StructuredTensor::from(tensor).into()
23 }
24
25 #[test]
26 fn new_dyn_value_stays_detached() {
27 let primal = dyn_tensor_from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
28 let value = new_dyn_value(primal);
29
30 assert!(!value.requires_grad());
31 assert_eq!(value.primal().scalar_type(), ScalarType::F64);
32 assert_eq!(value.primal().dims(), &[2, 2]);
33 }
34
35 #[test]
36 fn new_reverse_leaf_enables_grad_tracking() {
37 let primal = dyn_tensor_from_slice(&[5.0, 6.0], &[2]);
38 let value = new_reverse_leaf(primal);
39
40 assert!(value.requires_grad());
41 assert_eq!(value.primal().scalar_type(), ScalarType::F64);
42 assert_eq!(value.primal().dims(), &[2]);
43 }
44}