tfe_tropical_einsum_rrule_maxplus_f64

Function tfe_tropical_einsum_rrule_maxplus_f64 

Source
#[unsafe(no_mangle)]
pub unsafe extern "C" fn tfe_tropical_einsum_rrule_maxplus_f64( _subscripts: *const c_char, _operands: *const *const TfeTensorF64, _num_operands: usize, _cotangent: *const TfeTensorF64, _grads_out: *mut *mut TfeTensorF64, _status: *mut tfe_status_t, )
Expand description

Reverse-mode rule (VJP) for MaxPlus tropical einsum.

Computes one gradient tensor per input operand given the output cotangent. The caller must provide grads_out as a pre-allocated array of num_operands pointers.

§Safety

  • subscripts must be a valid null-terminated C string.
  • operands must point to an array of num_operands valid tensor pointers.
  • cotangent must be a valid, non-null tensor pointer.
  • grads_out must point to a caller-allocated array of num_operands mutable *mut TfeTensorF64 pointers.
  • status must be a valid, non-null pointer.

§Examples (C)

tfe_tensor_f64 *grads[2];
tfe_status_t status;
const tfe_tensor_f64 *ops[] = {a, b};
tfe_tropical_einsum_rrule_maxplus_f64("ij,jk->ik", ops, 2, grad_c, grads, &status);
tfe_tensor_f64_release(grads[0]);
tfe_tensor_f64_release(grads[1]);