Skip to main content

tenferro/
eager.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex, Weak};
3
4use computegraph::fragment::Fragment;
5use computegraph::{GlobalOpKey, GlobalValKey, OpMode, ValRef};
6use tenferro_ops::std_tensor_op::StdTensorOp;
7use tenferro_ops::ShapeGuardContext;
8use tenferro_tensor::cpu::CpuBackend;
9use tenferro_tensor::{Tensor, TensorBackend};
10use tidu::{backward_dag, topo_sort_grad_dag, BackwardCallbacks, GradNode, LinearFragment};
11
12use crate::eager_emitter::EagerEmitter;
13use crate::eager_exec::exec_op_on_tensors;
14use crate::error::{Error, Result};
15use crate::traced::next_input_key;
16
17pub(crate) type GradSlot = Arc<Mutex<Option<Arc<Tensor>>>>;
18pub(crate) type WeakGradSlot = Weak<Mutex<Option<Arc<Tensor>>>>;
19
20/// Shared eager execution context for tensors on a backend.
21///
22/// Reusing one context lets eager tensors share backend state and gradient
23/// storage across a computation.
24///
25/// # Examples
26///
27/// ```
28/// use tenferro::{CpuBackend, EagerContext, EagerTensor, Tensor};
29///
30/// let ctx = EagerContext::with_backend(CpuBackend::new());
31/// let x = EagerTensor::from_tensor_in(Tensor::from_vec(vec![1], vec![1.0_f64]), ctx.clone());
32/// let y = EagerTensor::from_tensor_in(Tensor::from_vec(vec![1], vec![2.0_f64]), ctx);
33/// let z = &x + &y;
34///
35/// assert_eq!(z.data().as_slice::<f64>().unwrap(), &[3.0]);
36/// ```
37pub struct EagerContext<B: TensorBackend> {
38    pub(crate) backend: Mutex<B>,
39    grad_slots: Mutex<HashMap<GlobalValKey<StdTensorOp>, WeakGradSlot>>,
40}
41
42impl<B: TensorBackend> EagerContext<B> {
43    fn new(backend: B) -> Self {
44        Self {
45            backend: Mutex::new(backend),
46            grad_slots: Mutex::new(HashMap::new()),
47        }
48    }
49
50    /// Create a shared eager execution context for the provided backend.
51    ///
52    /// # Examples
53    ///
54    /// ```
55    /// use tenferro::{CpuBackend, EagerContext};
56    ///
57    /// let ctx = EagerContext::with_backend(CpuBackend::new());
58    /// assert_eq!(std::sync::Arc::strong_count(&ctx), 1);
59    /// ```
60    pub fn with_backend(backend: B) -> Arc<Self> {
61        Arc::new(Self::new(backend))
62    }
63
64    pub(crate) fn register_grad_slot(&self, key: &GlobalValKey<StdTensorOp>, slot: &GradSlot) {
65        self.grad_slots
66            .lock()
67            .unwrap()
68            .insert(key.clone(), Arc::downgrade(slot));
69    }
70
71    pub(crate) fn absorb_from(&self, other: &Self) {
72        let other_slots = other.grad_slots.lock().unwrap();
73        let mut slots = self.grad_slots.lock().unwrap();
74        for (key, slot) in other_slots.iter() {
75            slots.entry(key.clone()).or_insert_with(|| slot.clone());
76        }
77    }
78
79    /// Clear all live gradient slots tracked by this context.
80    ///
81    /// This resets the stored gradients to `None` without unregistering the
82    /// tensors, so future `backward()` calls can accumulate again.
83    ///
84    /// # Examples
85    ///
86    /// ```
87    /// use tenferro::{CpuBackend, EagerContext, EagerTensor, Tensor};
88    ///
89    /// let ctx = EagerContext::with_backend(CpuBackend::new());
90    /// let x = EagerTensor::requires_grad_in(Tensor::from_vec(vec![3], vec![1.0_f64, 2.0, 3.0]), ctx.clone());
91    /// let y = EagerTensor::requires_grad_in(Tensor::from_vec(vec![3], vec![4.0_f64, 5.0, 6.0]), ctx.clone());
92    /// let loss = (&x * &y).reduce_sum(&[0]).unwrap();
93    /// let _ = loss.backward().unwrap();
94    ///
95    /// ctx.clear_grads();
96    ///
97    /// assert!(x.grad().is_none());
98    /// assert!(y.grad().is_none());
99    /// ```
100    pub fn clear_grads(&self) {
101        self.grad_slots.lock().unwrap().retain(|_, slot| {
102            if let Some(slot) = slot.upgrade() {
103                *slot.lock().unwrap() = None;
104                true
105            } else {
106                false
107            }
108        });
109    }
110
111    fn store_grads(
112        &self,
113        cotangents: &HashMap<GlobalValKey<StdTensorOp>, Arc<Tensor>>,
114        backend: &mut B,
115    ) -> Result<()> {
116        let mut updates = Vec::new();
117        let mut staged = Vec::new();
118
119        {
120            let mut slots = self.grad_slots.lock().unwrap();
121            slots.retain(|key, slot| {
122                let Some(slot) = slot.upgrade() else {
123                    return false;
124                };
125
126                if let Some(incoming) = cotangents.get(key) {
127                    updates.push((slot, Arc::clone(incoming)));
128                }
129
130                true
131            });
132        }
133
134        for (slot, incoming) in updates {
135            let next = {
136                let current = slot.lock().unwrap();
137                match current.as_ref() {
138                    Some(existing) => Arc::new(existing.as_ref().add(incoming.as_ref(), backend)?),
139                    None => incoming,
140                }
141            };
142            staged.push((slot, next));
143        }
144
145        for (slot, next) in staged {
146            *slot.lock().unwrap() = Some(next);
147        }
148
149        Ok(())
150    }
151}
152
153/// Eager tensor with reverse-mode autodiff over concrete tensor values.
154///
155/// This executes each primitive immediately and records a lightweight reverse
156/// DAG for `backward()`. Gradients accumulate across repeated `backward()`
157/// calls until they are cleared explicitly.
158///
159/// # Examples
160///
161/// ```
162/// use tenferro::{EagerTensor, Tensor};
163///
164/// let x = EagerTensor::requires_grad(Tensor::from_vec(vec![3], vec![1.0_f64, 2.0, 3.0]));
165/// let loss = (&x * &x).reduce_sum(&[0]).unwrap();
166/// let _cotangents = loss.backward().unwrap();
167/// let loss = (&x * &x).reduce_sum(&[0]).unwrap();
168/// let _cotangents = loss.backward().unwrap();
169///
170/// assert_eq!(x.grad().unwrap().as_slice::<f64>().unwrap(), &[4.0, 8.0, 12.0]);
171/// x.clear_grad();
172///
173/// assert!(x.grad().is_none());
174/// ```
175#[derive(Clone)]
176pub struct EagerTensor<B: TensorBackend = CpuBackend> {
177    pub(crate) data: Arc<Tensor>,
178    pub(crate) key: GlobalValKey<StdTensorOp>,
179    pub(crate) grad_node: Option<Arc<GradNode<StdTensorOp>>>,
180    pub(crate) requires_grad: bool,
181    grad_slot: GradSlot,
182    pub(crate) ctx: Arc<EagerContext<B>>,
183}
184
185impl<B: TensorBackend> std::ops::Add for &EagerTensor<B> {
186    type Output = EagerTensor<B>;
187
188    fn add(self, rhs: &EagerTensor<B>) -> Self::Output {
189        EagerTensor::add(self, rhs).unwrap_or_else(|err| panic!("eager add failed: {}", err))
190    }
191}
192
193impl<B: TensorBackend> std::ops::Mul for &EagerTensor<B> {
194    type Output = EagerTensor<B>;
195
196    fn mul(self, rhs: &EagerTensor<B>) -> Self::Output {
197        EagerTensor::mul(self, rhs).unwrap_or_else(|err| panic!("eager mul failed: {}", err))
198    }
199}
200
201impl<B: TensorBackend> std::ops::Neg for &EagerTensor<B> {
202    type Output = EagerTensor<B>;
203
204    fn neg(self) -> Self::Output {
205        EagerTensor::neg(self).unwrap_or_else(|err| panic!("eager neg failed: {}", err))
206    }
207}
208
209impl EagerTensor<CpuBackend> {
210    /// Create an untracked eager tensor on the default CPU backend.
211    ///
212    /// # Examples
213    ///
214    /// ```
215    /// use tenferro::{EagerTensor, Tensor};
216    ///
217    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![2], vec![1.0_f64, 2.0]));
218    /// assert_eq!(x.data().as_slice::<f64>().unwrap(), &[1.0, 2.0]);
219    /// assert!(x.grad().is_none());
220    /// ```
221    pub fn from_tensor(tensor: Tensor) -> Self {
222        Self::from_tensor_in(tensor, EagerContext::with_backend(CpuBackend::new()))
223    }
224
225    /// Create a tracked eager leaf on the default CPU backend.
226    ///
227    /// # Examples
228    ///
229    /// ```
230    /// use tenferro::{EagerTensor, Tensor};
231    ///
232    /// let x = EagerTensor::requires_grad(Tensor::from_vec(vec![2], vec![1.0_f64, 2.0]));
233    /// assert!(x.grad().is_none());
234    /// ```
235    pub fn requires_grad(tensor: Tensor) -> Self {
236        Self::requires_grad_in(tensor, EagerContext::with_backend(CpuBackend::new()))
237    }
238}
239
240impl<B: TensorBackend> EagerTensor<B> {
241    /// Create an untracked eager tensor inside an existing eager context.
242    ///
243    /// # Examples
244    ///
245    /// ```
246    /// use tenferro::{CpuBackend, EagerContext, EagerTensor, Tensor};
247    ///
248    /// let ctx = EagerContext::with_backend(CpuBackend::new());
249    /// let x = EagerTensor::from_tensor_in(Tensor::from_vec(vec![2], vec![1.0_f64, 2.0]), ctx);
250    ///
251    /// assert_eq!(x.data().as_slice::<f64>().unwrap(), &[1.0, 2.0]);
252    /// ```
253    pub fn from_tensor_in(tensor: Tensor, ctx: Arc<EagerContext<B>>) -> Self {
254        Self::new_leaf(ctx, tensor, false)
255    }
256
257    /// Create a tracked eager leaf inside an existing eager context.
258    ///
259    /// # Examples
260    ///
261    /// ```
262    /// use tenferro::{CpuBackend, EagerContext, EagerTensor, Tensor};
263    ///
264    /// let ctx = EagerContext::with_backend(CpuBackend::new());
265    /// let x = EagerTensor::requires_grad_in(Tensor::from_vec(vec![2], vec![1.0_f64, 2.0]), ctx);
266    ///
267    /// assert!(x.grad().is_none());
268    /// ```
269    pub fn requires_grad_in(tensor: Tensor, ctx: Arc<EagerContext<B>>) -> Self {
270        Self::new_leaf(ctx, tensor, true)
271    }
272
273    pub(crate) fn new_leaf(ctx: Arc<EagerContext<B>>, tensor: Tensor, requires_grad: bool) -> Self {
274        let key = eager_val_key();
275        let grad_slot = Arc::new(Mutex::new(None));
276        if requires_grad {
277            ctx.register_grad_slot(&key, &grad_slot);
278        }
279
280        Self {
281            data: Arc::new(tensor),
282            key,
283            grad_node: None,
284            requires_grad,
285            grad_slot,
286            ctx,
287        }
288    }
289
290    pub(crate) fn new_result(
291        ctx: Arc<EagerContext<B>>,
292        key: GlobalValKey<StdTensorOp>,
293        tensor: Tensor,
294        requires_grad: bool,
295        grad_node: Option<Arc<GradNode<StdTensorOp>>>,
296    ) -> Self {
297        let grad_slot = Arc::new(Mutex::new(None));
298        if requires_grad {
299            ctx.register_grad_slot(&key, &grad_slot);
300        }
301
302        Self {
303            data: Arc::new(tensor),
304            key,
305            grad_node,
306            requires_grad,
307            grad_slot,
308            ctx,
309        }
310    }
311
312    /// Detach this tensor from the reverse graph.
313    ///
314    /// The returned tensor keeps the concrete value but no longer contributes
315    /// gradients to the original graph.
316    ///
317    /// # Examples
318    ///
319    /// ```
320    /// use tenferro::{EagerTensor, Tensor};
321    ///
322    /// let x = EagerTensor::requires_grad(Tensor::from_vec(vec![2], vec![1.0_f64, 2.0]));
323    /// let y = x.detach();
324    ///
325    /// assert_eq!(y.data().as_slice::<f64>().unwrap(), &[1.0, 2.0]);
326    /// assert!(y.grad().is_none());
327    /// ```
328    pub fn detach(&self) -> Self {
329        Self::new_leaf(self.ctx.clone(), self.data.as_ref().clone(), false)
330    }
331
332    /// Borrow the concrete tensor value.
333    ///
334    /// # Examples
335    ///
336    /// ```
337    /// use tenferro::{EagerTensor, Tensor};
338    ///
339    /// let x = EagerTensor::from_tensor(Tensor::from_vec(vec![1], vec![3.0_f64]));
340    /// assert_eq!(x.data().as_slice::<f64>().unwrap(), &[3.0]);
341    /// ```
342    pub fn data(&self) -> &Tensor {
343        self.data.as_ref()
344    }
345
346    /// Return the accumulated gradient currently stored for this tensor.
347    ///
348    /// The stored gradient accumulates across repeated `backward()` calls
349    /// until it is cleared explicitly.
350    ///
351    /// # Examples
352    ///
353    /// ```
354    /// use tenferro::{EagerTensor, Tensor};
355    ///
356    /// let x = EagerTensor::requires_grad(Tensor::from_vec(vec![2], vec![1.0_f64, 2.0]));
357    /// let loss = x.exp().unwrap().reduce_sum(&[0]).unwrap();
358    /// let _cotangents = loss.backward().unwrap();
359    ///
360    /// let grad = x.grad().unwrap();
361    /// assert_eq!(grad.shape(), &[2]);
362    /// ```
363    pub fn grad(&self) -> Option<Arc<Tensor>> {
364        self.grad_slot.lock().unwrap().clone()
365    }
366
367    /// Clear the accumulated gradient stored for this tensor.
368    ///
369    /// This only affects this tensor's gradient slot. Other tensors in the
370    /// same context retain their gradients until they are cleared explicitly or
371    /// overwritten by later accumulation.
372    ///
373    /// # Examples
374    ///
375    /// ```
376    /// use tenferro::{CpuBackend, EagerContext, EagerTensor, Tensor};
377    ///
378    /// let ctx = EagerContext::with_backend(CpuBackend::new());
379    /// let x = EagerTensor::requires_grad_in(Tensor::from_vec(vec![3], vec![1.0_f64, 2.0, 3.0]), ctx.clone());
380    /// let y = EagerTensor::requires_grad_in(Tensor::from_vec(vec![3], vec![4.0_f64, 5.0, 6.0]), ctx);
381    /// let loss = (&x * &y).reduce_sum(&[0]).unwrap();
382    /// let _ = loss.backward().unwrap();
383    ///
384    /// x.clear_grad();
385    ///
386    /// assert!(x.grad().is_none());
387    /// assert!(y.grad().is_some());
388    /// ```
389    pub fn clear_grad(&self) {
390        *self.grad_slot.lock().unwrap() = None;
391    }
392
393    /// Report whether this tensor participates in gradient tracking.
394    ///
395    /// Tracked tensors keep a gradient slot in their eager context; untracked
396    /// tensors and detached tensors do not.
397    ///
398    /// # Examples
399    ///
400    /// ```
401    /// use tenferro::{CpuBackend, EagerContext, EagerTensor, Tensor};
402    ///
403    /// let ctx = EagerContext::with_backend(CpuBackend::new());
404    /// let plain = EagerTensor::from_tensor_in(Tensor::from_vec(vec![2], vec![1.0_f64, 2.0]), ctx.clone());
405    /// let tracked = EagerTensor::requires_grad_in(Tensor::from_vec(vec![2], vec![3.0_f64, 4.0]), ctx.clone());
406    /// let detached = tracked.detach();
407    ///
408    /// assert!(!plain.tracks_grad());
409    /// assert!(tracked.tracks_grad());
410    /// assert!(!detached.tracks_grad());
411    /// ```
412    pub fn tracks_grad(&self) -> bool {
413        self.requires_grad
414    }
415
416    /// Run reverse-mode AD from this scalar output.
417    ///
418    /// Returns the full cotangent map produced by the reverse pass and also
419    /// accumulates into `grad()` for tracked eager tensors reachable from this
420    /// output.
421    ///
422    /// # Examples
423    ///
424    /// ```
425    /// use tenferro::{EagerTensor, Tensor};
426    ///
427    /// let x = EagerTensor::requires_grad(Tensor::from_vec(vec![3], vec![1.0_f64, 2.0, 3.0]));
428    /// let loss = (&x + &x).reduce_sum(&[0]).unwrap();
429    /// let _cotangents = loss.backward().unwrap();
430    /// let loss = (&x + &x).reduce_sum(&[0]).unwrap();
431    /// let _cotangents = loss.backward().unwrap();
432    ///
433    /// assert_eq!(x.grad().unwrap().as_slice::<f64>().unwrap(), &[4.0, 4.0, 4.0]);
434    /// ```
435    pub fn backward(&self) -> Result<HashMap<GlobalValKey<StdTensorOp>, Arc<Tensor>>> {
436        if !self.data.shape().is_empty() {
437            return Err(Error::NonScalarGrad {
438                shape: self.data.shape().to_vec(),
439            });
440        }
441
442        let sorted = topo_sort_grad_dag(&self.grad_node);
443        let mut backend = self.ctx.backend.lock().unwrap();
444        let seed = Arc::new(one_like_tensor(self.data.as_ref(), &mut *backend));
445        let mut callbacks = TenferroBackwardCallbacks {
446            backend: &mut *backend,
447        };
448        let mut ad_ctx = ShapeGuardContext::default();
449        let cotangents = backward_dag(&sorted, &self.key, seed, &mut callbacks, &mut ad_ctx);
450        self.ctx.store_grads(&cotangents, &mut *backend)?;
451        Ok(cotangents)
452    }
453}
454
455pub(crate) struct TenferroBackwardCallbacks<'a, B: TensorBackend> {
456    backend: &'a mut B,
457}
458
459impl<B: TensorBackend> BackwardCallbacks<StdTensorOp> for TenferroBackwardCallbacks<'_, B> {
460    fn execute_forward(
461        &mut self,
462        fragment: &Fragment<StdTensorOp>,
463        initial_data: &HashMap<GlobalValKey<StdTensorOp>, Arc<Tensor>>,
464    ) -> HashMap<GlobalValKey<StdTensorOp>, Arc<Tensor>> {
465        let mut all_values = initial_data.clone();
466
467        for &input_id in fragment.inputs() {
468            let key = fragment.vals()[input_id].key.clone();
469            all_values.entry(key.clone()).or_insert_with(|| {
470                let GlobalValKey::Input(tangent_key) = &key else {
471                    panic!("expected input key for eager forward: {:?}", key);
472                };
473                let tenferro_ops::input_key::TensorInputKey::Tangent { of, .. } = tangent_key
474                else {
475                    panic!("missing concrete eager value for {:?}", key);
476                };
477                let base_key = GlobalValKey::Input((**of).clone());
478                let base = initial_data
479                    .get(&base_key)
480                    .unwrap_or_else(|| panic!("missing base eager value for {:?}", base_key));
481                Arc::new(zero_like_tensor(base.as_ref(), self.backend))
482            });
483        }
484
485        for op_node in fragment.ops() {
486            let resolved_inputs: Vec<&Tensor> = op_node
487                .inputs
488                .iter()
489                .map(|input| match input {
490                    ValRef::Local(local_id) => {
491                        let key = &fragment.vals()[*local_id].key;
492                        all_values
493                            .get(key)
494                            .unwrap_or_else(|| panic!("missing eager value for local {:?}", key))
495                            .as_ref()
496                    }
497                    ValRef::External(key) => all_values
498                        .get(key)
499                        .unwrap_or_else(|| panic!("missing eager value for external {:?}", key))
500                        .as_ref(),
501                })
502                .collect();
503            let outputs = exec_op_on_tensors(&op_node.op, &resolved_inputs, self.backend)
504                .unwrap_or_else(|err| {
505                    panic!("eager forward exec failed for {:?}: {}", op_node.op, err)
506                });
507
508            for (output_id, output) in op_node.outputs.iter().zip(outputs.into_iter()) {
509                let key = fragment.vals()[*output_id].key.clone();
510                all_values.insert(key, Arc::new(output));
511            }
512        }
513
514        all_values
515    }
516
517    fn eager_transpose(
518        &mut self,
519        linear: &LinearFragment<StdTensorOp>,
520        cotangent_out: &[Option<Arc<Tensor>>],
521        external_data: &HashMap<GlobalValKey<StdTensorOp>, Arc<Tensor>>,
522        ctx: &mut ShapeGuardContext,
523    ) -> Vec<Option<Arc<Tensor>>> {
524        let mut emitter = EagerEmitter::new(self.backend);
525        emitter.external_data = external_data.clone();
526        let cotangent_seed_ids = cotangent_out
527            .iter()
528            .map(|maybe_seed| {
529                maybe_seed
530                    .as_ref()
531                    .map(|seed| emitter.push_tensor(Arc::clone(seed)))
532            })
533            .collect::<Vec<_>>();
534
535        tidu::eager_transpose_fragment(linear, &mut emitter, &cotangent_seed_ids, ctx)
536            .into_iter()
537            .map(|maybe_id| maybe_id.map(|id| emitter.tensor(id)))
538            .collect()
539    }
540
541    fn add_operands(&mut self, a: &Arc<Tensor>, b: &Arc<Tensor>) -> Arc<Tensor> {
542        Arc::new(
543            a.as_ref()
544                .add(b.as_ref(), self.backend)
545                .unwrap_or_else(|err| panic!("eager cotangent add failed: {}", err)),
546        )
547    }
548}
549
550pub(crate) fn eager_val_key() -> GlobalValKey<StdTensorOp> {
551    GlobalValKey::Input(next_input_key())
552}
553
554pub(crate) fn saved_forward_values(
555    op: &StdTensorOp,
556    input_keys: &[GlobalValKey<StdTensorOp>],
557    inputs: &[Arc<Tensor>],
558    output: Arc<Tensor>,
559) -> HashMap<GlobalValKey<StdTensorOp>, Arc<Tensor>> {
560    let mut saved = HashMap::with_capacity(input_keys.len() + 1);
561    for (key, value) in input_keys.iter().zip(inputs.iter()) {
562        saved.insert(key.clone(), Arc::clone(value));
563    }
564    saved.insert(derived_output_key(op, input_keys, 0), output);
565    saved
566}
567
568pub(crate) fn saved_forward_values_multi(
569    op: &StdTensorOp,
570    input_keys: &[GlobalValKey<StdTensorOp>],
571    inputs: &[Arc<Tensor>],
572    num_outputs: usize,
573    outputs: &[Arc<Tensor>],
574) -> HashMap<GlobalValKey<StdTensorOp>, Arc<Tensor>> {
575    let mut saved = HashMap::with_capacity(input_keys.len() + num_outputs);
576    for (key, value) in input_keys.iter().zip(inputs.iter()) {
577        saved.insert(key.clone(), Arc::clone(value));
578    }
579    for slot in 0..num_outputs {
580        saved.insert(
581            derived_output_key(op, input_keys, slot),
582            Arc::clone(&outputs[slot]),
583        );
584    }
585    saved
586}
587
588pub(crate) fn derived_output_key(
589    op: &StdTensorOp,
590    input_keys: &[GlobalValKey<StdTensorOp>],
591    output_slot: usize,
592) -> GlobalValKey<StdTensorOp> {
593    GlobalValKey::Derived {
594        op: GlobalOpKey {
595            primitive: op.clone(),
596            inputs: input_keys.to_vec(),
597            mode: OpMode::Primal,
598        },
599        output_slot: output_slot as u8,
600    }
601}
602
603pub(crate) fn exec_single_output<B: TensorBackend>(
604    op: &StdTensorOp,
605    inputs: &[&Tensor],
606    ctx: &EagerContext<B>,
607) -> Result<Tensor> {
608    let mut backend = ctx.backend.lock().unwrap();
609    let mut outputs = exec_op_on_tensors(op, inputs, &mut *backend)?;
610    if outputs.len() != 1 {
611        return Err(Error::Internal(format!(
612            "expected one eager output for {:?}, got {}",
613            op,
614            outputs.len()
615        )));
616    }
617    Ok(outputs.remove(0))
618}
619
620pub(crate) fn zero_like_tensor<B: TensorBackend>(input: &Tensor, backend: &mut B) -> Tensor {
621    let neg = input
622        .neg(backend)
623        .unwrap_or_else(|err| panic!("zero_like neg failed: {}", err));
624    input
625        .add(&neg, backend)
626        .unwrap_or_else(|err| panic!("zero_like add failed: {}", err))
627}
628
629pub(crate) fn one_like_tensor<B: TensorBackend>(input: &Tensor, backend: &mut B) -> Tensor {
630    let zero = zero_like_tensor(input, backend);
631    backend
632        .exp(&zero)
633        .unwrap_or_else(|err| panic!("one_like exp failed: {}", err))
634}