1use std::ops::Add;
2use std::sync::Arc;
3
4use num_complex::{Complex32, Complex64};
5use num_traits::Zero;
6use tenferro_algebra::Scalar;
7use tenferro_internal_ad_core::{
8 AdResult, AutodiffError, CheckpointHint, DynValue, LinearizableOp, LinearizedOp, Schema,
9 SlotSchema,
10};
11use tenferro_internal_frontend_core::tensor_ops::{
12 tensor_element, tensor_map_binary_typed, tensor_map_unary_typed,
13};
14use tenferro_internal_frontend_core::{DynTensor, DynTensorTyped, StructuredTensor};
15use tenferro_tensor::{MemoryOrder, Tensor as DenseTensor};
16
17use crate::math::{einsum_frule, einsum_primal, einsum_rrule};
18use crate::{Error, Result};
19
20#[derive(Clone, Copy)]
21pub struct AddOp;
22
23#[derive(Clone, Copy)]
24pub struct ExpOp;
25
26#[derive(Clone, Copy)]
27pub struct SumOp;
28
29#[derive(Clone)]
30pub struct EinsumOp {
31 subscripts: Arc<str>,
32}
33
34#[doc(hidden)]
35pub struct AddLinearized;
36
37#[doc(hidden)]
38pub struct ExpLinearized {
39 output: DynTensor,
40}
41
42#[doc(hidden)]
43pub struct SumLinearized {
44 input: DynTensor,
45}
46
47#[doc(hidden)]
48pub struct EinsumLinearized {
49 subscripts: Arc<str>,
50 inputs: Vec<DynTensor>,
51}
52
53fn differentiable_schema(slots: usize) -> Schema {
54 Schema {
55 slots: (0..slots)
56 .map(|_| SlotSchema {
57 differentiable: true,
58 auxiliary: false,
59 })
60 .collect(),
61 }
62}
63
64fn invalid_argument(message: impl Into<String>) -> Error {
65 AutodiffError::InvalidArgument(message.into()).into()
66}
67
68fn into_ad_error(error: Error) -> AutodiffError {
69 match error {
70 Error::Autodiff(error) => error,
71 other => AutodiffError::InvalidArgument(other.to_string()),
72 }
73}
74
75fn structured_binary<T>(
76 lhs: &StructuredTensor<T>,
77 rhs: &StructuredTensor<T>,
78 f: impl FnMut(T, T) -> T,
79) -> Result<StructuredTensor<T>>
80where
81 T: Scalar + Copy,
82{
83 lhs.with_payload_like(tensor_map_binary_typed(lhs.payload(), rhs.payload(), f)?)
84}
85
86fn structured_unary<T, U>(
87 input: &StructuredTensor<T>,
88 f: impl FnMut(T) -> U,
89) -> Result<StructuredTensor<U>>
90where
91 T: Scalar + Copy,
92 U: Scalar + Copy,
93{
94 let payload = tensor_map_unary_typed(input.payload(), f)?;
95 Ok(StructuredTensor::from(payload))
96}
97
98fn dense_host_slice<'a, T>(tensor: &'a DenseTensor<T>, context: &str) -> Result<&'a [T]> {
99 tensor.buffer().as_slice().ok_or_else(|| {
100 invalid_argument(format!("{context} requires host-accessible dense payload"))
101 })
102}
103
104fn scalar_from_rank0<T>(value: &StructuredTensor<T>, context: &str) -> Result<T>
105where
106 T: Scalar + Copy,
107{
108 if !value.logical_dims().is_empty() {
109 return Err(invalid_argument(format!(
110 "{context} requires a rank-0 tensor, got {:?}",
111 value.logical_dims()
112 )));
113 }
114 tensor_element(value.payload(), &[])
115}
116
117fn structured_sum_all<T>(input: &StructuredTensor<T>) -> Result<StructuredTensor<T>>
118where
119 T: Scalar + Copy + Zero + Add<Output = T>,
120{
121 let dense = input.to_dense()?;
122 let mut acc = T::zero();
123 for &value in dense_host_slice(&dense, "sum")? {
124 acc = acc + value;
125 }
126 let payload = DenseTensor::from_slice(&[acc], &[], MemoryOrder::ColumnMajor)?;
127 Ok(StructuredTensor::from(payload))
128}
129
130fn structured_broadcast_scalar_like<T>(
131 scalar: &StructuredTensor<T>,
132 like: &StructuredTensor<T>,
133) -> Result<StructuredTensor<T>>
134where
135 T: Scalar + Copy,
136{
137 let value = scalar_from_rank0(scalar, "broadcast_scalar_like")?;
138 let total = like.logical_dims().iter().product();
139 let payload = DenseTensor::from_slice(
140 &vec![value; total],
141 like.logical_dims(),
142 MemoryOrder::ColumnMajor,
143 )?;
144 like.with_payload_like(payload)
145}
146
147fn dyn_add(lhs: &DynTensor, rhs: &DynTensor) -> Result<DynTensor> {
148 match (lhs, rhs) {
149 (DynTensor::F32(lhs), DynTensor::F32(rhs)) => {
150 Ok(DynTensor::F32(structured_binary(lhs, rhs, |x, y| x + y)?))
151 }
152 (DynTensor::F64(lhs), DynTensor::F64(rhs)) => {
153 Ok(DynTensor::F64(structured_binary(lhs, rhs, |x, y| x + y)?))
154 }
155 (DynTensor::C32(lhs), DynTensor::C32(rhs)) => {
156 Ok(DynTensor::C32(structured_binary(lhs, rhs, |x, y| x + y)?))
157 }
158 (DynTensor::C64(lhs), DynTensor::C64(rhs)) => {
159 Ok(DynTensor::C64(structured_binary(lhs, rhs, |x, y| x + y)?))
160 }
161 _ => Err(invalid_argument(format!(
162 "add requires matching dtypes, got lhs={:?}, rhs={:?}",
163 lhs.scalar_type(),
164 rhs.scalar_type()
165 ))),
166 }
167}
168
169fn dyn_mul(lhs: &DynTensor, rhs: &DynTensor) -> Result<DynTensor> {
170 match (lhs, rhs) {
171 (DynTensor::F32(lhs), DynTensor::F32(rhs)) => {
172 Ok(DynTensor::F32(structured_binary(lhs, rhs, |x, y| x * y)?))
173 }
174 (DynTensor::F64(lhs), DynTensor::F64(rhs)) => {
175 Ok(DynTensor::F64(structured_binary(lhs, rhs, |x, y| x * y)?))
176 }
177 (DynTensor::C32(lhs), DynTensor::C32(rhs)) => {
178 Ok(DynTensor::C32(structured_binary(lhs, rhs, |x, y| x * y)?))
179 }
180 (DynTensor::C64(lhs), DynTensor::C64(rhs)) => {
181 Ok(DynTensor::C64(structured_binary(lhs, rhs, |x, y| x * y)?))
182 }
183 _ => Err(invalid_argument(format!(
184 "mul requires matching dtypes, got lhs={:?}, rhs={:?}",
185 lhs.scalar_type(),
186 rhs.scalar_type()
187 ))),
188 }
189}
190
191fn dyn_exp(input: &DynTensor) -> Result<DynTensor> {
192 match input {
193 DynTensor::F32(value) => Ok(DynTensor::F32(structured_unary(value, |x: f32| x.exp())?)),
194 DynTensor::F64(value) => Ok(DynTensor::F64(structured_unary(value, |x: f64| x.exp())?)),
195 DynTensor::C32(value) => Ok(DynTensor::C32(structured_unary(value, |z: Complex32| {
196 z.exp()
197 })?)),
198 DynTensor::C64(value) => Ok(DynTensor::C64(structured_unary(value, |z: Complex64| {
199 z.exp()
200 })?)),
201 }
202}
203
204fn dyn_sum_all(input: &DynTensor) -> Result<DynTensor> {
205 match input {
206 DynTensor::F32(value) => Ok(DynTensor::F32(structured_sum_all(value)?)),
207 DynTensor::F64(value) => Ok(DynTensor::F64(structured_sum_all(value)?)),
208 DynTensor::C32(value) => Ok(DynTensor::C32(structured_sum_all(value)?)),
209 DynTensor::C64(value) => Ok(DynTensor::C64(structured_sum_all(value)?)),
210 }
211}
212
213fn dyn_broadcast_scalar_like(scalar: &DynTensor, like: &DynTensor) -> Result<DynTensor> {
214 match (scalar, like) {
215 (DynTensor::F32(scalar), DynTensor::F32(like)) => Ok(DynTensor::F32(
216 structured_broadcast_scalar_like(scalar, like)?,
217 )),
218 (DynTensor::F64(scalar), DynTensor::F64(like)) => Ok(DynTensor::F64(
219 structured_broadcast_scalar_like(scalar, like)?,
220 )),
221 (DynTensor::C32(scalar), DynTensor::C32(like)) => Ok(DynTensor::C32(
222 structured_broadcast_scalar_like(scalar, like)?,
223 )),
224 (DynTensor::C64(scalar), DynTensor::C64(like)) => Ok(DynTensor::C64(
225 structured_broadcast_scalar_like(scalar, like)?,
226 )),
227 _ => Err(invalid_argument(format!(
228 "broadcast requires matching dtypes, got scalar={:?}, like={:?}",
229 scalar.scalar_type(),
230 like.scalar_type()
231 ))),
232 }
233}
234
235fn dense_dyn_tensor_typed<T>(value: &DynTensor, context: &str) -> Result<DenseTensor<T>>
236where
237 T: DynTensorTyped + Copy,
238{
239 let structured = T::structured_ref(value)
240 .ok_or_else(|| invalid_argument(format!("{context} requires matching dtypes")))?;
241 structured.to_dense()
242}
243
244fn collect_dense_dyn_tensors<T>(values: &[&DynTensor], context: &str) -> Result<Vec<DenseTensor<T>>>
245where
246 T: DynTensorTyped + Copy,
247{
248 values
249 .iter()
250 .map(|value| dense_dyn_tensor_typed::<T>(value, context))
251 .collect()
252}
253
254fn optional_dense_dyn_tensor_typed<T>(
255 value: &Option<DynTensor>,
256 context: &str,
257) -> Result<Option<DenseTensor<T>>>
258where
259 T: DynTensorTyped + Copy,
260{
261 value
262 .as_ref()
263 .map(|tensor| dense_dyn_tensor_typed::<T>(tensor, context))
264 .transpose()
265}
266
267fn collect_optional_dense_dyn_tensors<T>(
268 values: &[Option<DynTensor>],
269 context: &str,
270) -> Result<Vec<Option<DenseTensor<T>>>>
271where
272 T: DynTensorTyped + Copy,
273{
274 values
275 .iter()
276 .map(|value| optional_dense_dyn_tensor_typed::<T>(value, context))
277 .collect()
278}
279
280fn dyn_from_dense<T>(value: DenseTensor<T>) -> DynTensor
281where
282 T: DynTensorTyped + Copy,
283{
284 T::into_dyn(StructuredTensor::from(value))
285}
286
287fn dyn_einsum_primal_t<T>(subscripts: &str, inputs: &[&DynTensor]) -> Result<DynTensor>
288where
289 T: crate::runtime::contracts::EinsumRuntimeValue + DynTensorTyped + Copy,
290{
291 let dense_inputs = collect_dense_dyn_tensors::<T>(inputs, "einsum")?;
292 let input_refs: Vec<&DenseTensor<T>> = dense_inputs.iter().collect();
293 let output = einsum_primal(subscripts, &input_refs)?;
294 Ok(dyn_from_dense(output))
295}
296
297fn dyn_einsum_jvp_t<T>(
298 subscripts: &str,
299 primals: &[DynTensor],
300 tangents: &[Option<DynTensor>],
301) -> Result<Option<DynTensor>>
302where
303 T: crate::runtime::contracts::EinsumRuntimeValue + DynTensorTyped + Copy,
304{
305 if tangents.iter().all(Option::is_none) {
306 return Ok(None);
307 }
308 let primal_refs: Vec<&DynTensor> = primals.iter().collect();
309 let dense_primals = collect_dense_dyn_tensors::<T>(&primal_refs, "einsum_jvp")?;
310 let dense_tangents = collect_optional_dense_dyn_tensors::<T>(tangents, "einsum_jvp")?;
311 let primal_refs: Vec<&DenseTensor<T>> = dense_primals.iter().collect();
312 let tangent_refs: Vec<Option<&DenseTensor<T>>> =
313 dense_tangents.iter().map(Option::as_ref).collect();
314 let tangent = einsum_frule(subscripts, &primal_refs, &tangent_refs)?;
315 Ok(Some(dyn_from_dense(tangent)))
316}
317
318fn dyn_einsum_vjp_t<T>(
319 subscripts: &str,
320 inputs: &[DynTensor],
321 cotangent: &DynTensor,
322 input_grad_mask: &[bool],
323) -> Result<Vec<Option<DynTensor>>>
324where
325 T: crate::runtime::contracts::EinsumRuntimeValue + DynTensorTyped + Copy,
326{
327 let input_refs: Vec<&DynTensor> = inputs.iter().collect();
328 let dense_inputs = collect_dense_dyn_tensors::<T>(&input_refs, "einsum_vjp")?;
329 let input_refs: Vec<&DenseTensor<T>> = dense_inputs.iter().collect();
330 let dense_cotangent = dense_dyn_tensor_typed::<T>(cotangent, "einsum_vjp")?;
331 let grads = einsum_rrule(subscripts, &input_refs, &dense_cotangent)?;
332 Ok(grads
333 .into_iter()
334 .zip(input_grad_mask.iter().copied())
335 .map(|(grad, needed)| needed.then(|| dyn_from_dense(grad)))
336 .collect())
337}
338
339impl LinearizableOp<DynTensor> for AddOp {
340 type Linearized = AddLinearized;
341
342 fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
343 Ok(vec![dyn_add(inputs[0], inputs[1]).map_err(into_ad_error)?])
344 }
345
346 fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
347 Ok(differentiable_schema(2))
348 }
349
350 fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
351 Ok(differentiable_schema(1))
352 }
353
354 fn linearize(
355 &self,
356 _inputs: &[&DynTensor],
357 _outputs: &[DynTensor],
358 ) -> AdResult<Self::Linearized> {
359 Ok(AddLinearized)
360 }
361
362 fn checkpoint_hint(&self) -> CheckpointHint {
363 CheckpointHint::CheapReplay
364 }
365}
366
367impl LinearizedOp<DynTensor> for AddLinearized {
368 fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
369 let tangent = match (&input_tangents[0], &input_tangents[1]) {
370 (None, None) => None,
371 (Some(lhs), None) => Some(lhs.clone()),
372 (None, Some(rhs)) => Some(rhs.clone()),
373 (Some(lhs), Some(rhs)) => Some(dyn_add(lhs, rhs).map_err(into_ad_error)?),
374 };
375 Ok(vec![tangent])
376 }
377
378 fn vjp(
379 &self,
380 output_cotangents: &[Option<DynTensor>],
381 input_grad_mask: &[bool],
382 ) -> AdResult<Vec<Option<DynTensor>>> {
383 let grad = output_cotangents[0].clone();
384 Ok(vec![
385 input_grad_mask[0].then(|| grad.clone()).flatten(),
386 input_grad_mask[1].then_some(grad).flatten(),
387 ])
388 }
389}
390
391impl LinearizableOp<DynTensor> for ExpOp {
392 type Linearized = ExpLinearized;
393
394 fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
395 Ok(vec![dyn_exp(inputs[0]).map_err(into_ad_error)?])
396 }
397
398 fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
399 Ok(differentiable_schema(1))
400 }
401
402 fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
403 Ok(differentiable_schema(1))
404 }
405
406 fn linearize(
407 &self,
408 _inputs: &[&DynTensor],
409 outputs: &[DynTensor],
410 ) -> AdResult<Self::Linearized> {
411 Ok(ExpLinearized {
412 output: outputs[0].clone(),
413 })
414 }
415
416 fn checkpoint_hint(&self) -> CheckpointHint {
417 CheckpointHint::CheapReplay
418 }
419}
420
421impl LinearizedOp<DynTensor> for ExpLinearized {
422 fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
423 Ok(vec![match &input_tangents[0] {
424 Some(tangent) => Some(dyn_mul(&self.output, tangent).map_err(into_ad_error)?),
425 None => None,
426 }])
427 }
428
429 fn vjp(
430 &self,
431 output_cotangents: &[Option<DynTensor>],
432 input_grad_mask: &[bool],
433 ) -> AdResult<Vec<Option<DynTensor>>> {
434 Ok(vec![if input_grad_mask[0] {
435 match &output_cotangents[0] {
436 Some(grad_out) => Some(dyn_mul(&self.output, grad_out).map_err(into_ad_error)?),
437 None => None,
438 }
439 } else {
440 None
441 }])
442 }
443}
444
445impl LinearizableOp<DynTensor> for SumOp {
446 type Linearized = SumLinearized;
447
448 fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
449 Ok(vec![dyn_sum_all(inputs[0]).map_err(into_ad_error)?])
450 }
451
452 fn input_schema(&self, _inputs: &[&DynTensor]) -> AdResult<Schema> {
453 Ok(differentiable_schema(1))
454 }
455
456 fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
457 Ok(differentiable_schema(1))
458 }
459
460 fn linearize(
461 &self,
462 inputs: &[&DynTensor],
463 _outputs: &[DynTensor],
464 ) -> AdResult<Self::Linearized> {
465 Ok(SumLinearized {
466 input: inputs[0].clone(),
467 })
468 }
469
470 fn checkpoint_hint(&self) -> CheckpointHint {
471 CheckpointHint::CheapReplay
472 }
473}
474
475impl LinearizedOp<DynTensor> for SumLinearized {
476 fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
477 Ok(vec![match &input_tangents[0] {
478 Some(tangent) => Some(dyn_sum_all(tangent).map_err(into_ad_error)?),
479 None => None,
480 }])
481 }
482
483 fn vjp(
484 &self,
485 output_cotangents: &[Option<DynTensor>],
486 input_grad_mask: &[bool],
487 ) -> AdResult<Vec<Option<DynTensor>>> {
488 Ok(vec![if input_grad_mask[0] {
489 match &output_cotangents[0] {
490 Some(grad_out) => {
491 Some(dyn_broadcast_scalar_like(grad_out, &self.input).map_err(into_ad_error)?)
492 }
493 None => None,
494 }
495 } else {
496 None
497 }])
498 }
499}
500
501impl EinsumOp {
502 pub fn new(subscripts: impl Into<String>) -> Self {
503 Self {
504 subscripts: Arc::<str>::from(subscripts.into()),
505 }
506 }
507}
508
509impl LinearizableOp<DynTensor> for EinsumOp {
510 type Linearized = EinsumLinearized;
511
512 fn primal(&self, inputs: &[&DynTensor]) -> AdResult<Vec<DynTensor>> {
513 let output = match inputs.first() {
514 Some(DynTensor::F32(_)) => dyn_einsum_primal_t::<f32>(&self.subscripts, inputs),
515 Some(DynTensor::F64(_)) => dyn_einsum_primal_t::<f64>(&self.subscripts, inputs),
516 Some(DynTensor::C32(_)) => dyn_einsum_primal_t::<Complex32>(&self.subscripts, inputs),
517 Some(DynTensor::C64(_)) => dyn_einsum_primal_t::<Complex64>(&self.subscripts, inputs),
518 None => Err(invalid_argument("einsum requires at least one input")),
519 }
520 .map_err(into_ad_error)?;
521 Ok(vec![output])
522 }
523
524 fn input_schema(&self, inputs: &[&DynTensor]) -> AdResult<Schema> {
525 Ok(differentiable_schema(inputs.len()))
526 }
527
528 fn output_schema(&self, _inputs: &[&DynTensor], _outputs: &[DynTensor]) -> AdResult<Schema> {
529 Ok(differentiable_schema(1))
530 }
531
532 fn linearize(
533 &self,
534 inputs: &[&DynTensor],
535 _outputs: &[DynTensor],
536 ) -> AdResult<Self::Linearized> {
537 Ok(EinsumLinearized {
538 subscripts: self.subscripts.clone(),
539 inputs: inputs.iter().map(|input| (*input).clone()).collect(),
540 })
541 }
542
543 fn checkpoint_hint(&self) -> CheckpointHint {
544 CheckpointHint::ExpensiveReplay
545 }
546}
547
548impl LinearizedOp<DynTensor> for EinsumLinearized {
549 fn jvp(&self, input_tangents: &[Option<DynTensor>]) -> AdResult<Vec<Option<DynTensor>>> {
550 let tangent = match self.inputs.first() {
551 Some(DynTensor::F32(_)) => {
552 dyn_einsum_jvp_t::<f32>(&self.subscripts, &self.inputs, input_tangents)
553 }
554 Some(DynTensor::F64(_)) => {
555 dyn_einsum_jvp_t::<f64>(&self.subscripts, &self.inputs, input_tangents)
556 }
557 Some(DynTensor::C32(_)) => {
558 dyn_einsum_jvp_t::<Complex32>(&self.subscripts, &self.inputs, input_tangents)
559 }
560 Some(DynTensor::C64(_)) => {
561 dyn_einsum_jvp_t::<Complex64>(&self.subscripts, &self.inputs, input_tangents)
562 }
563 None => Err(invalid_argument(
564 "einsum linearization requires at least one input",
565 )),
566 }
567 .map_err(into_ad_error)?;
568 Ok(vec![tangent])
569 }
570
571 fn vjp(
572 &self,
573 output_cotangents: &[Option<DynTensor>],
574 input_grad_mask: &[bool],
575 ) -> AdResult<Vec<Option<DynTensor>>> {
576 let Some(cotangent) = output_cotangents[0].as_ref() else {
577 return Ok((0..self.inputs.len()).map(|_| None).collect());
578 };
579 match self.inputs.first() {
580 Some(DynTensor::F32(_)) => {
581 dyn_einsum_vjp_t::<f32>(&self.subscripts, &self.inputs, cotangent, input_grad_mask)
582 }
583 Some(DynTensor::F64(_)) => {
584 dyn_einsum_vjp_t::<f64>(&self.subscripts, &self.inputs, cotangent, input_grad_mask)
585 }
586 Some(DynTensor::C32(_)) => dyn_einsum_vjp_t::<Complex32>(
587 &self.subscripts,
588 &self.inputs,
589 cotangent,
590 input_grad_mask,
591 ),
592 Some(DynTensor::C64(_)) => dyn_einsum_vjp_t::<Complex64>(
593 &self.subscripts,
594 &self.inputs,
595 cotangent,
596 input_grad_mask,
597 ),
598 None => Err(invalid_argument(
599 "einsum linearization requires at least one input",
600 )),
601 }
602 .map_err(into_ad_error)
603 }
604}
605
606pub fn add_dyn_values(lhs: &DynValue, rhs: &DynValue) -> AdResult<DynValue> {
607 AddOp.apply_one(&[lhs, rhs])
608}
609
610pub fn exp_dyn_value(input: &DynValue) -> AdResult<DynValue> {
611 ExpOp.apply_one(&[input])
612}
613
614pub fn sum_dyn_value(input: &DynValue) -> AdResult<DynValue> {
615 SumOp.apply_one(&[input])
616}
617
618pub fn einsum_dyn_values(subscripts: &str, inputs: &[&DynValue]) -> AdResult<DynValue> {
619 EinsumOp::new(subscripts).apply_one(inputs)
620}