tenferro_capi/
einsum_api.rs

1use std::panic::{catch_unwind, AssertUnwindSafe};
2
3use tenferro_algebra::Standard;
4use tenferro_einsum::{einsum, einsum_frule, einsum_rrule};
5use tenferro_prims::CpuBackend;
6
7use crate::ffi_utils::{cpu_context, read_c_str, read_optional_tensor_refs, read_tensor_refs};
8use crate::handle::{ensure_col_major, handle_to_ref, tensor_to_handle, TfeTensorF64};
9use crate::status::{
10    finalize_ptr, finalize_void, map_device_error, tfe_status_t, TFE_INVALID_ARGUMENT,
11};
12
13/// Execute einsum using string notation.
14///
15/// Returns a new tensor. The caller must release it with
16/// `tfe_tensor_f64_release`.
17///
18/// # Safety
19///
20/// - `subscripts` must be a valid null-terminated C string.
21/// - `operands` must point to an array of `num_operands` valid tensor pointers.
22/// - `status` must be a valid, non-null pointer.
23///
24/// # Examples (C)
25///
26/// ```c
27/// const tfe_tensor_f64 *ops[] = {a, b};
28/// tfe_status_t status;
29/// tfe_tensor_f64 *c = tfe_einsum_f64("ij,jk->ik", ops, 2, &status);
30/// tfe_tensor_f64_release(c);
31/// ```
32#[no_mangle]
33pub unsafe extern "C" fn tfe_einsum_f64(
34    subscripts: *const std::os::raw::c_char,
35    operands: *const *const TfeTensorF64,
36    num_operands: usize,
37    status: *mut tfe_status_t,
38) -> *mut TfeTensorF64 {
39    let result = catch_unwind(AssertUnwindSafe(|| {
40        let subs = read_c_str(subscripts, "einsum subscripts")?;
41        let ops = read_tensor_refs(operands, num_operands, "einsum operands")?;
42
43        let mut ctx = cpu_context()?;
44        einsum::<Standard<f64>, CpuBackend>(&mut ctx, subs, &ops, None)
45            .and_then(|t| ensure_col_major(&mut ctx, t))
46            .map(tensor_to_handle)
47            .map_err(|e| map_device_error(&e))
48    }));
49
50    finalize_ptr(result, status)
51}
52
53/// Reverse-mode rule (VJP) for einsum.
54///
55/// # Safety
56///
57/// - `subscripts` must be a valid null-terminated C string.
58/// - `operands` must point to an array of `num_operands` valid tensor pointers.
59/// - `cotangent` must be a valid, non-null tensor pointer.
60/// - `grads_out` must point to a caller-allocated array of `num_operands`
61///   mutable `*mut TfeTensorF64` pointers.
62/// - `status` must be a valid, non-null pointer.
63///
64/// # Examples (C)
65///
66/// ```c
67/// tfe_tensor_f64 *grads[2];
68/// tfe_status_t status;
69/// const tfe_tensor_f64 *ops[] = {a, b};
70/// tfe_einsum_rrule_f64("ij,jk->ik", ops, 2, grad_c, grads, &status);
71/// ```
72#[no_mangle]
73pub unsafe extern "C" fn tfe_einsum_rrule_f64(
74    subscripts: *const std::os::raw::c_char,
75    operands: *const *const TfeTensorF64,
76    num_operands: usize,
77    cotangent: *const TfeTensorF64,
78    grads_out: *mut *mut TfeTensorF64,
79    status: *mut tfe_status_t,
80) {
81    let result = catch_unwind(AssertUnwindSafe(|| {
82        if cotangent.is_null() || grads_out.is_null() {
83            return Err(TFE_INVALID_ARGUMENT);
84        }
85
86        let subs = read_c_str(subscripts, "einsum_rrule subscripts")?;
87        let ops = read_tensor_refs(operands, num_operands, "einsum_rrule operands")?;
88        let cot = handle_to_ref(cotangent);
89
90        let mut ctx = cpu_context()?;
91        let grads = einsum_rrule::<Standard<f64>, CpuBackend>(&mut ctx, subs, &ops, cot)
92            .map_err(|e| map_device_error(&e))?;
93
94        let out_slice = std::slice::from_raw_parts_mut(grads_out, num_operands);
95        for (i, g) in grads.into_iter().enumerate() {
96            let g = ensure_col_major(&mut ctx, g).map_err(|e| map_device_error(&e))?;
97            out_slice[i] = tensor_to_handle(g);
98        }
99
100        Ok(())
101    }));
102
103    finalize_void(result, status)
104}
105
106/// Forward-mode rule (JVP) for einsum.
107///
108/// Returns the output tangent. Elements of `tangents` may be null
109/// (interpreted as zero tangent for that operand).
110///
111/// # Safety
112///
113/// - `subscripts` must be a valid null-terminated C string.
114/// - `primals` must point to an array of `num_operands` valid tensor pointers.
115/// - `tangents` must point to an array of `num_operands` tensor pointers
116///   (elements may be null).
117/// - `status` must be a valid, non-null pointer.
118///
119/// # Examples (C)
120///
121/// ```c
122/// const tfe_tensor_f64 *primals[] = {a, b};
123/// const tfe_tensor_f64 *tangents[] = {da, NULL};
124/// tfe_status_t status;
125/// tfe_tensor_f64 *dc = tfe_einsum_frule_f64("ij,jk->ik", primals, 2, tangents, &status);
126/// tfe_tensor_f64_release(dc);
127/// ```
128#[no_mangle]
129pub unsafe extern "C" fn tfe_einsum_frule_f64(
130    subscripts: *const std::os::raw::c_char,
131    primals: *const *const TfeTensorF64,
132    num_operands: usize,
133    tangents: *const *const TfeTensorF64,
134    status: *mut tfe_status_t,
135) -> *mut TfeTensorF64 {
136    let result = catch_unwind(AssertUnwindSafe(|| {
137        let subs = read_c_str(subscripts, "einsum_frule subscripts")?;
138        let primal_refs = read_tensor_refs(primals, num_operands, "einsum_frule primals")?;
139        let tangent_refs =
140            read_optional_tensor_refs(tangents, num_operands, "einsum_frule tangents")?;
141
142        let mut ctx = cpu_context()?;
143        einsum_frule::<Standard<f64>, CpuBackend>(&mut ctx, subs, &primal_refs, &tangent_refs)
144            .and_then(|t| ensure_col_major(&mut ctx, t))
145            .map(tensor_to_handle)
146            .map_err(|e| map_device_error(&e))
147    }));
148
149    finalize_ptr(result, status)
150}