1use num_complex::{Complex32, Complex64};
2use tenferro_ops::dim_expr::DimExpr;
3use tenferro_ops::std_tensor_op::StdTensorOp;
4use tenferro_tensor::{CompareDir, DType, DotGeneralConfig};
5
6use crate::sym_dim::SymDim;
7use crate::traced::{
8 apply_binary, apply_multi_output, apply_nullary, apply_unary, concrete_shape, TracedTensor,
9};
10
11pub fn convert(input: &TracedTensor, to: DType) -> TracedTensor {
21 input.convert(to)
22}
23
24fn sym_shape(shape: &[usize]) -> Vec<SymDim> {
25 shape.iter().copied().map(SymDim::from).collect()
26}
27
28fn input_shape_expr(tensor: &TracedTensor) -> Vec<DimExpr> {
29 tensor
30 .shape_hint
31 .as_ref()
32 .and_then(|shape| {
33 shape
34 .iter()
35 .map(SymDim::constant_value)
36 .collect::<Option<Vec<_>>>()
37 })
38 .map(|shape| DimExpr::from_concrete(&shape))
39 .unwrap_or_else(|| DimExpr::input_shape(0, tensor.rank))
40}
41
42pub fn svd(a: &TracedTensor) -> (TracedTensor, TracedTensor, TracedTensor) {
50 svd_with_eps(a, 1e-12)
51}
52
53pub fn svd_with_eps(a: &TracedTensor, eps: f64) -> (TracedTensor, TracedTensor, TracedTensor) {
61 let shape = concrete_shape(a);
62 let m = shape[0];
63 let n = shape[1];
64 let k = m.min(n);
65 let batch = &shape[2..];
66 let op = StdTensorOp::Svd {
67 eps,
68 input_shape: input_shape_expr(a),
69 };
70 let mut u_shape = vec![m, k];
71 u_shape.extend_from_slice(batch);
72 let mut s_shape = vec![k];
73 s_shape.extend_from_slice(batch);
74 let mut vt_shape = vec![k, n];
75 vt_shape.extend_from_slice(batch);
76 let mut results = apply_multi_output(
77 op,
78 a,
79 vec![
80 sym_shape(&u_shape),
81 sym_shape(&s_shape),
82 sym_shape(&vt_shape),
83 ],
84 )
85 .into_iter();
86 match (
87 results.next(),
88 results.next(),
89 results.next(),
90 results.next(),
91 ) {
92 (Some(u), Some(s), Some(vt), None) => (u, s, vt),
93 _ => unreachable!("svd must produce exactly three outputs"),
94 }
95}
96
97pub fn qr(a: &TracedTensor) -> (TracedTensor, TracedTensor) {
105 let shape = concrete_shape(a);
106 let m = shape[0];
107 let n = shape[1];
108 let k = m.min(n);
109 let batch = &shape[2..];
110 let mut q_shape = vec![m, k];
111 q_shape.extend_from_slice(batch);
112 let mut r_shape = vec![k, n];
113 r_shape.extend_from_slice(batch);
114 let mut results = apply_multi_output(
115 StdTensorOp::Qr {
116 input_shape: input_shape_expr(a),
117 },
118 a,
119 vec![sym_shape(&q_shape), sym_shape(&r_shape)],
120 )
121 .into_iter();
122 match (results.next(), results.next(), results.next()) {
123 (Some(q), Some(r), None) => (q, r),
124 _ => unreachable!("qr must produce exactly two outputs"),
125 }
126}
127
128pub fn eigh(a: &TracedTensor) -> (TracedTensor, TracedTensor) {
136 eigh_with_eps(a, 1e-12)
137}
138
139pub fn eigh_with_eps(a: &TracedTensor, eps: f64) -> (TracedTensor, TracedTensor) {
147 let shape = concrete_shape(a);
148 let n = shape[0];
149 let batch = &shape[2..];
150 let op = StdTensorOp::Eigh {
151 eps,
152 input_shape: input_shape_expr(a),
153 };
154 let mut vals_shape = vec![n];
155 vals_shape.extend_from_slice(batch);
156 let mut vecs_shape = vec![n, n];
157 vecs_shape.extend_from_slice(batch);
158 let mut results =
159 apply_multi_output(op, a, vec![sym_shape(&vals_shape), sym_shape(&vecs_shape)]).into_iter();
160 match (results.next(), results.next(), results.next()) {
161 (Some(values), Some(vectors), None) => (values, vectors),
162 _ => unreachable!("eigh must produce exactly two outputs"),
163 }
164}
165
166pub fn cholesky(a: &TracedTensor) -> TracedTensor {
174 let shape = concrete_shape(a);
175 apply_unary(
176 StdTensorOp::Cholesky {
177 input_shape: input_shape_expr(a),
178 },
179 a,
180 a.rank,
181 Some(sym_shape(&shape)),
182 )
183}
184
185pub fn lu(a: &TracedTensor) -> (TracedTensor, TracedTensor, TracedTensor, TracedTensor) {
195 let shape = concrete_shape(a);
196 let m = shape[0];
197 let n = shape[1];
198 let k = m.min(n);
199 let batch = &shape[2..];
200 let mut p_shape = vec![m, m];
201 p_shape.extend_from_slice(batch);
202 let mut l_shape = vec![m, k];
203 l_shape.extend_from_slice(batch);
204 let mut u_shape = vec![k, n];
205 u_shape.extend_from_slice(batch);
206 let parity_shape = batch.to_vec();
207 let mut results = apply_multi_output(
208 StdTensorOp::Lu {
209 input_shape: input_shape_expr(a),
210 },
211 a,
212 vec![
213 sym_shape(&p_shape),
214 sym_shape(&l_shape),
215 sym_shape(&u_shape),
216 sym_shape(&parity_shape),
217 ],
218 )
219 .into_iter();
220 match (
221 results.next(),
222 results.next(),
223 results.next(),
224 results.next(),
225 results.next(),
226 ) {
227 (Some(p), Some(l), Some(u), Some(parity), None) => (p, l, u, parity),
228 _ => unreachable!("lu must produce exactly four outputs"),
229 }
230}
231
232pub fn eig(a: &TracedTensor) -> (TracedTensor, TracedTensor) {
242 let shape = concrete_shape(a);
243 let n = shape[0];
244 let batch = &shape[2..];
245 let mut vals_shape = vec![n];
246 vals_shape.extend_from_slice(batch);
247 let mut vecs_shape = vec![n, n];
248 vecs_shape.extend_from_slice(batch);
249 let eig_dtype = eig_output_dtype(a.dtype);
250 let mut results = apply_multi_output(
251 StdTensorOp::Eig {
252 input_dtype: a.dtype,
253 input_shape: input_shape_expr(a),
254 },
255 a,
256 vec![sym_shape(&vals_shape), sym_shape(&vecs_shape)],
257 )
258 .into_iter();
259 match (results.next(), results.next(), results.next()) {
260 (Some(mut values), Some(mut vectors), None) => {
261 values.dtype = eig_dtype;
262 vectors.dtype = eig_dtype;
263 (values, vectors)
264 }
265 _ => unreachable!("eig must produce exactly two outputs"),
266 }
267}
268
269fn validate_nonsingular(u: &TracedTensor) -> TracedTensor {
270 apply_unary(
271 StdTensorOp::ValidateNonsingular {
272 input_shape: input_shape_expr(u),
273 },
274 u,
275 u.rank,
276 u.shape_hint.clone(),
277 )
278}
279
280pub fn solve(a: &TracedTensor, b: &TracedTensor) -> TracedTensor {
288 let a_shape = concrete_shape(a);
289 let b_shape = concrete_shape(b);
290 if has_zero_dim(&a_shape) || has_zero_dim(&b_shape) {
291 return zeros_like(b);
292 }
293
294 let do_solve = |a: &TracedTensor, b: &TracedTensor| -> TracedTensor {
295 let (p, l, u, _) = lu(a);
296 let u = validate_nonsingular(&u);
297 let pb = matmul_preserve_trailing_batch(&p, b);
298 let z = triangular_solve(&l, &pb, true, true, false, true);
299 triangular_solve(&u, &z, true, false, false, false)
300 };
301
302 if let Some(matrix_rhs_shape) = batched_vector_rhs_shape(a, b) {
303 let b2d = b.reshape(&matrix_rhs_shape);
304 let x2d = do_solve(a, &b2d);
305 x2d.reshape(&b_shape)
306 } else {
307 do_solve(a, b)
308 }
309}
310
311pub fn triangular_solve(
319 a: &TracedTensor,
320 b: &TracedTensor,
321 left_side: bool,
322 lower: bool,
323 transpose_a: bool,
324 unit_diagonal: bool,
325) -> TracedTensor {
326 let a_shape = concrete_shape(a);
327 let b_shape = concrete_shape(b);
328 if has_zero_dim(&a_shape) || has_zero_dim(&b_shape) {
329 return zeros_like(b);
330 }
331 let op = StdTensorOp::TriangularSolve {
332 left_side,
333 lower,
334 transpose_a,
335 unit_diagonal,
336 lhs_shape: input_shape_expr(a),
337 rhs_shape: input_shape_expr(b),
338 };
339 if let Some(matrix_rhs_shape) = batched_vector_rhs_shape(a, b) {
340 let b2d = b.reshape(&matrix_rhs_shape);
341 let x2d = apply_binary(
342 StdTensorOp::TriangularSolve {
343 left_side,
344 lower,
345 transpose_a,
346 unit_diagonal,
347 lhs_shape: input_shape_expr(a),
348 rhs_shape: DimExpr::from_concrete(&matrix_rhs_shape),
349 },
350 a,
351 &b2d,
352 matrix_rhs_shape.len(),
353 Some(sym_shape(&matrix_rhs_shape)),
354 );
355 x2d.reshape(&b_shape)
356 } else {
357 apply_binary(op, a, b, b.rank, b.shape_hint.clone())
358 }
359}
360
361pub fn slogdet(a: &TracedTensor) -> (TracedTensor, TracedTensor) {
369 let shape = concrete_shape(a);
370 let batch_shape = &shape[2..];
371 if has_zero_dim(&shape) {
372 let sign = broadcast_scalar(one_scalar(a.dtype), batch_shape);
373 let logabsdet = broadcast_scalar(zero_scalar(a.dtype), batch_shape);
374 return (sign, logabsdet);
375 }
376
377 let (_, _, u, parity) = lu(a);
378 let diag_u = u.extract_diag(0, 1);
379 let sign_u = reduce_prod(&diag_u.sign(), &[0]);
380 let sign = &parity * &sign_u;
381 let logabsdet = diag_u.abs().log().reduce_sum(&[0]);
382 (sign, logabsdet)
383}
384
385pub fn det(a: &TracedTensor) -> TracedTensor {
393 let (sign, logabsdet) = slogdet(a);
394 &sign * &logabsdet.exp()
395}
396
397pub fn inv(a: &TracedTensor) -> TracedTensor {
405 let shape = concrete_shape(a);
406 let eye = eye_like(a, shape[0]);
407 solve(a, &eye)
408}
409
410pub fn eigvalsh(a: &TracedTensor) -> TracedTensor {
418 eigh(a).0
419}
420
421pub fn eigvals(a: &TracedTensor) -> TracedTensor {
429 eig(a).0
430}
431
432pub fn pinv(a: &TracedTensor) -> TracedTensor {
440 let shape = concrete_shape(a);
441 let max_dim = match (shape.first(), shape.get(1)) {
442 (Some(&m), Some(&n)) => m.max(n),
443 (Some(&m), None) => m,
444 _ => 0,
445 };
446 pinv_with_rtol(a, default_pinv_rtol(a.dtype, max_dim))
447}
448
449pub fn pinv_with_rtol(a: &TracedTensor, rtol: f64) -> TracedTensor {
459 let shape = concrete_shape(a);
460 if has_zero_dim(&shape) {
461 let mut out_shape = vec![shape[1], shape[0]];
462 out_shape.extend_from_slice(&shape[2..]);
463 return zeros_of_shape(a.dtype, out_shape);
464 }
465
466 let (u, s, vt) = svd(a);
467 let abs_s = s.abs();
468 let s_max = reduce_max(&abs_s, &[0]);
469 let s_max_shape = concrete_shape(&s_max);
470 let threshold = &s_max * &broadcast_scalar(scalar_real(s.dtype, rtol.max(0.0)), &s_max_shape);
471 let s_shape = concrete_shape(&s);
472 let threshold = broadcast_batch_scalar_to_leading_axis(&threshold, &s_shape);
473 let mask = compare_dir(&abs_s, &threshold, CompareDir::Gt);
474 let ones = ones_like(&s);
475 let denom = &s + &(&ones + &(-&mask));
476 let s_inv = &mask / &denom;
477
478 let v = vt.conj().transpose(&matrix_transpose_perm(vt.rank));
479 let uh = u.conj().transpose(&matrix_transpose_perm(u.rank));
480 let s_inv_diag = s_inv.embed_diag(0, 1);
481 let vs = matmul_preserve_trailing_batch(&v, &s_inv_diag);
482 matmul_preserve_trailing_batch(&vs, &uh)
483}
484
485pub fn norm(
495 a: &TracedTensor,
496 ord: Option<f64>,
497 dim: Option<&[usize]>,
498 keepdim: bool,
499) -> TracedTensor {
500 let axes = dim.map_or_else(|| (0..a.rank).collect::<Vec<_>>(), |dims| dims.to_vec());
501 if axes.is_empty() {
502 return a.clone();
503 }
504
505 let out = match axes.len() {
506 1 => vector_norm(a, axes[0], ord),
507 2 => matrix_norm(a, &axes, ord),
508 _ => {
509 let abs = a.abs();
510 match ord {
511 None => frobenius_norm(&abs, &axes),
512 Some(p) if p == f64::INFINITY => reduce_max(&abs, &axes),
513 Some(p) if p == f64::NEG_INFINITY => reduce_min(&abs, &axes),
514 Some(p) if p == 0.0 => count_nonzero(&abs, &axes),
515 Some(p) => p_norm(&abs, &axes, p),
516 }
517 }
518 };
519 let shape = concrete_shape(a);
520 restore_keepdim(out, &shape, &axes, keepdim)
521}
522
523fn eig_output_dtype(dtype: DType) -> DType {
524 match dtype {
525 DType::F64 | DType::C64 => DType::C64,
526 DType::F32 | DType::C32 => DType::C32,
527 }
528}
529
530fn scalar_real(dtype: DType, value: f64) -> TracedTensor {
531 match dtype {
532 DType::F64 => apply_nullary(
533 StdTensorOp::constant_f64(value),
534 0,
535 DType::F64,
536 Some(vec![]),
537 ),
538 DType::F32 => apply_nullary(
539 StdTensorOp::constant_f32(value as f32),
540 0,
541 DType::F32,
542 Some(vec![]),
543 ),
544 DType::C64 => apply_nullary(
545 StdTensorOp::constant_c64(Complex64::new(value, 0.0)),
546 0,
547 DType::C64,
548 Some(vec![]),
549 ),
550 DType::C32 => apply_nullary(
551 StdTensorOp::constant_c32(Complex32::new(value as f32, 0.0)),
552 0,
553 DType::C32,
554 Some(vec![]),
555 ),
556 }
557}
558
559fn zero_scalar(dtype: DType) -> TracedTensor {
560 scalar_real(dtype, 0.0)
561}
562
563fn one_scalar(dtype: DType) -> TracedTensor {
564 scalar_real(dtype, 1.0)
565}
566
567fn zeros_like(input: &TracedTensor) -> TracedTensor {
568 zeros_of_shape(input.dtype, concrete_shape(input))
569}
570
571fn zeros_of_shape(dtype: DType, shape: Vec<usize>) -> TracedTensor {
572 broadcast_scalar(zero_scalar(dtype), &shape)
573}
574
575fn ones_like(input: &TracedTensor) -> TracedTensor {
576 let shape = concrete_shape(input);
577 broadcast_scalar(one_scalar(input.dtype), &shape)
578}
579
580fn eye_like(anchor: &TracedTensor, size: usize) -> TracedTensor {
581 let mut vector_shape = vec![size];
582 let anchor_shape = concrete_shape(anchor);
583 vector_shape.extend_from_slice(&anchor_shape[2..]);
584 let diagonal = broadcast_scalar(one_scalar(anchor.dtype), &vector_shape);
585 diagonal.embed_diag(0, 1)
586}
587
588fn frobenius_norm(abs: &TracedTensor, axes: &[usize]) -> TracedTensor {
589 let squared = abs.pow(&scalar_real(abs.dtype, 2.0));
590 squared.reduce_sum(axes).sqrt()
591}
592
593fn p_norm(abs: &TracedTensor, axes: &[usize], p: f64) -> TracedTensor {
594 let power = abs.pow(&scalar_real(abs.dtype, p));
595 let inv_p = scalar_real(abs.dtype, 1.0 / p);
596 power.reduce_sum(axes).pow(&inv_p)
597}
598
599fn default_pinv_rtol(dtype: DType, max_dim: usize) -> f64 {
600 let eps = match dtype {
601 DType::F32 | DType::C32 => f32::EPSILON as f64,
602 DType::F64 | DType::C64 => f64::EPSILON,
603 };
604 eps * max_dim as f64
605}
606
607fn vector_norm(a: &TracedTensor, axis: usize, ord: Option<f64>) -> TracedTensor {
608 let abs = a.abs();
609 match ord {
610 None => frobenius_norm(&abs, &[axis]),
611 Some(p) if p == 0.0 => count_nonzero(&abs, &[axis]),
612 Some(p) if p == f64::INFINITY => reduce_max(&abs, &[axis]),
613 Some(p) if p == f64::NEG_INFINITY => reduce_min(&abs, &[axis]),
614 Some(p) => p_norm(&abs, &[axis], p),
615 }
616}
617
618fn matrix_norm(a: &TracedTensor, axes: &[usize], ord: Option<f64>) -> TracedTensor {
619 let matrix = move_axes_to_front(a, axes);
620 let abs = matrix.abs();
621 match ord {
622 None => frobenius_norm(&abs, &[0, 1]),
623 Some(p) if p == f64::INFINITY => matrix_row_sum_norm(&abs, true),
624 Some(p) if p == f64::NEG_INFINITY => matrix_row_sum_norm(&abs, false),
625 Some(p) if p == 1.0 => matrix_col_sum_norm(&abs, true),
626 Some(p) if p == -1.0 => matrix_col_sum_norm(&abs, false),
627 Some(p) if p == 2.0 => {
628 let singular_values = svd(&matrix).1.abs();
629 reduce_max(&singular_values, &[0])
630 }
631 Some(p) if p == -2.0 => {
632 let singular_values = svd(&matrix).1.abs();
633 reduce_min(&singular_values, &[0])
634 }
635 Some(p) if p == 0.0 => count_nonzero(&abs, &[0, 1]),
636 Some(p) => p_norm(&abs, &[0, 1], p),
637 }
638}
639
640fn count_nonzero(abs: &TracedTensor, axes: &[usize]) -> TracedTensor {
641 let mask = compare_dir(abs, &zero_scalar(abs.dtype), CompareDir::Gt);
642 mask.reduce_sum(axes)
643}
644
645fn matrix_row_sum_norm(abs: &TracedTensor, take_max: bool) -> TracedTensor {
646 let row_sums = abs.reduce_sum(&[1]);
647 if take_max {
648 reduce_max(&row_sums, &[0])
649 } else {
650 reduce_min(&row_sums, &[0])
651 }
652}
653
654fn matrix_col_sum_norm(abs: &TracedTensor, take_max: bool) -> TracedTensor {
655 let col_sums = abs.reduce_sum(&[0]);
656 if take_max {
657 reduce_max(&col_sums, &[0])
658 } else {
659 reduce_min(&col_sums, &[0])
660 }
661}
662
663fn move_axes_to_front(tensor: &TracedTensor, axes: &[usize]) -> TracedTensor {
664 if axes.iter().enumerate().all(|(index, &axis)| index == axis) {
665 return tensor.clone();
666 }
667
668 let mut selected = vec![false; tensor.rank];
669 for &axis in axes {
670 selected[axis] = true;
671 }
672
673 let mut perm = Vec::with_capacity(tensor.rank);
674 perm.extend_from_slice(axes);
675 for axis in 0..tensor.rank {
676 if !selected[axis] {
677 perm.push(axis);
678 }
679 }
680 tensor.transpose(&perm)
681}
682
683fn restore_keepdim(
684 reduced: TracedTensor,
685 original_shape: &[usize],
686 axes: &[usize],
687 keepdim: bool,
688) -> TracedTensor {
689 if !keepdim {
690 return reduced;
691 }
692 let mut kept_shape = original_shape.to_vec();
693 for &axis in axes {
694 kept_shape[axis] = 1;
695 }
696 reduced.reshape(&kept_shape)
697}
698
699fn reduce_prod(input: &TracedTensor, axes: &[usize]) -> TracedTensor {
700 let input_shape = concrete_shape(input);
701 let out_shape = reduced_shape(&input_shape, axes);
702 apply_unary(
703 StdTensorOp::ReduceProd {
704 axes: axes.to_vec(),
705 input_shape: DimExpr::input_shape(0, input.rank),
706 },
707 input,
708 input.rank - axes.len(),
709 Some(sym_shape(&out_shape)),
710 )
711}
712
713fn reduce_max(input: &TracedTensor, axes: &[usize]) -> TracedTensor {
714 let input_shape = concrete_shape(input);
715 let out_shape = reduced_shape(&input_shape, axes);
716 apply_unary(
717 StdTensorOp::ReduceMax {
718 axes: axes.to_vec(),
719 input_shape: DimExpr::input_shape(0, input.rank),
720 },
721 input,
722 input.rank - axes.len(),
723 Some(sym_shape(&out_shape)),
724 )
725}
726
727fn reduce_min(input: &TracedTensor, axes: &[usize]) -> TracedTensor {
728 let input_shape = concrete_shape(input);
729 let out_shape = reduced_shape(&input_shape, axes);
730 apply_unary(
731 StdTensorOp::ReduceMin {
732 axes: axes.to_vec(),
733 input_shape: DimExpr::input_shape(0, input.rank),
734 },
735 input,
736 input.rank - axes.len(),
737 Some(sym_shape(&out_shape)),
738 )
739}
740
741fn compare_dir(lhs: &TracedTensor, rhs: &TracedTensor, dir: CompareDir) -> TracedTensor {
742 let (lhs, rhs) = broadcast_binary(lhs, rhs);
743 apply_binary(
744 StdTensorOp::Compare(dir),
745 &lhs,
746 &rhs,
747 lhs.rank,
748 lhs.shape_hint.clone(),
749 )
750}
751
752fn broadcast_scalar(input: TracedTensor, shape: &[usize]) -> TracedTensor {
753 let input_shape = concrete_shape(&input);
754 if input_shape == shape {
755 return input;
756 }
757 input.broadcast_in_dim(shape, &[])
758}
759
760fn broadcast_batch_scalar_to_leading_axis(input: &TracedTensor, shape: &[usize]) -> TracedTensor {
761 let input_shape = concrete_shape(input);
762 if input_shape == shape {
763 return input.clone();
764 }
765 let dims: Vec<usize> = (1..shape.len()).collect();
766 input.broadcast_in_dim(shape, &dims)
767}
768
769fn matmul_preserve_trailing_batch(lhs: &TracedTensor, rhs: &TracedTensor) -> TracedTensor {
770 let rank = lhs.rank;
771 let batch_dims: Vec<usize> = (2..rank).collect();
772 lhs.dot_general(
773 rhs,
774 DotGeneralConfig {
775 lhs_contracting_dims: vec![1],
776 rhs_contracting_dims: vec![0],
777 lhs_batch_dims: batch_dims.clone(),
778 rhs_batch_dims: batch_dims,
779 lhs_rank: rank,
780 rhs_rank: rank,
781 },
782 )
783}
784
785fn matrix_transpose_perm(rank: usize) -> Vec<usize> {
786 let mut perm: Vec<usize> = (0..rank).collect();
787 perm.swap(0, 1);
788 perm
789}
790
791fn reduced_shape(shape: &[usize], axes: &[usize]) -> Vec<usize> {
792 (0..shape.len())
793 .filter(|axis| !axes.contains(axis))
794 .map(|axis| shape[axis])
795 .collect()
796}
797
798fn batched_vector_rhs_shape(a: &TracedTensor, b: &TracedTensor) -> Option<Vec<usize>> {
799 let a_shape = concrete_shape(a);
800 let b_shape = concrete_shape(b);
801
802 if b_shape.len() == 1 {
803 return Some(vec![b_shape[0], 1]);
804 }
805
806 let is_batched_vector_rhs = a_shape.len() == b_shape.len() + 1
807 && !b_shape.is_empty()
808 && b_shape[0] == a_shape[0]
809 && b_shape[1..] == a_shape[2..];
810 if !is_batched_vector_rhs {
811 return None;
812 }
813
814 let mut rhs_shape = vec![b_shape[0], 1];
815 rhs_shape.extend_from_slice(&b_shape[1..]);
816 Some(rhs_shape)
817}
818
819fn has_zero_dim(shape: &[usize]) -> bool {
820 shape.contains(&0)
821}
822
823fn broadcast_shape(a: &[usize], b: &[usize]) -> Option<Vec<usize>> {
824 let rank = a.len().max(b.len());
825 let mut result = Vec::with_capacity(rank);
826 for index in 0..rank {
827 let a_dim = if index < rank - a.len() {
828 1
829 } else {
830 a[index - (rank - a.len())]
831 };
832 let b_dim = if index < rank - b.len() {
833 1
834 } else {
835 b[index - (rank - b.len())]
836 };
837 if a_dim == b_dim {
838 result.push(a_dim);
839 } else if a_dim == 1 {
840 result.push(b_dim);
841 } else if b_dim == 1 {
842 result.push(a_dim);
843 } else {
844 return None;
845 }
846 }
847 Some(result)
848}
849
850fn broadcast_to(tensor: &TracedTensor, target_shape: &[usize]) -> TracedTensor {
851 let tensor_shape = concrete_shape(tensor);
852 if tensor_shape == target_shape {
853 return tensor.clone();
854 }
855
856 assert!(
857 tensor.rank <= target_shape.len(),
858 "cannot broadcast higher-rank shape {:?} to {:?}",
859 tensor_shape,
860 target_shape
861 );
862
863 let rank_diff = target_shape.len() - tensor.rank;
864 let mut source_shape = Vec::with_capacity(tensor.rank);
865 let mut dims = Vec::with_capacity(tensor.rank);
866 for (src_axis, &src_dim) in tensor_shape.iter().enumerate() {
867 let dst_axis = src_axis + rank_diff;
868 let dst_dim = target_shape[dst_axis];
869 assert!(
870 src_dim == dst_dim || src_dim == 1,
871 "cannot broadcast shape {:?} to {:?}",
872 tensor_shape,
873 target_shape
874 );
875 if src_dim == 1 && dst_dim != 1 {
876 continue;
877 }
878 source_shape.push(src_dim);
879 dims.push(dst_axis);
880 }
881
882 let source = if source_shape == tensor_shape {
883 tensor.clone()
884 } else {
885 tensor.reshape(&source_shape)
886 };
887 source.broadcast_in_dim(target_shape, &dims)
888}
889
890fn broadcast_binary(a: &TracedTensor, b: &TracedTensor) -> (TracedTensor, TracedTensor) {
891 if a.shape_hint == b.shape_hint && a.rank == b.rank {
892 return (a.clone(), b.clone());
893 }
894 let a_shape = concrete_shape(a);
895 let b_shape = concrete_shape(b);
896 let target = broadcast_shape(&a_shape, &b_shape).unwrap_or_else(|| {
897 panic!(
898 "incompatible shapes for broadcast: {:?} and {:?}",
899 a_shape, b_shape
900 )
901 });
902 (broadcast_to(a, &target), broadcast_to(b, &target))
903}