1use std::os::raw::c_char;
2use std::panic::{catch_unwind, AssertUnwindSafe};
3
4use tenferro_algebra::Conjugate;
5use tenferro_capi::{tfe_status_t, TfeTensorF64};
6use tenferro_einsum::{einsum, EinsumBackend};
7use tenferro_ext_tropical::ad::{
8 extract_inner, promote_to_tropical, tropical_einsum_frule, tropical_einsum_rrule,
9 TropicalScalar,
10};
11use tenferro_ext_tropical::{
12 MaxMul, MaxMulAlgebra, MaxPlus, MaxPlusAlgebra, MinPlus, MinPlusAlgebra,
13};
14use tenferro_prims::{CpuBackend, CpuContext, TensorSemiringCore, TensorSemiringFastPath};
15use tenferro_tensor::Tensor;
16
17use crate::ffi_utils::{
18 collect_operand_handles, collect_optional_tangent_handles, cpu_context, parse_subscripts,
19};
20use crate::handle::{handle_to_ref, tensor_to_handle};
21use crate::status::{finalize_ptr, finalize_void, map_device_error};
22
23unsafe fn tropical_einsum_impl<T, Alg>(
24 subscripts: *const c_char,
25 operands: *const *const TfeTensorF64,
26 num_operands: usize,
27) -> std::result::Result<*mut TfeTensorF64, tfe_status_t>
28where
29 T: TropicalScalar<Inner = f64> + Conjugate + tenferro_algebra::HasAlgebra<Algebra = Alg>,
30 Alg: tenferro_algebra::Semiring<Scalar = T>,
31 CpuBackend: EinsumBackend<Alg>
32 + TensorSemiringCore<Alg, Context = CpuContext>
33 + TensorSemiringFastPath<
34 Alg,
35 Context = CpuContext,
36 Plan = <CpuBackend as TensorSemiringCore<Alg>>::Plan,
37 >,
38{
39 let subscripts = parse_subscripts(subscripts)?;
40 let operands = collect_operand_handles(operands, num_operands)?;
41 let tropical_operands: Vec<Tensor<T>> = operands
42 .iter()
43 .map(|tensor| promote_to_tropical::<T>(*tensor).map_err(|err| map_device_error(&err)))
44 .collect::<std::result::Result<_, _>>()?;
45 let tropical_refs: Vec<&Tensor<T>> = tropical_operands.iter().collect();
46 let mut ctx = cpu_context()?;
47 let output = einsum::<Alg, CpuBackend>(&mut ctx, subscripts, &tropical_refs, None)
48 .map_err(|err| map_device_error(&err))?;
49 let output = extract_inner::<T>(&output).map_err(|err| map_device_error(&err))?;
50 Ok(tensor_to_handle(output))
51}
52
53unsafe fn tropical_einsum_rrule_impl<T, Alg>(
54 subscripts: *const c_char,
55 operands: *const *const TfeTensorF64,
56 num_operands: usize,
57 cotangent: *const TfeTensorF64,
58 grads_out: *mut *mut TfeTensorF64,
59) -> std::result::Result<(), tfe_status_t>
60where
61 T: TropicalScalar<Inner = f64> + Conjugate + tenferro_algebra::HasAlgebra<Algebra = Alg>,
62 Alg: tenferro_algebra::Semiring<Scalar = T>,
63 CpuBackend: EinsumBackend<Alg>
64 + TensorSemiringCore<Alg, Context = CpuContext>
65 + TensorSemiringFastPath<
66 Alg,
67 Context = CpuContext,
68 Plan = <CpuBackend as TensorSemiringCore<Alg>>::Plan,
69 >,
70{
71 if cotangent.is_null() || grads_out.is_null() {
72 return Err(tenferro_capi::TFE_INVALID_ARGUMENT);
73 }
74
75 let subscripts = parse_subscripts(subscripts)?;
76 let operands = collect_operand_handles(operands, num_operands)?;
77 let tropical_operands: Vec<Tensor<T>> = operands
78 .iter()
79 .map(|tensor| promote_to_tropical::<T>(*tensor).map_err(|err| map_device_error(&err)))
80 .collect::<std::result::Result<_, _>>()?;
81 let tropical_refs: Vec<&Tensor<T>> = tropical_operands.iter().collect();
82 let cotangent = handle_to_ref(cotangent);
83 let mut ctx = cpu_context()?;
84 let grads = tropical_einsum_rrule::<T, Alg, CpuBackend>(
85 &mut ctx,
86 subscripts,
87 &tropical_refs,
88 cotangent,
89 )
90 .map_err(|err| map_device_error(&err))?;
91
92 if grads.len() != num_operands {
93 return Err(tenferro_capi::TFE_INTERNAL_ERROR);
94 }
95
96 let out = std::slice::from_raw_parts_mut(grads_out, num_operands);
97 for (slot, grad) in out.iter_mut().zip(grads.into_iter()) {
98 *slot = tensor_to_handle(grad);
99 }
100 Ok(())
101}
102
103unsafe fn tropical_einsum_frule_impl<T, Alg>(
104 subscripts: *const c_char,
105 primals: *const *const TfeTensorF64,
106 num_operands: usize,
107 tangents: *const *const TfeTensorF64,
108) -> std::result::Result<*mut TfeTensorF64, tfe_status_t>
109where
110 T: TropicalScalar<Inner = f64> + Conjugate + tenferro_algebra::HasAlgebra<Algebra = Alg>,
111 Alg: tenferro_algebra::Semiring<Scalar = T>,
112 CpuBackend: EinsumBackend<Alg>
113 + TensorSemiringCore<Alg, Context = CpuContext>
114 + TensorSemiringFastPath<
115 Alg,
116 Context = CpuContext,
117 Plan = <CpuBackend as TensorSemiringCore<Alg>>::Plan,
118 >,
119{
120 let subscripts = parse_subscripts(subscripts)?;
121 let primals = collect_operand_handles(primals, num_operands)?;
122 let tangents = collect_optional_tangent_handles(tangents, num_operands)?;
123 let tropical_primals: Vec<Tensor<T>> = primals
124 .iter()
125 .map(|tensor| promote_to_tropical::<T>(*tensor).map_err(|err| map_device_error(&err)))
126 .collect::<std::result::Result<_, _>>()?;
127 let tropical_refs: Vec<&Tensor<T>> = tropical_primals.iter().collect();
128 let mut ctx = cpu_context()?;
129 let output = tropical_einsum_frule::<T, Alg, CpuBackend>(
130 &mut ctx,
131 subscripts,
132 &tropical_refs,
133 &tangents,
134 )
135 .map_err(|err| map_device_error(&err))?;
136 Ok(tensor_to_handle(output))
137}
138
139macro_rules! define_tropical_entrypoints {
140 (
141 algebra = $algebra_name:literal,
142 combine = $combine:literal,
143 multiply = $multiply:literal,
144 tropical_ty = $tropical_ty:ty,
145 algebra_ty = $algebra_ty:ty,
146 einsum_fn = $einsum_fn:ident,
147 rrule_fn = $rrule_fn:ident,
148 frule_fn = $frule_fn:ident
149 ) => {
150 #[doc = concat!(
151 "Execute tropical einsum under ", $algebra_name, " algebra (⊕=", $combine, ", ⊗=", $multiply, ").\n\n",
152 "Accepts standard `TfeTensorF64` handles and interprets them as tropical scalars internally.\n",
153 "Returns a new tensor that the caller must release with `tfe_tensor_f64_release`.\n\n",
154 "# Safety\n\n",
155 "- `subscripts` must be a valid null-terminated C string.\n",
156 "- `operands` must point to an array of `num_operands` valid tensor pointers.\n",
157 "- `status` must be a valid, non-null pointer.\n\n",
158 "# Examples (C)\n\n",
159 "```c\n",
160 "const tfe_tensor_f64 *ops[] = {a, b};\n",
161 "tfe_status_t status;\n",
162 "tfe_tensor_f64 *c = ", stringify!($einsum_fn), "(\"ij,jk->ik\", ops, 2, &status);\n",
163 "tfe_tensor_f64_release(c);\n",
164 "```"
165 )]
166 #[no_mangle]
167 pub unsafe extern "C" fn $einsum_fn(
168 subscripts: *const c_char,
169 operands: *const *const TfeTensorF64,
170 num_operands: usize,
171 status: *mut tfe_status_t,
172 ) -> *mut TfeTensorF64 {
173 let result = catch_unwind(AssertUnwindSafe(
174 || -> std::result::Result<*mut TfeTensorF64, tfe_status_t> {
175 tropical_einsum_impl::<$tropical_ty, $algebra_ty>(
176 subscripts,
177 operands,
178 num_operands,
179 )
180 },
181 ));
182 finalize_ptr(result, status)
183 }
184
185 #[doc = concat!(
186 "Reverse-mode rule (VJP) for ", $algebra_name, " tropical einsum.\n\n",
187 "Computes one gradient tensor per input operand given the output cotangent.\n\n",
188 "# Safety\n\n",
189 "- `subscripts` must be a valid null-terminated C string.\n",
190 "- `operands` must point to an array of `num_operands` valid tensor pointers.\n",
191 "- `cotangent` must be a valid, non-null tensor pointer.\n",
192 "- `grads_out` must point to a caller-allocated array of `num_operands` mutable output pointers.\n",
193 "- `status` must be a valid, non-null pointer.\n\n",
194 "# Examples (C)\n\n",
195 "```c\n",
196 "tfe_tensor_f64 *grads[2];\n",
197 "tfe_status_t status;\n",
198 "const tfe_tensor_f64 *ops[] = {a, b};\n",
199 stringify!($rrule_fn), "(\"ij,jk->ik\", ops, 2, grad_c, grads, &status);\n",
200 "tfe_tensor_f64_release(grads[0]);\n",
201 "tfe_tensor_f64_release(grads[1]);\n",
202 "```"
203 )]
204 #[no_mangle]
205 pub unsafe extern "C" fn $rrule_fn(
206 subscripts: *const c_char,
207 operands: *const *const TfeTensorF64,
208 num_operands: usize,
209 cotangent: *const TfeTensorF64,
210 grads_out: *mut *mut TfeTensorF64,
211 status: *mut tfe_status_t,
212 ) {
213 let result = catch_unwind(AssertUnwindSafe(
214 || -> std::result::Result<(), tfe_status_t> {
215 tropical_einsum_rrule_impl::<$tropical_ty, $algebra_ty>(
216 subscripts,
217 operands,
218 num_operands,
219 cotangent,
220 grads_out,
221 )
222 },
223 ));
224 finalize_void(result, status)
225 }
226
227 #[doc = concat!(
228 "Forward-mode rule (JVP) for ", $algebra_name, " tropical einsum.\n\n",
229 "Returns the output tangent. Elements of `tangents` may be null to denote zero tangents.\n\n",
230 "# Safety\n\n",
231 "- `subscripts` must be a valid null-terminated C string.\n",
232 "- `primals` must point to an array of `num_operands` valid tensor pointers.\n",
233 "- `tangents` must point to an array of `num_operands` tensor pointers (elements may be null).\n",
234 "- `status` must be a valid, non-null pointer.\n\n",
235 "# Examples (C)\n\n",
236 "```c\n",
237 "const tfe_tensor_f64 *primals[] = {a, b};\n",
238 "const tfe_tensor_f64 *tangents[] = {da, NULL};\n",
239 "tfe_status_t status;\n",
240 "tfe_tensor_f64 *dc = ", stringify!($frule_fn), "(\"ij,jk->ik\", primals, 2, tangents, &status);\n",
241 "tfe_tensor_f64_release(dc);\n",
242 "```"
243 )]
244 #[no_mangle]
245 pub unsafe extern "C" fn $frule_fn(
246 subscripts: *const c_char,
247 primals: *const *const TfeTensorF64,
248 num_operands: usize,
249 tangents: *const *const TfeTensorF64,
250 status: *mut tfe_status_t,
251 ) -> *mut TfeTensorF64 {
252 let result = catch_unwind(AssertUnwindSafe(
253 || -> std::result::Result<*mut TfeTensorF64, tfe_status_t> {
254 tropical_einsum_frule_impl::<$tropical_ty, $algebra_ty>(
255 subscripts,
256 primals,
257 num_operands,
258 tangents,
259 )
260 },
261 ));
262 finalize_ptr(result, status)
263 }
264 };
265}
266
267define_tropical_entrypoints!(
268 algebra = "MaxPlus",
269 combine = "max",
270 multiply = "+",
271 tropical_ty = MaxPlus<f64>,
272 algebra_ty = MaxPlusAlgebra<f64>,
273 einsum_fn = tfe_tropical_einsum_maxplus_f64,
274 rrule_fn = tfe_tropical_einsum_rrule_maxplus_f64,
275 frule_fn = tfe_tropical_einsum_frule_maxplus_f64
276);
277
278define_tropical_entrypoints!(
279 algebra = "MinPlus",
280 combine = "min",
281 multiply = "+",
282 tropical_ty = MinPlus<f64>,
283 algebra_ty = MinPlusAlgebra<f64>,
284 einsum_fn = tfe_tropical_einsum_minplus_f64,
285 rrule_fn = tfe_tropical_einsum_rrule_minplus_f64,
286 frule_fn = tfe_tropical_einsum_frule_minplus_f64
287);
288
289define_tropical_entrypoints!(
290 algebra = "MaxMul",
291 combine = "max",
292 multiply = "×",
293 tropical_ty = MaxMul<f64>,
294 algebra_ty = MaxMulAlgebra<f64>,
295 einsum_fn = tfe_tropical_einsum_maxmul_f64,
296 rrule_fn = tfe_tropical_einsum_rrule_maxmul_f64,
297 frule_fn = tfe_tropical_einsum_frule_maxmul_f64
298);