1use std::panic::{catch_unwind, AssertUnwindSafe};
2
3use tenferro_device::LogicalMemorySpace;
4use tenferro_linalg::{svd, svd_frule, svd_rrule, SvdCotangent, SvdOptions};
5use tenferro_tensor::{MemoryOrder, Tensor};
6
7use crate::ffi_utils::{cpu_context, read_usize_slice};
8use crate::handle::{handle_to_ref, tensor_to_handle, TfeTensorF64};
9use crate::status::{
10 finalize_ptr, finalize_void, map_ad_error, map_device_error, tfe_status_t,
11 TFE_INVALID_ARGUMENT, TFE_SHAPE_MISMATCH,
12};
13
14fn build_svd_options(max_rank: usize, cutoff: f64) -> Option<SvdOptions> {
15 let mr = if max_rank == 0 { None } else { Some(max_rank) };
16 let co = if cutoff < 0.0 { None } else { Some(cutoff) };
17 if mr.is_none() && co.is_none() {
18 None
19 } else {
20 Some(SvdOptions {
21 max_rank: mr,
22 cutoff: co,
23 })
24 }
25}
26
27fn matricize(
28 tensor: &Tensor<f64>,
29 left: &[usize],
30 right: &[usize],
31) -> Result<(Tensor<f64>, Vec<usize>, Vec<usize>), tfe_status_t> {
32 let dims = tensor.dims();
33 let mut seen = vec![false; dims.len()];
34
35 for &l in left {
36 if l >= dims.len() || seen[l] {
37 return Err(TFE_INVALID_ARGUMENT);
38 }
39 seen[l] = true;
40 }
41 for &r in right {
42 if r >= dims.len() || seen[r] {
43 return Err(TFE_INVALID_ARGUMENT);
44 }
45 seen[r] = true;
46 }
47 if left.len() + right.len() != dims.len() || seen.iter().any(|&v| !v) {
48 return Err(TFE_INVALID_ARGUMENT);
49 }
50
51 let left_dims: Vec<usize> = left.iter().map(|&i| dims[i]).collect();
52 let right_dims: Vec<usize> = right.iter().map(|&i| dims[i]).collect();
53 let m: usize = left_dims.iter().product();
54 let n: usize = right_dims.iter().product();
55
56 let mut perm: Vec<usize> = Vec::with_capacity(dims.len());
57 perm.extend_from_slice(left);
58 perm.extend_from_slice(right);
59
60 let permuted = tensor
61 .permute(&perm)
62 .map_err(|_| TFE_INVALID_ARGUMENT)?
63 .contiguous(MemoryOrder::ColumnMajor)
64 .reshape(&[m, n])
65 .map_err(|_| TFE_SHAPE_MISMATCH)?
66 .contiguous(MemoryOrder::ColumnMajor);
67
68 Ok((permuted, left_dims, right_dims))
69}
70
71fn inverse_permutation(perm: &[usize]) -> Vec<usize> {
72 let mut inv = vec![0; perm.len()];
73 for (i, &p) in perm.iter().enumerate() {
74 inv[p] = i;
75 }
76 inv
77}
78
79fn u_cotangent_to_matrix(
80 cot_u: &Tensor<f64>,
81 left_dims: &[usize],
82) -> Result<(Tensor<f64>, usize), tfe_status_t> {
83 let u_dims = cot_u.dims();
84 if u_dims.len() != left_dims.len() + 1 || &u_dims[..left_dims.len()] != left_dims {
85 return Err(TFE_SHAPE_MISMATCH);
86 }
87 let k = u_dims[left_dims.len()];
88 let m: usize = left_dims.iter().product();
89 let mat = cot_u
90 .contiguous(MemoryOrder::ColumnMajor)
91 .reshape(&[m, k])
92 .map_err(|_| TFE_SHAPE_MISMATCH)?
93 .contiguous(MemoryOrder::ColumnMajor);
94 Ok((mat, k))
95}
96
97fn vt_cotangent_to_matrix(
98 cot_vt: &Tensor<f64>,
99 right_dims: &[usize],
100) -> Result<(Tensor<f64>, usize), tfe_status_t> {
101 let vt_dims = cot_vt.dims();
102 if vt_dims.len() != right_dims.len() + 1 || &vt_dims[1..] != right_dims {
103 return Err(TFE_SHAPE_MISMATCH);
104 }
105 let k = vt_dims[0];
106 let n: usize = right_dims.iter().product();
107 let mat = cot_vt
108 .contiguous(MemoryOrder::ColumnMajor)
109 .reshape(&[k, n])
110 .map_err(|_| TFE_SHAPE_MISMATCH)?
111 .contiguous(MemoryOrder::ColumnMajor);
112 Ok((mat, k))
113}
114
115fn validate_s_cotangent(cot_s: &Tensor<f64>) -> Result<usize, tfe_status_t> {
116 let dims = cot_s.dims();
117 if dims.len() != 1 {
118 return Err(TFE_SHAPE_MISMATCH);
119 }
120 Ok(dims[0])
121}
122
123fn u_matrix_to_public(u: Tensor<f64>, left_dims: &[usize]) -> Result<Tensor<f64>, tfe_status_t> {
124 if u.dims().len() != 2 {
125 return Err(TFE_SHAPE_MISMATCH);
126 }
127 let k = u.dims()[1];
128 let mut out_dims = left_dims.to_vec();
129 out_dims.push(k);
130 u.reshape(&out_dims)
131 .map_err(|_| TFE_SHAPE_MISMATCH)
132 .map(|t| t.contiguous(MemoryOrder::ColumnMajor))
133}
134
135fn vt_matrix_to_public(vt: Tensor<f64>, right_dims: &[usize]) -> Result<Tensor<f64>, tfe_status_t> {
136 if vt.dims().len() != 2 {
137 return Err(TFE_SHAPE_MISMATCH);
138 }
139 let k = vt.dims()[0];
140 let mut out_dims = vec![k];
141 out_dims.extend_from_slice(right_dims);
142 vt.reshape(&out_dims)
143 .map_err(|_| TFE_SHAPE_MISMATCH)
144 .map(|t| t.contiguous(MemoryOrder::ColumnMajor))
145}
146
147fn grad_matrix_to_public(
148 grad_matrix: Tensor<f64>,
149 original_dims: &[usize],
150 left: &[usize],
151 right: &[usize],
152 left_dims: &[usize],
153 right_dims: &[usize],
154) -> Result<Tensor<f64>, tfe_status_t> {
155 if grad_matrix.dims().len() != 2 {
156 return Err(TFE_SHAPE_MISMATCH);
157 }
158
159 let mut permuted_dims = left_dims.to_vec();
160 permuted_dims.extend_from_slice(right_dims);
161 let reshaped = grad_matrix
162 .contiguous(MemoryOrder::ColumnMajor)
163 .reshape(&permuted_dims)
164 .map_err(|_| TFE_SHAPE_MISMATCH)?;
165
166 let mut perm = Vec::with_capacity(original_dims.len());
167 perm.extend_from_slice(left);
168 perm.extend_from_slice(right);
169 let inv_perm = inverse_permutation(&perm);
170
171 reshaped
172 .permute(&inv_perm)
173 .map_err(|_| TFE_INVALID_ARGUMENT)
174 .map(|t| t.contiguous(MemoryOrder::ColumnMajor))
175}
176
177#[no_mangle]
196pub unsafe extern "C" fn tfe_svd_f64(
197 tensor: *const TfeTensorF64,
198 left: *const usize,
199 left_len: usize,
200 right: *const usize,
201 right_len: usize,
202 max_rank: usize,
203 cutoff: f64,
204 u_out: *mut *mut TfeTensorF64,
205 s_out: *mut *mut TfeTensorF64,
206 vt_out: *mut *mut TfeTensorF64,
207 status: *mut tfe_status_t,
208) {
209 let result = catch_unwind(AssertUnwindSafe(|| {
210 if tensor.is_null() || u_out.is_null() || s_out.is_null() || vt_out.is_null() {
211 return Err(TFE_INVALID_ARGUMENT);
212 }
213 let t = handle_to_ref(tensor);
214 let left_indices = read_usize_slice(left, left_len, "svd left indices")?;
215 let right_indices = read_usize_slice(right, right_len, "svd right indices")?;
216
217 let (matrix, left_dims, right_dims) = matricize(t, left_indices, right_indices)?;
218
219 let opts = build_svd_options(max_rank, cutoff);
220 let mut ctx = cpu_context()?;
221 let result = svd(&mut ctx, &matrix, opts.as_ref()).map_err(|e| map_device_error(&e))?;
222
223 let k = result.s.len();
224 let mut u_dims: Vec<usize> = left_dims;
225 u_dims.push(k);
226 let u_reshaped = result
227 .u
228 .reshape(&u_dims)
229 .map_err(|e| map_device_error(&e))?
230 .contiguous(MemoryOrder::ColumnMajor);
231
232 let mut vt_dims: Vec<usize> = vec![k];
233 vt_dims.extend_from_slice(&right_dims);
234 let vt_reshaped = result
235 .vt
236 .reshape(&vt_dims)
237 .map_err(|e| map_device_error(&e))?
238 .contiguous(MemoryOrder::ColumnMajor);
239
240 *u_out = tensor_to_handle(u_reshaped);
241 *s_out = tensor_to_handle(result.s);
242 *vt_out = tensor_to_handle(vt_reshaped);
243 Ok(())
244 }));
245
246 finalize_void(result, status)
247}
248
249#[no_mangle]
263pub unsafe extern "C" fn tfe_svd_rrule_f64(
264 tensor: *const TfeTensorF64,
265 left: *const usize,
266 left_len: usize,
267 right: *const usize,
268 right_len: usize,
269 max_rank: usize,
270 cutoff: f64,
271 cotangent_u: *const TfeTensorF64,
272 cotangent_s: *const TfeTensorF64,
273 cotangent_vt: *const TfeTensorF64,
274 status: *mut tfe_status_t,
275) -> *mut TfeTensorF64 {
276 let result = catch_unwind(AssertUnwindSafe(|| {
277 if tensor.is_null() {
278 return Err(TFE_INVALID_ARGUMENT);
279 }
280
281 let t = handle_to_ref(tensor);
282 let left_indices = read_usize_slice(left, left_len, "svd_rrule left indices")?;
283 let right_indices = read_usize_slice(right, right_len, "svd_rrule right indices")?;
284 let original_dims = t.dims().to_vec();
285 let (matrix, left_dims, right_dims) = matricize(t, left_indices, right_indices)?;
286
287 let mut inferred_k: Option<usize> = None;
288 let cot_u = if cotangent_u.is_null() {
289 None
290 } else {
291 let (u_mat, k) = u_cotangent_to_matrix(handle_to_ref(cotangent_u), &left_dims)?;
292 inferred_k = Some(k);
293 Some(u_mat)
294 };
295 let cot_s = if cotangent_s.is_null() {
296 None
297 } else {
298 let cot_s_ref = handle_to_ref(cotangent_s);
299 let k = validate_s_cotangent(cot_s_ref)?;
300 if let Some(prev) = inferred_k {
301 if prev != k {
302 return Err(TFE_SHAPE_MISMATCH);
303 }
304 } else {
305 inferred_k = Some(k);
306 }
307 Some(cot_s_ref.clone())
308 };
309 let cot_vt = if cotangent_vt.is_null() {
310 None
311 } else {
312 let (vt_mat, k) = vt_cotangent_to_matrix(handle_to_ref(cotangent_vt), &right_dims)?;
313 if let Some(prev) = inferred_k {
314 if prev != k {
315 return Err(TFE_SHAPE_MISMATCH);
316 }
317 } else {
318 inferred_k = Some(k);
319 }
320 Some(vt_mat)
321 };
322 let _ = inferred_k;
323
324 let cotangent = SvdCotangent {
325 u: cot_u,
326 s: cot_s,
327 vt: cot_vt,
328 };
329
330 let opts = build_svd_options(max_rank, cutoff);
331 let mut ctx = cpu_context()?;
332 let grad_matrix = svd_rrule(&mut ctx, &matrix, &cotangent, opts.as_ref())
333 .map_err(|e| map_ad_error(&e))?;
334 let grad = grad_matrix_to_public(
335 grad_matrix,
336 &original_dims,
337 left_indices,
338 right_indices,
339 &left_dims,
340 &right_dims,
341 )?;
342
343 Ok(tensor_to_handle(grad))
344 }));
345
346 finalize_ptr(result, status)
347}
348
349#[no_mangle]
364pub unsafe extern "C" fn tfe_svd_frule_f64(
365 tensor: *const TfeTensorF64,
366 left: *const usize,
367 left_len: usize,
368 right: *const usize,
369 right_len: usize,
370 max_rank: usize,
371 cutoff: f64,
372 tangent: *const TfeTensorF64,
373 u_out: *mut *mut TfeTensorF64,
374 s_out: *mut *mut TfeTensorF64,
375 vt_out: *mut *mut TfeTensorF64,
376 status: *mut tfe_status_t,
377) {
378 let result = catch_unwind(AssertUnwindSafe(|| {
379 if tensor.is_null() || u_out.is_null() || s_out.is_null() || vt_out.is_null() {
380 return Err(TFE_INVALID_ARGUMENT);
381 }
382
383 let t = handle_to_ref(tensor);
384 let left_indices = read_usize_slice(left, left_len, "svd_frule left indices")?;
385 let right_indices = read_usize_slice(right, right_len, "svd_frule right indices")?;
386 let (matrix, left_dims, right_dims) = matricize(t, left_indices, right_indices)?;
387
388 let tang = if tangent.is_null() {
389 Tensor::<f64>::zeros(
390 matrix.dims(),
391 LogicalMemorySpace::MainMemory,
392 MemoryOrder::ColumnMajor,
393 )
394 .map_err(|e| map_device_error(&e))?
395 } else {
396 let tang_tensor = handle_to_ref(tangent);
397 let (tang_matrix, _, _) = matricize(tang_tensor, left_indices, right_indices)?;
398 tang_matrix
399 };
400
401 let opts = build_svd_options(max_rank, cutoff);
402 let mut ctx = cpu_context()?;
403 let (_primal, tangent_result) =
404 svd_frule(&mut ctx, &matrix, &tang, opts.as_ref()).map_err(|e| map_ad_error(&e))?;
405
406 let u_public = u_matrix_to_public(tangent_result.u, &left_dims)?;
407 let vt_public = vt_matrix_to_public(tangent_result.vt, &right_dims)?;
408
409 *u_out = tensor_to_handle(u_public);
410 *s_out = tensor_to_handle(tangent_result.s);
411 *vt_out = tensor_to_handle(vt_public);
412 Ok(())
413 }));
414
415 finalize_void(result, status)
416}