tenferro_capi/
svd_api.rs

1use std::panic::{catch_unwind, AssertUnwindSafe};
2
3use tenferro_device::LogicalMemorySpace;
4use tenferro_linalg::{svd, svd_frule, svd_rrule, SvdCotangent, SvdOptions};
5use tenferro_tensor::{MemoryOrder, Tensor};
6
7use crate::ffi_utils::{cpu_context, read_usize_slice};
8use crate::handle::{handle_to_ref, tensor_to_handle, TfeTensorF64};
9use crate::status::{
10    finalize_ptr, finalize_void, map_ad_error, map_device_error, tfe_status_t,
11    TFE_INVALID_ARGUMENT, TFE_SHAPE_MISMATCH,
12};
13
14fn build_svd_options(max_rank: usize, cutoff: f64) -> Option<SvdOptions> {
15    let mr = if max_rank == 0 { None } else { Some(max_rank) };
16    let co = if cutoff < 0.0 { None } else { Some(cutoff) };
17    if mr.is_none() && co.is_none() {
18        None
19    } else {
20        Some(SvdOptions {
21            max_rank: mr,
22            cutoff: co,
23        })
24    }
25}
26
27fn matricize(
28    tensor: &Tensor<f64>,
29    left: &[usize],
30    right: &[usize],
31) -> Result<(Tensor<f64>, Vec<usize>, Vec<usize>), tfe_status_t> {
32    let dims = tensor.dims();
33    let mut seen = vec![false; dims.len()];
34
35    for &l in left {
36        if l >= dims.len() || seen[l] {
37            return Err(TFE_INVALID_ARGUMENT);
38        }
39        seen[l] = true;
40    }
41    for &r in right {
42        if r >= dims.len() || seen[r] {
43            return Err(TFE_INVALID_ARGUMENT);
44        }
45        seen[r] = true;
46    }
47    if left.len() + right.len() != dims.len() || seen.iter().any(|&v| !v) {
48        return Err(TFE_INVALID_ARGUMENT);
49    }
50
51    let left_dims: Vec<usize> = left.iter().map(|&i| dims[i]).collect();
52    let right_dims: Vec<usize> = right.iter().map(|&i| dims[i]).collect();
53    let m: usize = left_dims.iter().product();
54    let n: usize = right_dims.iter().product();
55
56    let mut perm: Vec<usize> = Vec::with_capacity(dims.len());
57    perm.extend_from_slice(left);
58    perm.extend_from_slice(right);
59
60    let permuted = tensor
61        .permute(&perm)
62        .map_err(|_| TFE_INVALID_ARGUMENT)?
63        .contiguous(MemoryOrder::ColumnMajor)
64        .reshape(&[m, n])
65        .map_err(|_| TFE_SHAPE_MISMATCH)?
66        .contiguous(MemoryOrder::ColumnMajor);
67
68    Ok((permuted, left_dims, right_dims))
69}
70
71fn inverse_permutation(perm: &[usize]) -> Vec<usize> {
72    let mut inv = vec![0; perm.len()];
73    for (i, &p) in perm.iter().enumerate() {
74        inv[p] = i;
75    }
76    inv
77}
78
79fn u_cotangent_to_matrix(
80    cot_u: &Tensor<f64>,
81    left_dims: &[usize],
82) -> Result<(Tensor<f64>, usize), tfe_status_t> {
83    let u_dims = cot_u.dims();
84    if u_dims.len() != left_dims.len() + 1 || &u_dims[..left_dims.len()] != left_dims {
85        return Err(TFE_SHAPE_MISMATCH);
86    }
87    let k = u_dims[left_dims.len()];
88    let m: usize = left_dims.iter().product();
89    let mat = cot_u
90        .contiguous(MemoryOrder::ColumnMajor)
91        .reshape(&[m, k])
92        .map_err(|_| TFE_SHAPE_MISMATCH)?
93        .contiguous(MemoryOrder::ColumnMajor);
94    Ok((mat, k))
95}
96
97fn vt_cotangent_to_matrix(
98    cot_vt: &Tensor<f64>,
99    right_dims: &[usize],
100) -> Result<(Tensor<f64>, usize), tfe_status_t> {
101    let vt_dims = cot_vt.dims();
102    if vt_dims.len() != right_dims.len() + 1 || &vt_dims[1..] != right_dims {
103        return Err(TFE_SHAPE_MISMATCH);
104    }
105    let k = vt_dims[0];
106    let n: usize = right_dims.iter().product();
107    let mat = cot_vt
108        .contiguous(MemoryOrder::ColumnMajor)
109        .reshape(&[k, n])
110        .map_err(|_| TFE_SHAPE_MISMATCH)?
111        .contiguous(MemoryOrder::ColumnMajor);
112    Ok((mat, k))
113}
114
115fn validate_s_cotangent(cot_s: &Tensor<f64>) -> Result<usize, tfe_status_t> {
116    let dims = cot_s.dims();
117    if dims.len() != 1 {
118        return Err(TFE_SHAPE_MISMATCH);
119    }
120    Ok(dims[0])
121}
122
123fn u_matrix_to_public(u: Tensor<f64>, left_dims: &[usize]) -> Result<Tensor<f64>, tfe_status_t> {
124    if u.dims().len() != 2 {
125        return Err(TFE_SHAPE_MISMATCH);
126    }
127    let k = u.dims()[1];
128    let mut out_dims = left_dims.to_vec();
129    out_dims.push(k);
130    u.reshape(&out_dims)
131        .map_err(|_| TFE_SHAPE_MISMATCH)
132        .map(|t| t.contiguous(MemoryOrder::ColumnMajor))
133}
134
135fn vt_matrix_to_public(vt: Tensor<f64>, right_dims: &[usize]) -> Result<Tensor<f64>, tfe_status_t> {
136    if vt.dims().len() != 2 {
137        return Err(TFE_SHAPE_MISMATCH);
138    }
139    let k = vt.dims()[0];
140    let mut out_dims = vec![k];
141    out_dims.extend_from_slice(right_dims);
142    vt.reshape(&out_dims)
143        .map_err(|_| TFE_SHAPE_MISMATCH)
144        .map(|t| t.contiguous(MemoryOrder::ColumnMajor))
145}
146
147fn grad_matrix_to_public(
148    grad_matrix: Tensor<f64>,
149    original_dims: &[usize],
150    left: &[usize],
151    right: &[usize],
152    left_dims: &[usize],
153    right_dims: &[usize],
154) -> Result<Tensor<f64>, tfe_status_t> {
155    if grad_matrix.dims().len() != 2 {
156        return Err(TFE_SHAPE_MISMATCH);
157    }
158
159    let mut permuted_dims = left_dims.to_vec();
160    permuted_dims.extend_from_slice(right_dims);
161    let reshaped = grad_matrix
162        .contiguous(MemoryOrder::ColumnMajor)
163        .reshape(&permuted_dims)
164        .map_err(|_| TFE_SHAPE_MISMATCH)?;
165
166    let mut perm = Vec::with_capacity(original_dims.len());
167    perm.extend_from_slice(left);
168    perm.extend_from_slice(right);
169    let inv_perm = inverse_permutation(&perm);
170
171    reshaped
172        .permute(&inv_perm)
173        .map_err(|_| TFE_INVALID_ARGUMENT)
174        .map(|t| t.contiguous(MemoryOrder::ColumnMajor))
175}
176
177/// Compute the SVD of a tensor after matricizing by `left` and `right`.
178///
179/// # Safety
180///
181/// - `tensor` must be valid and non-null.
182/// - `left` and `right` must point to valid index arrays.
183/// - `u_out`, `s_out`, `vt_out` must be valid, non-null pointers.
184/// - `status` must be valid.
185///
186/// # Examples (C)
187///
188/// ```c
189/// size_t left[] = {0};
190/// size_t right[] = {1};
191/// tfe_tensor_f64 *u, *s, *vt;
192/// tfe_status_t status;
193/// tfe_svd_f64(a, left, 1, right, 1, 0, -1.0, &u, &s, &vt, &status);
194/// ```
195#[no_mangle]
196pub unsafe extern "C" fn tfe_svd_f64(
197    tensor: *const TfeTensorF64,
198    left: *const usize,
199    left_len: usize,
200    right: *const usize,
201    right_len: usize,
202    max_rank: usize,
203    cutoff: f64,
204    u_out: *mut *mut TfeTensorF64,
205    s_out: *mut *mut TfeTensorF64,
206    vt_out: *mut *mut TfeTensorF64,
207    status: *mut tfe_status_t,
208) {
209    let result = catch_unwind(AssertUnwindSafe(|| {
210        if tensor.is_null() || u_out.is_null() || s_out.is_null() || vt_out.is_null() {
211            return Err(TFE_INVALID_ARGUMENT);
212        }
213        let t = handle_to_ref(tensor);
214        let left_indices = read_usize_slice(left, left_len, "svd left indices")?;
215        let right_indices = read_usize_slice(right, right_len, "svd right indices")?;
216
217        let (matrix, left_dims, right_dims) = matricize(t, left_indices, right_indices)?;
218
219        let opts = build_svd_options(max_rank, cutoff);
220        let mut ctx = cpu_context()?;
221        let result = svd(&mut ctx, &matrix, opts.as_ref()).map_err(|e| map_device_error(&e))?;
222
223        let k = result.s.len();
224        let mut u_dims: Vec<usize> = left_dims;
225        u_dims.push(k);
226        let u_reshaped = result
227            .u
228            .reshape(&u_dims)
229            .map_err(|e| map_device_error(&e))?
230            .contiguous(MemoryOrder::ColumnMajor);
231
232        let mut vt_dims: Vec<usize> = vec![k];
233        vt_dims.extend_from_slice(&right_dims);
234        let vt_reshaped = result
235            .vt
236            .reshape(&vt_dims)
237            .map_err(|e| map_device_error(&e))?
238            .contiguous(MemoryOrder::ColumnMajor);
239
240        *u_out = tensor_to_handle(u_reshaped);
241        *s_out = tensor_to_handle(result.s);
242        *vt_out = tensor_to_handle(vt_reshaped);
243        Ok(())
244    }));
245
246    finalize_void(result, status)
247}
248
249/// Reverse-mode rule (VJP) for SVD.
250///
251/// # Safety
252///
253/// - `tensor` must be valid and non-null.
254/// - `left` and `right` must point to valid index arrays.
255/// - `status` must be valid.
256///
257/// # Examples (C)
258///
259/// ```c
260/// tfe_tensor_f64 *grad = tfe_svd_rrule_f64(a, left, 1, right, 1, 0, -1.0, NULL, cot_s, NULL, &status);
261/// ```
262#[no_mangle]
263pub unsafe extern "C" fn tfe_svd_rrule_f64(
264    tensor: *const TfeTensorF64,
265    left: *const usize,
266    left_len: usize,
267    right: *const usize,
268    right_len: usize,
269    max_rank: usize,
270    cutoff: f64,
271    cotangent_u: *const TfeTensorF64,
272    cotangent_s: *const TfeTensorF64,
273    cotangent_vt: *const TfeTensorF64,
274    status: *mut tfe_status_t,
275) -> *mut TfeTensorF64 {
276    let result = catch_unwind(AssertUnwindSafe(|| {
277        if tensor.is_null() {
278            return Err(TFE_INVALID_ARGUMENT);
279        }
280
281        let t = handle_to_ref(tensor);
282        let left_indices = read_usize_slice(left, left_len, "svd_rrule left indices")?;
283        let right_indices = read_usize_slice(right, right_len, "svd_rrule right indices")?;
284        let original_dims = t.dims().to_vec();
285        let (matrix, left_dims, right_dims) = matricize(t, left_indices, right_indices)?;
286
287        let mut inferred_k: Option<usize> = None;
288        let cot_u = if cotangent_u.is_null() {
289            None
290        } else {
291            let (u_mat, k) = u_cotangent_to_matrix(handle_to_ref(cotangent_u), &left_dims)?;
292            inferred_k = Some(k);
293            Some(u_mat)
294        };
295        let cot_s = if cotangent_s.is_null() {
296            None
297        } else {
298            let cot_s_ref = handle_to_ref(cotangent_s);
299            let k = validate_s_cotangent(cot_s_ref)?;
300            if let Some(prev) = inferred_k {
301                if prev != k {
302                    return Err(TFE_SHAPE_MISMATCH);
303                }
304            } else {
305                inferred_k = Some(k);
306            }
307            Some(cot_s_ref.clone())
308        };
309        let cot_vt = if cotangent_vt.is_null() {
310            None
311        } else {
312            let (vt_mat, k) = vt_cotangent_to_matrix(handle_to_ref(cotangent_vt), &right_dims)?;
313            if let Some(prev) = inferred_k {
314                if prev != k {
315                    return Err(TFE_SHAPE_MISMATCH);
316                }
317            } else {
318                inferred_k = Some(k);
319            }
320            Some(vt_mat)
321        };
322        let _ = inferred_k;
323
324        let cotangent = SvdCotangent {
325            u: cot_u,
326            s: cot_s,
327            vt: cot_vt,
328        };
329
330        let opts = build_svd_options(max_rank, cutoff);
331        let mut ctx = cpu_context()?;
332        let grad_matrix = svd_rrule(&mut ctx, &matrix, &cotangent, opts.as_ref())
333            .map_err(|e| map_ad_error(&e))?;
334        let grad = grad_matrix_to_public(
335            grad_matrix,
336            &original_dims,
337            left_indices,
338            right_indices,
339            &left_dims,
340            &right_dims,
341        )?;
342
343        Ok(tensor_to_handle(grad))
344    }));
345
346    finalize_ptr(result, status)
347}
348
349/// Forward-mode rule (JVP) for SVD.
350///
351/// # Safety
352///
353/// - `tensor` must be valid and non-null.
354/// - `left` and `right` must point to valid index arrays.
355/// - `u_out`, `s_out`, `vt_out` must be valid, non-null pointers.
356/// - `status` must be valid.
357///
358/// # Examples (C)
359///
360/// ```c
361/// tfe_svd_frule_f64(a, left, 1, right, 1, 0, -1.0, da, &du, &ds, &dvt, &status);
362/// ```
363#[no_mangle]
364pub unsafe extern "C" fn tfe_svd_frule_f64(
365    tensor: *const TfeTensorF64,
366    left: *const usize,
367    left_len: usize,
368    right: *const usize,
369    right_len: usize,
370    max_rank: usize,
371    cutoff: f64,
372    tangent: *const TfeTensorF64,
373    u_out: *mut *mut TfeTensorF64,
374    s_out: *mut *mut TfeTensorF64,
375    vt_out: *mut *mut TfeTensorF64,
376    status: *mut tfe_status_t,
377) {
378    let result = catch_unwind(AssertUnwindSafe(|| {
379        if tensor.is_null() || u_out.is_null() || s_out.is_null() || vt_out.is_null() {
380            return Err(TFE_INVALID_ARGUMENT);
381        }
382
383        let t = handle_to_ref(tensor);
384        let left_indices = read_usize_slice(left, left_len, "svd_frule left indices")?;
385        let right_indices = read_usize_slice(right, right_len, "svd_frule right indices")?;
386        let (matrix, left_dims, right_dims) = matricize(t, left_indices, right_indices)?;
387
388        let tang = if tangent.is_null() {
389            Tensor::<f64>::zeros(
390                matrix.dims(),
391                LogicalMemorySpace::MainMemory,
392                MemoryOrder::ColumnMajor,
393            )
394            .map_err(|e| map_device_error(&e))?
395        } else {
396            let tang_tensor = handle_to_ref(tangent);
397            let (tang_matrix, _, _) = matricize(tang_tensor, left_indices, right_indices)?;
398            tang_matrix
399        };
400
401        let opts = build_svd_options(max_rank, cutoff);
402        let mut ctx = cpu_context()?;
403        let (_primal, tangent_result) =
404            svd_frule(&mut ctx, &matrix, &tang, opts.as_ref()).map_err(|e| map_ad_error(&e))?;
405
406        let u_public = u_matrix_to_public(tangent_result.u, &left_dims)?;
407        let vt_public = vt_matrix_to_public(tangent_result.vt, &right_dims)?;
408
409        *u_out = tensor_to_handle(u_public);
410        *s_out = tensor_to_handle(tangent_result.s);
411        *vt_out = tensor_to_handle(vt_public);
412        Ok(())
413    }));
414
415    finalize_void(result, status)
416}