tfe_svd_rrule_f64

Function tfe_svd_rrule_f64 

Source
#[unsafe(no_mangle)]
pub unsafe extern "C" fn tfe_svd_rrule_f64( _tensor: *const TfeTensorF64, _left: *const usize, _left_len: usize, _right: *const usize, _right_len: usize, _max_rank: usize, _cutoff: f64, _cotangent_u: *const TfeTensorF64, _cotangent_s: *const TfeTensorF64, _cotangent_vt: *const TfeTensorF64, _status: *mut tfe_status_t, ) -> *mut TfeTensorF64
Expand description

Reverse-mode rule (VJP) for SVD.

Computes the gradient of the input tensor given cotangents for U, S, and Vt. Any cotangent may be null (zero cotangent).

§Safety

  • tensor must be a valid, non-null tensor pointer.
  • left must point to left_len valid usize values.
  • right must point to right_len valid usize values.
  • cotangent_u, cotangent_s, cotangent_vt may each be null.
  • status must be a valid, non-null pointer.

§Examples (C)

size_t left[] = {0};
size_t right[] = {1};
tfe_status_t status;
// Only need gradient through singular values
tfe_tensor_f64 *grad = tfe_svd_rrule_f64(
    a, left, 1, right, 1, 0, -1.0,
    NULL,    // no cotangent for U
    cot_s,   // cotangent for S
    NULL,    // no cotangent for Vt
    &status);
tfe_tensor_f64_release(grad);