1use std::collections::HashMap;
7
8use tenferro_einsum::Subscripts;
9use tenferro_ops::dim_expr::DimExpr;
10use tenferro_ops::std_tensor_op::StdTensorOp;
11use tenferro_tensor::{DType, DotGeneralConfig, GatherConfig, PadConfig, SliceConfig};
12
13pub fn infer_output_dtype(op: &StdTensorOp, input_dtypes: &[DType]) -> DType {
18 match op {
19 StdTensorOp::Constant { dtype, .. } => *dtype,
20 StdTensorOp::Convert { to, .. } => *to,
21 StdTensorOp::Eig { input_dtype, .. } => match input_dtype {
22 DType::F32 | DType::C32 => DType::C32,
23 DType::F64 | DType::C64 => DType::C64,
24 },
25 StdTensorOp::Add
26 | StdTensorOp::Mul
27 | StdTensorOp::Neg
28 | StdTensorOp::Conj
29 | StdTensorOp::Div
30 | StdTensorOp::Abs
31 | StdTensorOp::Sign
32 | StdTensorOp::Maximum
33 | StdTensorOp::Minimum
34 | StdTensorOp::Select
35 | StdTensorOp::Clamp
36 | StdTensorOp::Exp
37 | StdTensorOp::Log
38 | StdTensorOp::Sin
39 | StdTensorOp::Cos
40 | StdTensorOp::Tanh
41 | StdTensorOp::Sqrt
42 | StdTensorOp::Rsqrt
43 | StdTensorOp::Pow
44 | StdTensorOp::Expm1
45 | StdTensorOp::Log1p
46 | StdTensorOp::Transpose { .. }
47 | StdTensorOp::Reshape { .. }
48 | StdTensorOp::BroadcastInDim { .. }
49 | StdTensorOp::ReduceSum { .. }
50 | StdTensorOp::ReduceProd { .. }
51 | StdTensorOp::ReduceMax { .. }
52 | StdTensorOp::ReduceMin { .. }
53 | StdTensorOp::ExtractDiag { .. }
54 | StdTensorOp::EmbedDiag { .. }
55 | StdTensorOp::Tril { .. }
56 | StdTensorOp::Triu { .. }
57 | StdTensorOp::Gather(_)
58 | StdTensorOp::Scatter(_)
59 | StdTensorOp::Slice(_)
60 | StdTensorOp::DynamicSlice { .. }
61 | StdTensorOp::Pad(_)
62 | StdTensorOp::Concatenate { .. }
63 | StdTensorOp::Reverse { .. }
64 | StdTensorOp::DotGeneral(_)
65 | StdTensorOp::NaryEinsum { .. }
66 | StdTensorOp::Cholesky { .. }
67 | StdTensorOp::Lu { .. }
68 | StdTensorOp::Svd { .. }
69 | StdTensorOp::Qr { .. }
70 | StdTensorOp::Eigh { .. }
71 | StdTensorOp::TriangularSolve { .. }
72 | StdTensorOp::ValidateNonsingular { .. }
73 | StdTensorOp::DynamicTruncate { .. }
74 | StdTensorOp::PadToMatch { .. } => input_dtypes[0],
75 StdTensorOp::Compare(_) => input_dtypes[0],
76 StdTensorOp::ShapeOf { .. } => DType::F64,
77 }
78}
79
80pub fn infer_output_shapes(op: &StdTensorOp, input_shapes: &[&[DimExpr]]) -> Vec<Vec<DimExpr>> {
86 match op {
87 StdTensorOp::Add => vec![same_or_scalar_broadcast_shape(
88 require_input(op, input_shapes, 0),
89 require_input(op, input_shapes, 1),
90 )],
91 StdTensorOp::Mul => vec![same_or_scalar_broadcast_shape(
92 require_input(op, input_shapes, 0),
93 require_input(op, input_shapes, 1),
94 )],
95 StdTensorOp::Neg
96 | StdTensorOp::Conj
97 | StdTensorOp::Div
98 | StdTensorOp::Abs
99 | StdTensorOp::Sign
100 | StdTensorOp::Maximum
101 | StdTensorOp::Minimum
102 | StdTensorOp::Compare(_)
103 | StdTensorOp::Select
104 | StdTensorOp::Clamp
105 | StdTensorOp::Exp
106 | StdTensorOp::Log
107 | StdTensorOp::Sin
108 | StdTensorOp::Cos
109 | StdTensorOp::Tanh
110 | StdTensorOp::Sqrt
111 | StdTensorOp::Rsqrt
112 | StdTensorOp::Pow
113 | StdTensorOp::Expm1
114 | StdTensorOp::Log1p
115 | StdTensorOp::Convert { .. }
116 | StdTensorOp::Tril { .. }
117 | StdTensorOp::Triu { .. }
118 | StdTensorOp::Reverse { .. }
119 | StdTensorOp::Scatter(_)
120 | StdTensorOp::Cholesky { .. }
121 | StdTensorOp::ValidateNonsingular { .. } => {
122 vec![require_input(op, input_shapes, 0).to_vec()]
123 }
124 StdTensorOp::Transpose { perm } => {
125 vec![permute_shape(require_input(op, input_shapes, 0), perm)]
126 }
127 StdTensorOp::Reshape { to_shape, .. } => vec![to_shape.clone()],
128 StdTensorOp::BroadcastInDim { shape, .. } => vec![shape.clone()],
129 StdTensorOp::Constant { .. } => vec![Vec::new()],
130 StdTensorOp::ReduceSum { axes, .. }
131 | StdTensorOp::ReduceProd { axes, .. }
132 | StdTensorOp::ReduceMax { axes, .. }
133 | StdTensorOp::ReduceMin { axes, .. } => {
134 vec![reduced_shape(require_input(op, input_shapes, 0), axes)]
135 }
136 StdTensorOp::ExtractDiag { axis_a, axis_b } => {
137 vec![extract_diag_shape(
138 require_input(op, input_shapes, 0),
139 *axis_a,
140 *axis_b,
141 )]
142 }
143 StdTensorOp::EmbedDiag { axis_a, axis_b } => {
144 vec![embed_diag_shape(
145 require_input(op, input_shapes, 0),
146 *axis_a,
147 *axis_b,
148 )]
149 }
150 StdTensorOp::Gather(config) => vec![gather_shape(
151 require_input(op, input_shapes, 0),
152 require_input(op, input_shapes, 1),
153 config,
154 )],
155 StdTensorOp::Slice(config) => vec![slice_shape(require_input(op, input_shapes, 0), config)],
156 StdTensorOp::DynamicSlice { slice_sizes } => {
157 vec![slice_sizes.iter().copied().map(DimExpr::Const).collect()]
158 }
159 StdTensorOp::Pad(config) => vec![pad_shape(require_input(op, input_shapes, 0), config)],
160 StdTensorOp::DotGeneral(config) => vec![dot_general_shape(
161 require_input(op, input_shapes, 0),
162 require_input(op, input_shapes, 1),
163 config,
164 )],
165 StdTensorOp::NaryEinsum {
166 subscripts,
167 n_inputs,
168 } => {
169 assert_eq!(
170 input_shapes.len(),
171 *n_inputs,
172 "NaryEinsum expects {n_inputs} inputs, got {}",
173 input_shapes.len()
174 );
175 vec![einsum_output_shape(subscripts, input_shapes)]
176 }
177 StdTensorOp::Concatenate { axis } => vec![concatenate_shape(input_shapes, *axis)],
178 StdTensorOp::ShapeOf { .. } => vec![Vec::new()],
179 StdTensorOp::DynamicTruncate { axis } => {
180 let shape = require_input(op, input_shapes, 0).to_vec();
181 assert!(
182 *axis < shape.len(),
183 "DynamicTruncate axis {axis} out of bounds for rank {}",
184 shape.len()
185 );
186 vec![shape]
187 }
188 StdTensorOp::PadToMatch { axis } => vec![pad_to_match_shape(
189 require_input(op, input_shapes, 0),
190 require_input(op, input_shapes, 1),
191 *axis,
192 )],
193 StdTensorOp::Lu { .. } => lu_shapes(require_input(op, input_shapes, 0)),
194 StdTensorOp::Svd { .. } => svd_shapes(require_input(op, input_shapes, 0)),
195 StdTensorOp::Qr { .. } => qr_shapes(require_input(op, input_shapes, 0)),
196 StdTensorOp::Eigh { .. } | StdTensorOp::Eig { .. } => {
197 eig_like_shapes(require_input(op, input_shapes, 0))
198 }
199 StdTensorOp::TriangularSolve { .. } => vec![require_input(op, input_shapes, 1).to_vec()],
200 }
201}
202
203fn require_input<'a>(
204 op: &StdTensorOp,
205 input_shapes: &'a [&[DimExpr]],
206 idx: usize,
207) -> &'a [DimExpr] {
208 input_shapes.get(idx).copied().unwrap_or_else(|| {
209 panic!(
210 "{op:?} expects input index {idx}, got {} input shapes",
211 input_shapes.len()
212 )
213 })
214}
215
216fn permute_shape(input_shape: &[DimExpr], perm: &[usize]) -> Vec<DimExpr> {
217 perm.iter().map(|&axis| input_shape[axis].clone()).collect()
218}
219
220fn reduced_shape(input_shape: &[DimExpr], axes: &[usize]) -> Vec<DimExpr> {
221 input_shape
222 .iter()
223 .enumerate()
224 .filter_map(|(axis, dim)| (!axes.contains(&axis)).then_some(dim.clone()))
225 .collect()
226}
227
228fn same_or_scalar_broadcast_shape(lhs_shape: &[DimExpr], rhs_shape: &[DimExpr]) -> Vec<DimExpr> {
229 if lhs_shape.is_empty() {
230 rhs_shape.to_vec()
231 } else if rhs_shape.is_empty() {
232 lhs_shape.to_vec()
233 } else {
234 lhs_shape.to_vec()
238 }
239}
240
241fn extract_diag_shape(input_shape: &[DimExpr], axis_a: usize, axis_b: usize) -> Vec<DimExpr> {
242 assert!(
243 axis_a < input_shape.len() && axis_b < input_shape.len(),
244 "ExtractDiag axes ({axis_a}, {axis_b}) out of bounds for rank {}",
245 input_shape.len()
246 );
247 assert_ne!(axis_a, axis_b, "ExtractDiag requires distinct axes");
248 let diag_output_axis = if axis_a < axis_b { axis_a } else { axis_a - 1 };
249 let diag_dim = dim_min(input_shape[axis_a].clone(), input_shape[axis_b].clone());
250 let mut output_shape = input_shape.to_vec();
251 output_shape.remove(axis_b);
252 output_shape[diag_output_axis] = diag_dim;
253 output_shape
254}
255
256fn embed_diag_shape(input_shape: &[DimExpr], axis_a: usize, axis_b: usize) -> Vec<DimExpr> {
257 assert!(
258 axis_a < input_shape.len(),
259 "EmbedDiag axis_a {axis_a} out of bounds for rank {}",
260 input_shape.len()
261 );
262 assert!(
263 axis_b <= input_shape.len(),
264 "EmbedDiag axis_b {axis_b} out of bounds for rank {}",
265 input_shape.len()
266 );
267 let mut output_shape = input_shape.to_vec();
268 output_shape.insert(axis_b, input_shape[axis_a].clone());
269 output_shape
270}
271
272fn dot_general_shape(
273 lhs_shape: &[DimExpr],
274 rhs_shape: &[DimExpr],
275 config: &DotGeneralConfig,
276) -> Vec<DimExpr> {
277 assert_eq!(
278 lhs_shape.len(),
279 config.lhs_rank,
280 "DotGeneral lhs rank mismatch: config={}, actual={}",
281 config.lhs_rank,
282 lhs_shape.len()
283 );
284 assert_eq!(
285 rhs_shape.len(),
286 config.rhs_rank,
287 "DotGeneral rhs rank mismatch: config={}, actual={}",
288 config.rhs_rank,
289 rhs_shape.len()
290 );
291
292 let lhs_free = (0..config.lhs_rank).filter(|axis| {
293 !config.lhs_contracting_dims.contains(axis) && !config.lhs_batch_dims.contains(axis)
294 });
295 let rhs_free = (0..config.rhs_rank).filter(|axis| {
296 !config.rhs_contracting_dims.contains(axis) && !config.rhs_batch_dims.contains(axis)
297 });
298
299 let mut output_shape = Vec::new();
300 output_shape.extend(lhs_free.map(|axis| lhs_shape[axis].clone()));
301 output_shape.extend(rhs_free.map(|axis| rhs_shape[axis].clone()));
302 output_shape.extend(
303 config
304 .lhs_batch_dims
305 .iter()
306 .map(|&axis| lhs_shape[axis].clone()),
307 );
308 output_shape
309}
310
311fn gather_shape(
312 operand_shape: &[DimExpr],
313 index_shape: &[DimExpr],
314 config: &GatherConfig,
315) -> Vec<DimExpr> {
316 assert_eq!(
317 config.slice_sizes.len(),
318 operand_shape.len(),
319 "gather: slice_sizes rank mismatch"
320 );
321
322 let batch_shape = if config.index_vector_dim == index_shape.len() {
323 index_shape.to_vec()
324 } else {
325 index_shape
326 .iter()
327 .enumerate()
328 .filter_map(|(axis, dim)| (axis != config.index_vector_dim).then_some(dim.clone()))
329 .collect()
330 };
331
332 let window_dims: Vec<usize> = (0..operand_shape.len())
333 .filter(|dim| !config.collapsed_slice_dims.contains(dim))
334 .collect();
335 assert_eq!(
336 config.offset_dims.len(),
337 window_dims.len(),
338 "gather: offset_dims length mismatch"
339 );
340
341 let out_rank = batch_shape.len() + config.offset_dims.len();
342 let mut out_shape = vec![DimExpr::Const(0); out_rank];
343 let mut out_axis_to_operand_dim = vec![None; out_rank];
344 for (offset_axis, &out_axis) in config.offset_dims.iter().enumerate() {
345 out_axis_to_operand_dim[out_axis] = Some(window_dims[offset_axis]);
346 }
347
348 let mut batch_axis = 0usize;
349 for out_axis in 0..out_rank {
350 if let Some(operand_dim) = out_axis_to_operand_dim[out_axis] {
351 out_shape[out_axis] = DimExpr::Const(config.slice_sizes[operand_dim]);
352 } else {
353 out_shape[out_axis] = batch_shape[batch_axis].clone();
354 batch_axis += 1;
355 }
356 }
357
358 out_shape
359}
360
361fn slice_shape(input_shape: &[DimExpr], config: &SliceConfig) -> Vec<DimExpr> {
362 let rank = input_shape.len();
363 assert_eq!(config.starts.len(), rank, "slice: starts rank mismatch");
364 assert_eq!(config.limits.len(), rank, "slice: limits rank mismatch");
365 assert_eq!(config.strides.len(), rank, "slice: strides rank mismatch");
366 (0..rank)
367 .map(|axis| {
368 let span = config.limits[axis] - config.starts[axis];
369 DimExpr::Const((span + config.strides[axis] - 1) / config.strides[axis])
370 })
371 .collect()
372}
373
374fn pad_shape(input_shape: &[DimExpr], config: &PadConfig) -> Vec<DimExpr> {
375 let rank = input_shape.len();
376 assert_eq!(
377 config.edge_padding_low.len(),
378 rank,
379 "pad: edge_padding_low rank mismatch"
380 );
381 assert_eq!(
382 config.edge_padding_high.len(),
383 rank,
384 "pad: edge_padding_high rank mismatch"
385 );
386 assert_eq!(
387 config.interior_padding.len(),
388 rank,
389 "pad: interior_padding rank mismatch"
390 );
391
392 input_shape
393 .iter()
394 .enumerate()
395 .map(|(axis, dim)| {
396 assert!(
397 config.interior_padding[axis] >= 0,
398 "pad: interior padding must be non-negative on axis {axis}"
399 );
400 if let DimExpr::Const(extent) = dim {
401 let base = if *extent == 0 {
402 0
403 } else {
404 (*extent as i64 - 1) * (config.interior_padding[axis] + 1) + 1
405 };
406 let padded = config.edge_padding_low[axis] + config.edge_padding_high[axis] + base;
407 DimExpr::Const(
408 usize::try_from(padded)
409 .expect("pad: output extent must be representable as usize"),
410 )
411 } else if config.interior_padding[axis] == 0 {
412 add_signed(
413 dim.clone(),
414 config.edge_padding_low[axis] + config.edge_padding_high[axis],
415 )
416 } else {
417 let stride = DimExpr::Const((config.interior_padding[axis] + 1) as usize);
418 let stretched = dim_add(
419 dim_mul(dim_sub(dim.clone(), DimExpr::Const(1)), stride),
420 DimExpr::Const(1),
421 );
422 add_signed(
423 stretched,
424 config.edge_padding_low[axis] + config.edge_padding_high[axis],
425 )
426 }
427 })
428 .collect()
429}
430
431fn add_signed(expr: DimExpr, amount: i64) -> DimExpr {
432 if amount >= 0 {
433 dim_add(expr, DimExpr::Const(amount as usize))
434 } else {
435 dim_sub(expr, DimExpr::Const((-amount) as usize))
436 }
437}
438
439fn concatenate_shape(input_shapes: &[&[DimExpr]], axis: usize) -> Vec<DimExpr> {
440 let first = input_shapes
441 .first()
442 .copied()
443 .expect("concatenate expects at least one input shape");
444 assert!(axis < first.len(), "concatenate axis {axis} out of bounds");
445 let mut output_shape = first.to_vec();
446 let axis_dim = input_shapes
447 .iter()
448 .skip(1)
449 .fold(first[axis].clone(), |acc, shape| {
450 dim_add(acc, shape[axis].clone())
451 });
452 output_shape[axis] = axis_dim;
453 output_shape
454}
455
456fn einsum_output_shape(subscripts: &str, input_shapes: &[&[DimExpr]]) -> Vec<DimExpr> {
457 let parsed = Subscripts::parse(subscripts)
458 .unwrap_or_else(|err| panic!("invalid einsum subscripts {subscripts:?}: {err}"));
459 assert_eq!(
460 parsed.inputs.len(),
461 input_shapes.len(),
462 "einsum subscripts expect {} inputs, got {}",
463 parsed.inputs.len(),
464 input_shapes.len()
465 );
466
467 let mut label_dims: HashMap<u32, DimExpr> = HashMap::new();
468 for (labels, shape) in parsed.inputs.iter().zip(input_shapes.iter()) {
469 assert_eq!(
470 labels.len(),
471 shape.len(),
472 "einsum input rank mismatch: labels={}, shape={}",
473 labels.len(),
474 shape.len()
475 );
476 for (&label, dim) in labels.iter().zip(shape.iter()) {
477 if let Some(existing) = label_dims.get(&label) {
478 if let (DimExpr::Const(lhs), DimExpr::Const(rhs)) = (existing, dim) {
479 assert_eq!(
480 lhs, rhs,
481 "einsum label {label} has inconsistent concrete sizes {lhs} vs {rhs}"
482 );
483 }
484 } else {
485 label_dims.insert(label, dim.clone());
486 }
487 }
488 }
489
490 parsed
491 .output
492 .iter()
493 .map(|label| {
494 label_dims
495 .get(label)
496 .cloned()
497 .unwrap_or_else(|| panic!("einsum output label {label} missing from inputs"))
498 })
499 .collect()
500}
501
502fn pad_to_match_shape(
503 input_shape: &[DimExpr],
504 reference_shape: &[DimExpr],
505 axis: usize,
506) -> Vec<DimExpr> {
507 assert!(
508 axis < input_shape.len(),
509 "PadToMatch input axis {axis} out of bounds"
510 );
511 assert!(
512 axis < reference_shape.len(),
513 "PadToMatch reference axis {axis} out of bounds"
514 );
515 let mut output_shape = input_shape.to_vec();
516 output_shape[axis] = dim_max(input_shape[axis].clone(), reference_shape[axis].clone());
517 output_shape
518}
519
520fn matrix_parts(input_shape: &[DimExpr]) -> (&DimExpr, &DimExpr, &[DimExpr]) {
521 assert!(
522 input_shape.len() >= 2,
523 "linalg op expects rank >= 2, got {}",
524 input_shape.len()
525 );
526 (&input_shape[0], &input_shape[1], &input_shape[2..])
527}
528
529fn svd_shapes(input_shape: &[DimExpr]) -> Vec<Vec<DimExpr>> {
530 let (m, n, batch) = matrix_parts(input_shape);
531 let k = dim_min(m.clone(), n.clone());
532 let mut u_shape = vec![m.clone(), k.clone()];
533 u_shape.extend_from_slice(batch);
534 let mut s_shape = vec![k.clone()];
535 s_shape.extend_from_slice(batch);
536 let mut vt_shape = vec![k, n.clone()];
537 vt_shape.extend_from_slice(batch);
538 vec![u_shape, s_shape, vt_shape]
539}
540
541fn qr_shapes(input_shape: &[DimExpr]) -> Vec<Vec<DimExpr>> {
542 let (m, n, batch) = matrix_parts(input_shape);
543 let k = dim_min(m.clone(), n.clone());
544 let mut q_shape = vec![m.clone(), k.clone()];
545 q_shape.extend_from_slice(batch);
546 let mut r_shape = vec![k, n.clone()];
547 r_shape.extend_from_slice(batch);
548 vec![q_shape, r_shape]
549}
550
551fn lu_shapes(input_shape: &[DimExpr]) -> Vec<Vec<DimExpr>> {
552 let (m, n, batch) = matrix_parts(input_shape);
553 let k = dim_min(m.clone(), n.clone());
554 let mut p_shape = vec![m.clone(), m.clone()];
555 p_shape.extend_from_slice(batch);
556 let mut l_shape = vec![m.clone(), k.clone()];
557 l_shape.extend_from_slice(batch);
558 let mut u_shape = vec![k, n.clone()];
559 u_shape.extend_from_slice(batch);
560 vec![p_shape, l_shape, u_shape, batch.to_vec()]
561}
562
563fn eig_like_shapes(input_shape: &[DimExpr]) -> Vec<Vec<DimExpr>> {
564 let (n, _, batch) = matrix_parts(input_shape);
565 let mut values_shape = vec![n.clone()];
566 values_shape.extend_from_slice(batch);
567 let mut vectors_shape = vec![n.clone(), n.clone()];
568 vectors_shape.extend_from_slice(batch);
569 vec![values_shape, vectors_shape]
570}
571
572fn dim_add(lhs: DimExpr, rhs: DimExpr) -> DimExpr {
573 match (lhs, rhs) {
574 (DimExpr::Const(lhs), DimExpr::Const(rhs)) => DimExpr::Const(lhs + rhs),
575 (lhs, rhs) => DimExpr::add(lhs, rhs),
576 }
577}
578
579fn dim_sub(lhs: DimExpr, rhs: DimExpr) -> DimExpr {
580 match (lhs, rhs) {
581 (DimExpr::Const(lhs), DimExpr::Const(rhs)) => DimExpr::Const(lhs - rhs),
582 (lhs, rhs) => DimExpr::sub(lhs, rhs),
583 }
584}
585
586fn dim_mul(lhs: DimExpr, rhs: DimExpr) -> DimExpr {
587 match (lhs, rhs) {
588 (DimExpr::Const(lhs), DimExpr::Const(rhs)) => DimExpr::Const(lhs * rhs),
589 (lhs, rhs) => DimExpr::mul(lhs, rhs),
590 }
591}
592
593fn dim_min(lhs: DimExpr, rhs: DimExpr) -> DimExpr {
594 match (lhs, rhs) {
595 (DimExpr::Const(lhs), DimExpr::Const(rhs)) => DimExpr::Const(lhs.min(rhs)),
596 (lhs, rhs) => DimExpr::min(lhs, rhs),
597 }
598}
599
600fn dim_max(lhs: DimExpr, rhs: DimExpr) -> DimExpr {
601 match (lhs, rhs) {
602 (DimExpr::Const(lhs), DimExpr::Const(rhs)) => DimExpr::Const(lhs.max(rhs)),
603 (lhs, rhs) => DimExpr::max(lhs, rhs),
604 }
605}