Skip to main content

tenferro/
shape_infer.rs

1//! Output-shape and output-dtype inference for `StdTensorOp`.
2//!
3//! Called during `StdTensorOp -> ExecProgram` lowering to populate
4//! `ExecInstruction::output_shapes` and `ExecInstruction::dtype`.
5
6use std::collections::HashMap;
7
8use tenferro_einsum::Subscripts;
9use tenferro_ops::dim_expr::DimExpr;
10use tenferro_ops::std_tensor_op::StdTensorOp;
11use tenferro_tensor::{DType, DotGeneralConfig, GatherConfig, PadConfig, SliceConfig};
12
13/// Infer output dtype for a single instruction given its op and input dtypes.
14///
15/// Panics if the input dtypes are inconsistent for the op (shouldn't happen
16/// in well-formed SSA programs).
17pub fn infer_output_dtype(op: &StdTensorOp, input_dtypes: &[DType]) -> DType {
18    match op {
19        StdTensorOp::Constant { dtype, .. } => *dtype,
20        StdTensorOp::Convert { to, .. } => *to,
21        StdTensorOp::Eig { input_dtype, .. } => match input_dtype {
22            DType::F32 | DType::C32 => DType::C32,
23            DType::F64 | DType::C64 => DType::C64,
24        },
25        StdTensorOp::Add
26        | StdTensorOp::Mul
27        | StdTensorOp::Neg
28        | StdTensorOp::Conj
29        | StdTensorOp::Div
30        | StdTensorOp::Abs
31        | StdTensorOp::Sign
32        | StdTensorOp::Maximum
33        | StdTensorOp::Minimum
34        | StdTensorOp::Select
35        | StdTensorOp::Clamp
36        | StdTensorOp::Exp
37        | StdTensorOp::Log
38        | StdTensorOp::Sin
39        | StdTensorOp::Cos
40        | StdTensorOp::Tanh
41        | StdTensorOp::Sqrt
42        | StdTensorOp::Rsqrt
43        | StdTensorOp::Pow
44        | StdTensorOp::Expm1
45        | StdTensorOp::Log1p
46        | StdTensorOp::Transpose { .. }
47        | StdTensorOp::Reshape { .. }
48        | StdTensorOp::BroadcastInDim { .. }
49        | StdTensorOp::ReduceSum { .. }
50        | StdTensorOp::ReduceProd { .. }
51        | StdTensorOp::ReduceMax { .. }
52        | StdTensorOp::ReduceMin { .. }
53        | StdTensorOp::ExtractDiag { .. }
54        | StdTensorOp::EmbedDiag { .. }
55        | StdTensorOp::Tril { .. }
56        | StdTensorOp::Triu { .. }
57        | StdTensorOp::Gather(_)
58        | StdTensorOp::Scatter(_)
59        | StdTensorOp::Slice(_)
60        | StdTensorOp::DynamicSlice { .. }
61        | StdTensorOp::Pad(_)
62        | StdTensorOp::Concatenate { .. }
63        | StdTensorOp::Reverse { .. }
64        | StdTensorOp::DotGeneral(_)
65        | StdTensorOp::NaryEinsum { .. }
66        | StdTensorOp::Cholesky { .. }
67        | StdTensorOp::Lu { .. }
68        | StdTensorOp::Svd { .. }
69        | StdTensorOp::Qr { .. }
70        | StdTensorOp::Eigh { .. }
71        | StdTensorOp::TriangularSolve { .. }
72        | StdTensorOp::ValidateNonsingular { .. }
73        | StdTensorOp::DynamicTruncate { .. }
74        | StdTensorOp::PadToMatch { .. } => input_dtypes[0],
75        StdTensorOp::Compare(_) => input_dtypes[0],
76        StdTensorOp::ShapeOf { .. } => DType::F64,
77    }
78}
79
80/// Infer output shapes for a single instruction.
81///
82/// Returns a vector of shapes (one per output slot). For single-output ops,
83/// the vector has length 1. For multi-output linalg ops the vector has one
84/// entry per output.
85pub fn infer_output_shapes(op: &StdTensorOp, input_shapes: &[&[DimExpr]]) -> Vec<Vec<DimExpr>> {
86    match op {
87        StdTensorOp::Add => vec![same_or_scalar_broadcast_shape(
88            require_input(op, input_shapes, 0),
89            require_input(op, input_shapes, 1),
90        )],
91        StdTensorOp::Mul => vec![same_or_scalar_broadcast_shape(
92            require_input(op, input_shapes, 0),
93            require_input(op, input_shapes, 1),
94        )],
95        StdTensorOp::Neg
96        | StdTensorOp::Conj
97        | StdTensorOp::Div
98        | StdTensorOp::Abs
99        | StdTensorOp::Sign
100        | StdTensorOp::Maximum
101        | StdTensorOp::Minimum
102        | StdTensorOp::Compare(_)
103        | StdTensorOp::Select
104        | StdTensorOp::Clamp
105        | StdTensorOp::Exp
106        | StdTensorOp::Log
107        | StdTensorOp::Sin
108        | StdTensorOp::Cos
109        | StdTensorOp::Tanh
110        | StdTensorOp::Sqrt
111        | StdTensorOp::Rsqrt
112        | StdTensorOp::Pow
113        | StdTensorOp::Expm1
114        | StdTensorOp::Log1p
115        | StdTensorOp::Convert { .. }
116        | StdTensorOp::Tril { .. }
117        | StdTensorOp::Triu { .. }
118        | StdTensorOp::Reverse { .. }
119        | StdTensorOp::Scatter(_)
120        | StdTensorOp::Cholesky { .. }
121        | StdTensorOp::ValidateNonsingular { .. } => {
122            vec![require_input(op, input_shapes, 0).to_vec()]
123        }
124        StdTensorOp::Transpose { perm } => {
125            vec![permute_shape(require_input(op, input_shapes, 0), perm)]
126        }
127        StdTensorOp::Reshape { to_shape, .. } => vec![to_shape.clone()],
128        StdTensorOp::BroadcastInDim { shape, .. } => vec![shape.clone()],
129        StdTensorOp::Constant { .. } => vec![Vec::new()],
130        StdTensorOp::ReduceSum { axes, .. }
131        | StdTensorOp::ReduceProd { axes, .. }
132        | StdTensorOp::ReduceMax { axes, .. }
133        | StdTensorOp::ReduceMin { axes, .. } => {
134            vec![reduced_shape(require_input(op, input_shapes, 0), axes)]
135        }
136        StdTensorOp::ExtractDiag { axis_a, axis_b } => {
137            vec![extract_diag_shape(
138                require_input(op, input_shapes, 0),
139                *axis_a,
140                *axis_b,
141            )]
142        }
143        StdTensorOp::EmbedDiag { axis_a, axis_b } => {
144            vec![embed_diag_shape(
145                require_input(op, input_shapes, 0),
146                *axis_a,
147                *axis_b,
148            )]
149        }
150        StdTensorOp::Gather(config) => vec![gather_shape(
151            require_input(op, input_shapes, 0),
152            require_input(op, input_shapes, 1),
153            config,
154        )],
155        StdTensorOp::Slice(config) => vec![slice_shape(require_input(op, input_shapes, 0), config)],
156        StdTensorOp::DynamicSlice { slice_sizes } => {
157            vec![slice_sizes.iter().copied().map(DimExpr::Const).collect()]
158        }
159        StdTensorOp::Pad(config) => vec![pad_shape(require_input(op, input_shapes, 0), config)],
160        StdTensorOp::DotGeneral(config) => vec![dot_general_shape(
161            require_input(op, input_shapes, 0),
162            require_input(op, input_shapes, 1),
163            config,
164        )],
165        StdTensorOp::NaryEinsum {
166            subscripts,
167            n_inputs,
168        } => {
169            assert_eq!(
170                input_shapes.len(),
171                *n_inputs,
172                "NaryEinsum expects {n_inputs} inputs, got {}",
173                input_shapes.len()
174            );
175            vec![einsum_output_shape(subscripts, input_shapes)]
176        }
177        StdTensorOp::Concatenate { axis } => vec![concatenate_shape(input_shapes, *axis)],
178        StdTensorOp::ShapeOf { .. } => vec![Vec::new()],
179        StdTensorOp::DynamicTruncate { axis } => {
180            let shape = require_input(op, input_shapes, 0).to_vec();
181            assert!(
182                *axis < shape.len(),
183                "DynamicTruncate axis {axis} out of bounds for rank {}",
184                shape.len()
185            );
186            vec![shape]
187        }
188        StdTensorOp::PadToMatch { axis } => vec![pad_to_match_shape(
189            require_input(op, input_shapes, 0),
190            require_input(op, input_shapes, 1),
191            *axis,
192        )],
193        StdTensorOp::Lu { .. } => lu_shapes(require_input(op, input_shapes, 0)),
194        StdTensorOp::Svd { .. } => svd_shapes(require_input(op, input_shapes, 0)),
195        StdTensorOp::Qr { .. } => qr_shapes(require_input(op, input_shapes, 0)),
196        StdTensorOp::Eigh { .. } | StdTensorOp::Eig { .. } => {
197            eig_like_shapes(require_input(op, input_shapes, 0))
198        }
199        StdTensorOp::TriangularSolve { .. } => vec![require_input(op, input_shapes, 1).to_vec()],
200    }
201}
202
203fn require_input<'a>(
204    op: &StdTensorOp,
205    input_shapes: &'a [&[DimExpr]],
206    idx: usize,
207) -> &'a [DimExpr] {
208    input_shapes.get(idx).copied().unwrap_or_else(|| {
209        panic!(
210            "{op:?} expects input index {idx}, got {} input shapes",
211            input_shapes.len()
212        )
213    })
214}
215
216fn permute_shape(input_shape: &[DimExpr], perm: &[usize]) -> Vec<DimExpr> {
217    perm.iter().map(|&axis| input_shape[axis].clone()).collect()
218}
219
220fn reduced_shape(input_shape: &[DimExpr], axes: &[usize]) -> Vec<DimExpr> {
221    input_shape
222        .iter()
223        .enumerate()
224        .filter_map(|(axis, dim)| (!axes.contains(&axis)).then_some(dim.clone()))
225        .collect()
226}
227
228fn same_or_scalar_broadcast_shape(lhs_shape: &[DimExpr], rhs_shape: &[DimExpr]) -> Vec<DimExpr> {
229    if lhs_shape.is_empty() {
230        rhs_shape.to_vec()
231    } else if rhs_shape.is_empty() {
232        lhs_shape.to_vec()
233    } else {
234        // Dynamic shape ops such as DynamicTruncate can only be inferred
235        // approximately here, so for non-scalar inputs we preserve the
236        // historical "follow lhs" behavior.
237        lhs_shape.to_vec()
238    }
239}
240
241fn extract_diag_shape(input_shape: &[DimExpr], axis_a: usize, axis_b: usize) -> Vec<DimExpr> {
242    assert!(
243        axis_a < input_shape.len() && axis_b < input_shape.len(),
244        "ExtractDiag axes ({axis_a}, {axis_b}) out of bounds for rank {}",
245        input_shape.len()
246    );
247    assert_ne!(axis_a, axis_b, "ExtractDiag requires distinct axes");
248    let diag_output_axis = if axis_a < axis_b { axis_a } else { axis_a - 1 };
249    let diag_dim = dim_min(input_shape[axis_a].clone(), input_shape[axis_b].clone());
250    let mut output_shape = input_shape.to_vec();
251    output_shape.remove(axis_b);
252    output_shape[diag_output_axis] = diag_dim;
253    output_shape
254}
255
256fn embed_diag_shape(input_shape: &[DimExpr], axis_a: usize, axis_b: usize) -> Vec<DimExpr> {
257    assert!(
258        axis_a < input_shape.len(),
259        "EmbedDiag axis_a {axis_a} out of bounds for rank {}",
260        input_shape.len()
261    );
262    assert!(
263        axis_b <= input_shape.len(),
264        "EmbedDiag axis_b {axis_b} out of bounds for rank {}",
265        input_shape.len()
266    );
267    let mut output_shape = input_shape.to_vec();
268    output_shape.insert(axis_b, input_shape[axis_a].clone());
269    output_shape
270}
271
272fn dot_general_shape(
273    lhs_shape: &[DimExpr],
274    rhs_shape: &[DimExpr],
275    config: &DotGeneralConfig,
276) -> Vec<DimExpr> {
277    assert_eq!(
278        lhs_shape.len(),
279        config.lhs_rank,
280        "DotGeneral lhs rank mismatch: config={}, actual={}",
281        config.lhs_rank,
282        lhs_shape.len()
283    );
284    assert_eq!(
285        rhs_shape.len(),
286        config.rhs_rank,
287        "DotGeneral rhs rank mismatch: config={}, actual={}",
288        config.rhs_rank,
289        rhs_shape.len()
290    );
291
292    let lhs_free = (0..config.lhs_rank).filter(|axis| {
293        !config.lhs_contracting_dims.contains(axis) && !config.lhs_batch_dims.contains(axis)
294    });
295    let rhs_free = (0..config.rhs_rank).filter(|axis| {
296        !config.rhs_contracting_dims.contains(axis) && !config.rhs_batch_dims.contains(axis)
297    });
298
299    let mut output_shape = Vec::new();
300    output_shape.extend(lhs_free.map(|axis| lhs_shape[axis].clone()));
301    output_shape.extend(rhs_free.map(|axis| rhs_shape[axis].clone()));
302    output_shape.extend(
303        config
304            .lhs_batch_dims
305            .iter()
306            .map(|&axis| lhs_shape[axis].clone()),
307    );
308    output_shape
309}
310
311fn gather_shape(
312    operand_shape: &[DimExpr],
313    index_shape: &[DimExpr],
314    config: &GatherConfig,
315) -> Vec<DimExpr> {
316    assert_eq!(
317        config.slice_sizes.len(),
318        operand_shape.len(),
319        "gather: slice_sizes rank mismatch"
320    );
321
322    let batch_shape = if config.index_vector_dim == index_shape.len() {
323        index_shape.to_vec()
324    } else {
325        index_shape
326            .iter()
327            .enumerate()
328            .filter_map(|(axis, dim)| (axis != config.index_vector_dim).then_some(dim.clone()))
329            .collect()
330    };
331
332    let window_dims: Vec<usize> = (0..operand_shape.len())
333        .filter(|dim| !config.collapsed_slice_dims.contains(dim))
334        .collect();
335    assert_eq!(
336        config.offset_dims.len(),
337        window_dims.len(),
338        "gather: offset_dims length mismatch"
339    );
340
341    let out_rank = batch_shape.len() + config.offset_dims.len();
342    let mut out_shape = vec![DimExpr::Const(0); out_rank];
343    let mut out_axis_to_operand_dim = vec![None; out_rank];
344    for (offset_axis, &out_axis) in config.offset_dims.iter().enumerate() {
345        out_axis_to_operand_dim[out_axis] = Some(window_dims[offset_axis]);
346    }
347
348    let mut batch_axis = 0usize;
349    for out_axis in 0..out_rank {
350        if let Some(operand_dim) = out_axis_to_operand_dim[out_axis] {
351            out_shape[out_axis] = DimExpr::Const(config.slice_sizes[operand_dim]);
352        } else {
353            out_shape[out_axis] = batch_shape[batch_axis].clone();
354            batch_axis += 1;
355        }
356    }
357
358    out_shape
359}
360
361fn slice_shape(input_shape: &[DimExpr], config: &SliceConfig) -> Vec<DimExpr> {
362    let rank = input_shape.len();
363    assert_eq!(config.starts.len(), rank, "slice: starts rank mismatch");
364    assert_eq!(config.limits.len(), rank, "slice: limits rank mismatch");
365    assert_eq!(config.strides.len(), rank, "slice: strides rank mismatch");
366    (0..rank)
367        .map(|axis| {
368            let span = config.limits[axis] - config.starts[axis];
369            DimExpr::Const((span + config.strides[axis] - 1) / config.strides[axis])
370        })
371        .collect()
372}
373
374fn pad_shape(input_shape: &[DimExpr], config: &PadConfig) -> Vec<DimExpr> {
375    let rank = input_shape.len();
376    assert_eq!(
377        config.edge_padding_low.len(),
378        rank,
379        "pad: edge_padding_low rank mismatch"
380    );
381    assert_eq!(
382        config.edge_padding_high.len(),
383        rank,
384        "pad: edge_padding_high rank mismatch"
385    );
386    assert_eq!(
387        config.interior_padding.len(),
388        rank,
389        "pad: interior_padding rank mismatch"
390    );
391
392    input_shape
393        .iter()
394        .enumerate()
395        .map(|(axis, dim)| {
396            assert!(
397                config.interior_padding[axis] >= 0,
398                "pad: interior padding must be non-negative on axis {axis}"
399            );
400            if let DimExpr::Const(extent) = dim {
401                let base = if *extent == 0 {
402                    0
403                } else {
404                    (*extent as i64 - 1) * (config.interior_padding[axis] + 1) + 1
405                };
406                let padded = config.edge_padding_low[axis] + config.edge_padding_high[axis] + base;
407                DimExpr::Const(
408                    usize::try_from(padded)
409                        .expect("pad: output extent must be representable as usize"),
410                )
411            } else if config.interior_padding[axis] == 0 {
412                add_signed(
413                    dim.clone(),
414                    config.edge_padding_low[axis] + config.edge_padding_high[axis],
415                )
416            } else {
417                let stride = DimExpr::Const((config.interior_padding[axis] + 1) as usize);
418                let stretched = dim_add(
419                    dim_mul(dim_sub(dim.clone(), DimExpr::Const(1)), stride),
420                    DimExpr::Const(1),
421                );
422                add_signed(
423                    stretched,
424                    config.edge_padding_low[axis] + config.edge_padding_high[axis],
425                )
426            }
427        })
428        .collect()
429}
430
431fn add_signed(expr: DimExpr, amount: i64) -> DimExpr {
432    if amount >= 0 {
433        dim_add(expr, DimExpr::Const(amount as usize))
434    } else {
435        dim_sub(expr, DimExpr::Const((-amount) as usize))
436    }
437}
438
439fn concatenate_shape(input_shapes: &[&[DimExpr]], axis: usize) -> Vec<DimExpr> {
440    let first = input_shapes
441        .first()
442        .copied()
443        .expect("concatenate expects at least one input shape");
444    assert!(axis < first.len(), "concatenate axis {axis} out of bounds");
445    let mut output_shape = first.to_vec();
446    let axis_dim = input_shapes
447        .iter()
448        .skip(1)
449        .fold(first[axis].clone(), |acc, shape| {
450            dim_add(acc, shape[axis].clone())
451        });
452    output_shape[axis] = axis_dim;
453    output_shape
454}
455
456fn einsum_output_shape(subscripts: &str, input_shapes: &[&[DimExpr]]) -> Vec<DimExpr> {
457    let parsed = Subscripts::parse(subscripts)
458        .unwrap_or_else(|err| panic!("invalid einsum subscripts {subscripts:?}: {err}"));
459    assert_eq!(
460        parsed.inputs.len(),
461        input_shapes.len(),
462        "einsum subscripts expect {} inputs, got {}",
463        parsed.inputs.len(),
464        input_shapes.len()
465    );
466
467    let mut label_dims: HashMap<u32, DimExpr> = HashMap::new();
468    for (labels, shape) in parsed.inputs.iter().zip(input_shapes.iter()) {
469        assert_eq!(
470            labels.len(),
471            shape.len(),
472            "einsum input rank mismatch: labels={}, shape={}",
473            labels.len(),
474            shape.len()
475        );
476        for (&label, dim) in labels.iter().zip(shape.iter()) {
477            if let Some(existing) = label_dims.get(&label) {
478                if let (DimExpr::Const(lhs), DimExpr::Const(rhs)) = (existing, dim) {
479                    assert_eq!(
480                        lhs, rhs,
481                        "einsum label {label} has inconsistent concrete sizes {lhs} vs {rhs}"
482                    );
483                }
484            } else {
485                label_dims.insert(label, dim.clone());
486            }
487        }
488    }
489
490    parsed
491        .output
492        .iter()
493        .map(|label| {
494            label_dims
495                .get(label)
496                .cloned()
497                .unwrap_or_else(|| panic!("einsum output label {label} missing from inputs"))
498        })
499        .collect()
500}
501
502fn pad_to_match_shape(
503    input_shape: &[DimExpr],
504    reference_shape: &[DimExpr],
505    axis: usize,
506) -> Vec<DimExpr> {
507    assert!(
508        axis < input_shape.len(),
509        "PadToMatch input axis {axis} out of bounds"
510    );
511    assert!(
512        axis < reference_shape.len(),
513        "PadToMatch reference axis {axis} out of bounds"
514    );
515    let mut output_shape = input_shape.to_vec();
516    output_shape[axis] = dim_max(input_shape[axis].clone(), reference_shape[axis].clone());
517    output_shape
518}
519
520fn matrix_parts(input_shape: &[DimExpr]) -> (&DimExpr, &DimExpr, &[DimExpr]) {
521    assert!(
522        input_shape.len() >= 2,
523        "linalg op expects rank >= 2, got {}",
524        input_shape.len()
525    );
526    (&input_shape[0], &input_shape[1], &input_shape[2..])
527}
528
529fn svd_shapes(input_shape: &[DimExpr]) -> Vec<Vec<DimExpr>> {
530    let (m, n, batch) = matrix_parts(input_shape);
531    let k = dim_min(m.clone(), n.clone());
532    let mut u_shape = vec![m.clone(), k.clone()];
533    u_shape.extend_from_slice(batch);
534    let mut s_shape = vec![k.clone()];
535    s_shape.extend_from_slice(batch);
536    let mut vt_shape = vec![k, n.clone()];
537    vt_shape.extend_from_slice(batch);
538    vec![u_shape, s_shape, vt_shape]
539}
540
541fn qr_shapes(input_shape: &[DimExpr]) -> Vec<Vec<DimExpr>> {
542    let (m, n, batch) = matrix_parts(input_shape);
543    let k = dim_min(m.clone(), n.clone());
544    let mut q_shape = vec![m.clone(), k.clone()];
545    q_shape.extend_from_slice(batch);
546    let mut r_shape = vec![k, n.clone()];
547    r_shape.extend_from_slice(batch);
548    vec![q_shape, r_shape]
549}
550
551fn lu_shapes(input_shape: &[DimExpr]) -> Vec<Vec<DimExpr>> {
552    let (m, n, batch) = matrix_parts(input_shape);
553    let k = dim_min(m.clone(), n.clone());
554    let mut p_shape = vec![m.clone(), m.clone()];
555    p_shape.extend_from_slice(batch);
556    let mut l_shape = vec![m.clone(), k.clone()];
557    l_shape.extend_from_slice(batch);
558    let mut u_shape = vec![k, n.clone()];
559    u_shape.extend_from_slice(batch);
560    vec![p_shape, l_shape, u_shape, batch.to_vec()]
561}
562
563fn eig_like_shapes(input_shape: &[DimExpr]) -> Vec<Vec<DimExpr>> {
564    let (n, _, batch) = matrix_parts(input_shape);
565    let mut values_shape = vec![n.clone()];
566    values_shape.extend_from_slice(batch);
567    let mut vectors_shape = vec![n.clone(), n.clone()];
568    vectors_shape.extend_from_slice(batch);
569    vec![values_shape, vectors_shape]
570}
571
572fn dim_add(lhs: DimExpr, rhs: DimExpr) -> DimExpr {
573    match (lhs, rhs) {
574        (DimExpr::Const(lhs), DimExpr::Const(rhs)) => DimExpr::Const(lhs + rhs),
575        (lhs, rhs) => DimExpr::add(lhs, rhs),
576    }
577}
578
579fn dim_sub(lhs: DimExpr, rhs: DimExpr) -> DimExpr {
580    match (lhs, rhs) {
581        (DimExpr::Const(lhs), DimExpr::Const(rhs)) => DimExpr::Const(lhs - rhs),
582        (lhs, rhs) => DimExpr::sub(lhs, rhs),
583    }
584}
585
586fn dim_mul(lhs: DimExpr, rhs: DimExpr) -> DimExpr {
587    match (lhs, rhs) {
588        (DimExpr::Const(lhs), DimExpr::Const(rhs)) => DimExpr::Const(lhs * rhs),
589        (lhs, rhs) => DimExpr::mul(lhs, rhs),
590    }
591}
592
593fn dim_min(lhs: DimExpr, rhs: DimExpr) -> DimExpr {
594    match (lhs, rhs) {
595        (DimExpr::Const(lhs), DimExpr::Const(rhs)) => DimExpr::Const(lhs.min(rhs)),
596        (lhs, rhs) => DimExpr::min(lhs, rhs),
597    }
598}
599
600fn dim_max(lhs: DimExpr, rhs: DimExpr) -> DimExpr {
601    match (lhs, rhs) {
602        (DimExpr::Const(lhs), DimExpr::Const(rhs)) => DimExpr::Const(lhs.max(rhs)),
603        (lhs, rhs) => DimExpr::max(lhs, rhs),
604    }
605}