tenferro_ext_tropical_capi/
einsum_api.rs

1use std::os::raw::c_char;
2use std::panic::{catch_unwind, AssertUnwindSafe};
3
4use tenferro_algebra::Conjugate;
5use tenferro_capi::{tfe_status_t, TfeTensorF64};
6use tenferro_einsum::{einsum, EinsumBackend};
7use tenferro_ext_tropical::ad::{
8    extract_inner, promote_to_tropical, tropical_einsum_frule, tropical_einsum_rrule,
9    TropicalScalar,
10};
11use tenferro_ext_tropical::{
12    MaxMul, MaxMulAlgebra, MaxPlus, MaxPlusAlgebra, MinPlus, MinPlusAlgebra,
13};
14use tenferro_prims::{CpuBackend, CpuContext, TensorSemiringCore, TensorSemiringFastPath};
15use tenferro_tensor::Tensor;
16
17use crate::ffi_utils::{
18    collect_operand_handles, collect_optional_tangent_handles, cpu_context, parse_subscripts,
19};
20use crate::handle::{handle_to_ref, tensor_to_handle};
21use crate::status::{finalize_ptr, finalize_void, map_device_error};
22
23unsafe fn tropical_einsum_impl<T, Alg>(
24    subscripts: *const c_char,
25    operands: *const *const TfeTensorF64,
26    num_operands: usize,
27) -> std::result::Result<*mut TfeTensorF64, tfe_status_t>
28where
29    T: TropicalScalar<Inner = f64> + Conjugate + tenferro_algebra::HasAlgebra<Algebra = Alg>,
30    Alg: tenferro_algebra::Semiring<Scalar = T>,
31    CpuBackend: EinsumBackend<Alg>
32        + TensorSemiringCore<Alg, Context = CpuContext>
33        + TensorSemiringFastPath<
34            Alg,
35            Context = CpuContext,
36            Plan = <CpuBackend as TensorSemiringCore<Alg>>::Plan,
37        >,
38{
39    let subscripts = parse_subscripts(subscripts)?;
40    let operands = collect_operand_handles(operands, num_operands)?;
41    let tropical_operands: Vec<Tensor<T>> = operands
42        .iter()
43        .map(|tensor| promote_to_tropical::<T>(*tensor).map_err(|err| map_device_error(&err)))
44        .collect::<std::result::Result<_, _>>()?;
45    let tropical_refs: Vec<&Tensor<T>> = tropical_operands.iter().collect();
46    let mut ctx = cpu_context()?;
47    let output = einsum::<Alg, CpuBackend>(&mut ctx, subscripts, &tropical_refs, None)
48        .map_err(|err| map_device_error(&err))?;
49    let output = extract_inner::<T>(&output).map_err(|err| map_device_error(&err))?;
50    Ok(tensor_to_handle(output))
51}
52
53unsafe fn tropical_einsum_rrule_impl<T, Alg>(
54    subscripts: *const c_char,
55    operands: *const *const TfeTensorF64,
56    num_operands: usize,
57    cotangent: *const TfeTensorF64,
58    grads_out: *mut *mut TfeTensorF64,
59) -> std::result::Result<(), tfe_status_t>
60where
61    T: TropicalScalar<Inner = f64> + Conjugate + tenferro_algebra::HasAlgebra<Algebra = Alg>,
62    Alg: tenferro_algebra::Semiring<Scalar = T>,
63    CpuBackend: EinsumBackend<Alg>
64        + TensorSemiringCore<Alg, Context = CpuContext>
65        + TensorSemiringFastPath<
66            Alg,
67            Context = CpuContext,
68            Plan = <CpuBackend as TensorSemiringCore<Alg>>::Plan,
69        >,
70{
71    if cotangent.is_null() || grads_out.is_null() {
72        return Err(tenferro_capi::TFE_INVALID_ARGUMENT);
73    }
74
75    let subscripts = parse_subscripts(subscripts)?;
76    let operands = collect_operand_handles(operands, num_operands)?;
77    let tropical_operands: Vec<Tensor<T>> = operands
78        .iter()
79        .map(|tensor| promote_to_tropical::<T>(*tensor).map_err(|err| map_device_error(&err)))
80        .collect::<std::result::Result<_, _>>()?;
81    let tropical_refs: Vec<&Tensor<T>> = tropical_operands.iter().collect();
82    let cotangent = handle_to_ref(cotangent);
83    let mut ctx = cpu_context()?;
84    let grads = tropical_einsum_rrule::<T, Alg, CpuBackend>(
85        &mut ctx,
86        subscripts,
87        &tropical_refs,
88        cotangent,
89    )
90    .map_err(|err| map_device_error(&err))?;
91
92    if grads.len() != num_operands {
93        return Err(tenferro_capi::TFE_INTERNAL_ERROR);
94    }
95
96    let out = std::slice::from_raw_parts_mut(grads_out, num_operands);
97    for (slot, grad) in out.iter_mut().zip(grads.into_iter()) {
98        *slot = tensor_to_handle(grad);
99    }
100    Ok(())
101}
102
103unsafe fn tropical_einsum_frule_impl<T, Alg>(
104    subscripts: *const c_char,
105    primals: *const *const TfeTensorF64,
106    num_operands: usize,
107    tangents: *const *const TfeTensorF64,
108) -> std::result::Result<*mut TfeTensorF64, tfe_status_t>
109where
110    T: TropicalScalar<Inner = f64> + Conjugate + tenferro_algebra::HasAlgebra<Algebra = Alg>,
111    Alg: tenferro_algebra::Semiring<Scalar = T>,
112    CpuBackend: EinsumBackend<Alg>
113        + TensorSemiringCore<Alg, Context = CpuContext>
114        + TensorSemiringFastPath<
115            Alg,
116            Context = CpuContext,
117            Plan = <CpuBackend as TensorSemiringCore<Alg>>::Plan,
118        >,
119{
120    let subscripts = parse_subscripts(subscripts)?;
121    let primals = collect_operand_handles(primals, num_operands)?;
122    let tangents = collect_optional_tangent_handles(tangents, num_operands)?;
123    let tropical_primals: Vec<Tensor<T>> = primals
124        .iter()
125        .map(|tensor| promote_to_tropical::<T>(*tensor).map_err(|err| map_device_error(&err)))
126        .collect::<std::result::Result<_, _>>()?;
127    let tropical_refs: Vec<&Tensor<T>> = tropical_primals.iter().collect();
128    let mut ctx = cpu_context()?;
129    let output = tropical_einsum_frule::<T, Alg, CpuBackend>(
130        &mut ctx,
131        subscripts,
132        &tropical_refs,
133        &tangents,
134    )
135    .map_err(|err| map_device_error(&err))?;
136    Ok(tensor_to_handle(output))
137}
138
139macro_rules! define_tropical_entrypoints {
140    (
141        algebra = $algebra_name:literal,
142        combine = $combine:literal,
143        multiply = $multiply:literal,
144        tropical_ty = $tropical_ty:ty,
145        algebra_ty = $algebra_ty:ty,
146        einsum_fn = $einsum_fn:ident,
147        rrule_fn = $rrule_fn:ident,
148        frule_fn = $frule_fn:ident
149    ) => {
150        #[doc = concat!(
151            "Execute tropical einsum under ", $algebra_name, " algebra (⊕=", $combine, ", ⊗=", $multiply, ").\n\n",
152            "Accepts standard `TfeTensorF64` handles and interprets them as tropical scalars internally.\n",
153            "Returns a new tensor that the caller must release with `tfe_tensor_f64_release`.\n\n",
154            "# Safety\n\n",
155            "- `subscripts` must be a valid null-terminated C string.\n",
156            "- `operands` must point to an array of `num_operands` valid tensor pointers.\n",
157            "- `status` must be a valid, non-null pointer.\n\n",
158            "# Examples (C)\n\n",
159            "```c\n",
160            "const tfe_tensor_f64 *ops[] = {a, b};\n",
161            "tfe_status_t status;\n",
162            "tfe_tensor_f64 *c = ", stringify!($einsum_fn), "(\"ij,jk->ik\", ops, 2, &status);\n",
163            "tfe_tensor_f64_release(c);\n",
164            "```"
165        )]
166        #[no_mangle]
167        pub unsafe extern "C" fn $einsum_fn(
168            subscripts: *const c_char,
169            operands: *const *const TfeTensorF64,
170            num_operands: usize,
171            status: *mut tfe_status_t,
172        ) -> *mut TfeTensorF64 {
173            let result = catch_unwind(AssertUnwindSafe(
174                || -> std::result::Result<*mut TfeTensorF64, tfe_status_t> {
175                    tropical_einsum_impl::<$tropical_ty, $algebra_ty>(
176                        subscripts,
177                        operands,
178                        num_operands,
179                    )
180                },
181            ));
182            finalize_ptr(result, status)
183        }
184
185        #[doc = concat!(
186            "Reverse-mode rule (VJP) for ", $algebra_name, " tropical einsum.\n\n",
187            "Computes one gradient tensor per input operand given the output cotangent.\n\n",
188            "# Safety\n\n",
189            "- `subscripts` must be a valid null-terminated C string.\n",
190            "- `operands` must point to an array of `num_operands` valid tensor pointers.\n",
191            "- `cotangent` must be a valid, non-null tensor pointer.\n",
192            "- `grads_out` must point to a caller-allocated array of `num_operands` mutable output pointers.\n",
193            "- `status` must be a valid, non-null pointer.\n\n",
194            "# Examples (C)\n\n",
195            "```c\n",
196            "tfe_tensor_f64 *grads[2];\n",
197            "tfe_status_t status;\n",
198            "const tfe_tensor_f64 *ops[] = {a, b};\n",
199            stringify!($rrule_fn), "(\"ij,jk->ik\", ops, 2, grad_c, grads, &status);\n",
200            "tfe_tensor_f64_release(grads[0]);\n",
201            "tfe_tensor_f64_release(grads[1]);\n",
202            "```"
203        )]
204        #[no_mangle]
205        pub unsafe extern "C" fn $rrule_fn(
206            subscripts: *const c_char,
207            operands: *const *const TfeTensorF64,
208            num_operands: usize,
209            cotangent: *const TfeTensorF64,
210            grads_out: *mut *mut TfeTensorF64,
211            status: *mut tfe_status_t,
212        ) {
213            let result = catch_unwind(AssertUnwindSafe(
214                || -> std::result::Result<(), tfe_status_t> {
215                    tropical_einsum_rrule_impl::<$tropical_ty, $algebra_ty>(
216                        subscripts,
217                        operands,
218                        num_operands,
219                        cotangent,
220                        grads_out,
221                    )
222                },
223            ));
224            finalize_void(result, status)
225        }
226
227        #[doc = concat!(
228            "Forward-mode rule (JVP) for ", $algebra_name, " tropical einsum.\n\n",
229            "Returns the output tangent. Elements of `tangents` may be null to denote zero tangents.\n\n",
230            "# Safety\n\n",
231            "- `subscripts` must be a valid null-terminated C string.\n",
232            "- `primals` must point to an array of `num_operands` valid tensor pointers.\n",
233            "- `tangents` must point to an array of `num_operands` tensor pointers (elements may be null).\n",
234            "- `status` must be a valid, non-null pointer.\n\n",
235            "# Examples (C)\n\n",
236            "```c\n",
237            "const tfe_tensor_f64 *primals[] = {a, b};\n",
238            "const tfe_tensor_f64 *tangents[] = {da, NULL};\n",
239            "tfe_status_t status;\n",
240            "tfe_tensor_f64 *dc = ", stringify!($frule_fn), "(\"ij,jk->ik\", primals, 2, tangents, &status);\n",
241            "tfe_tensor_f64_release(dc);\n",
242            "```"
243        )]
244        #[no_mangle]
245        pub unsafe extern "C" fn $frule_fn(
246            subscripts: *const c_char,
247            primals: *const *const TfeTensorF64,
248            num_operands: usize,
249            tangents: *const *const TfeTensorF64,
250            status: *mut tfe_status_t,
251        ) -> *mut TfeTensorF64 {
252            let result = catch_unwind(AssertUnwindSafe(
253                || -> std::result::Result<*mut TfeTensorF64, tfe_status_t> {
254                    tropical_einsum_frule_impl::<$tropical_ty, $algebra_ty>(
255                        subscripts,
256                        primals,
257                        num_operands,
258                        tangents,
259                    )
260                },
261            ));
262            finalize_ptr(result, status)
263        }
264    };
265}
266
267define_tropical_entrypoints!(
268    algebra = "MaxPlus",
269    combine = "max",
270    multiply = "+",
271    tropical_ty = MaxPlus<f64>,
272    algebra_ty = MaxPlusAlgebra<f64>,
273    einsum_fn = tfe_tropical_einsum_maxplus_f64,
274    rrule_fn = tfe_tropical_einsum_rrule_maxplus_f64,
275    frule_fn = tfe_tropical_einsum_frule_maxplus_f64
276);
277
278define_tropical_entrypoints!(
279    algebra = "MinPlus",
280    combine = "min",
281    multiply = "+",
282    tropical_ty = MinPlus<f64>,
283    algebra_ty = MinPlusAlgebra<f64>,
284    einsum_fn = tfe_tropical_einsum_minplus_f64,
285    rrule_fn = tfe_tropical_einsum_rrule_minplus_f64,
286    frule_fn = tfe_tropical_einsum_frule_minplus_f64
287);
288
289define_tropical_entrypoints!(
290    algebra = "MaxMul",
291    combine = "max",
292    multiply = "×",
293    tropical_ty = MaxMul<f64>,
294    algebra_ty = MaxMulAlgebra<f64>,
295    einsum_fn = tfe_tropical_einsum_maxmul_f64,
296    rrule_fn = tfe_tropical_einsum_rrule_maxmul_f64,
297    frule_fn = tfe_tropical_einsum_frule_maxmul_f64
298);