tenferro_capi/
einsum_api.rs1use std::panic::{catch_unwind, AssertUnwindSafe};
2
3use tenferro_algebra::Standard;
4use tenferro_einsum::{einsum, einsum_frule, einsum_rrule};
5use tenferro_prims::CpuBackend;
6
7use crate::ffi_utils::{cpu_context, read_c_str, read_optional_tensor_refs, read_tensor_refs};
8use crate::handle::{ensure_col_major, handle_to_ref, tensor_to_handle, TfeTensorF64};
9use crate::status::{
10 finalize_ptr, finalize_void, map_device_error, tfe_status_t, TFE_INVALID_ARGUMENT,
11};
12
13#[no_mangle]
33pub unsafe extern "C" fn tfe_einsum_f64(
34 subscripts: *const std::os::raw::c_char,
35 operands: *const *const TfeTensorF64,
36 num_operands: usize,
37 status: *mut tfe_status_t,
38) -> *mut TfeTensorF64 {
39 let result = catch_unwind(AssertUnwindSafe(|| {
40 let subs = read_c_str(subscripts, "einsum subscripts")?;
41 let ops = read_tensor_refs(operands, num_operands, "einsum operands")?;
42
43 let mut ctx = cpu_context()?;
44 einsum::<Standard<f64>, CpuBackend>(&mut ctx, subs, &ops, None)
45 .and_then(|t| ensure_col_major(&mut ctx, t))
46 .map(tensor_to_handle)
47 .map_err(|e| map_device_error(&e))
48 }));
49
50 finalize_ptr(result, status)
51}
52
53#[no_mangle]
73pub unsafe extern "C" fn tfe_einsum_rrule_f64(
74 subscripts: *const std::os::raw::c_char,
75 operands: *const *const TfeTensorF64,
76 num_operands: usize,
77 cotangent: *const TfeTensorF64,
78 grads_out: *mut *mut TfeTensorF64,
79 status: *mut tfe_status_t,
80) {
81 let result = catch_unwind(AssertUnwindSafe(|| {
82 if cotangent.is_null() || grads_out.is_null() {
83 return Err(TFE_INVALID_ARGUMENT);
84 }
85
86 let subs = read_c_str(subscripts, "einsum_rrule subscripts")?;
87 let ops = read_tensor_refs(operands, num_operands, "einsum_rrule operands")?;
88 let cot = handle_to_ref(cotangent);
89
90 let mut ctx = cpu_context()?;
91 let grads = einsum_rrule::<Standard<f64>, CpuBackend>(&mut ctx, subs, &ops, cot)
92 .map_err(|e| map_device_error(&e))?;
93
94 let out_slice = std::slice::from_raw_parts_mut(grads_out, num_operands);
95 for (i, g) in grads.into_iter().enumerate() {
96 let g = ensure_col_major(&mut ctx, g).map_err(|e| map_device_error(&e))?;
97 out_slice[i] = tensor_to_handle(g);
98 }
99
100 Ok(())
101 }));
102
103 finalize_void(result, status)
104}
105
106#[no_mangle]
129pub unsafe extern "C" fn tfe_einsum_frule_f64(
130 subscripts: *const std::os::raw::c_char,
131 primals: *const *const TfeTensorF64,
132 num_operands: usize,
133 tangents: *const *const TfeTensorF64,
134 status: *mut tfe_status_t,
135) -> *mut TfeTensorF64 {
136 let result = catch_unwind(AssertUnwindSafe(|| {
137 let subs = read_c_str(subscripts, "einsum_frule subscripts")?;
138 let primal_refs = read_tensor_refs(primals, num_operands, "einsum_frule primals")?;
139 let tangent_refs =
140 read_optional_tensor_refs(tangents, num_operands, "einsum_frule tangents")?;
141
142 let mut ctx = cpu_context()?;
143 einsum_frule::<Standard<f64>, CpuBackend>(&mut ctx, subs, &primal_refs, &tangent_refs)
144 .and_then(|t| ensure_col_major(&mut ctx, t))
145 .map(tensor_to_handle)
146 .map_err(|e| map_device_error(&e))
147 }));
148
149 finalize_ptr(result, status)
150}