Skip to main content

tenferro_fft/
lib.rs

1//! FFT extension operations for tenferro.
2//!
3//! This crate is an out-of-tree `ExtensionOp` package. The initial
4//! implementation executes on host tensors through `rustfft`; it does not add
5//! FFT to the core `tenferro` backend trait surface.
6//!
7//! # Examples
8//!
9//! ```
10//! use num_complex::Complex64;
11//! use tenferro_cpu::CpuBackend;
12//! use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
13//! use tenferro_fft::{FftNorm, TracedTensorFftExt};
14//!
15//! let x = TracedTensor::from_vec_col_major(
16//!     vec![4],
17//!     vec![
18//!         Complex64::new(1.0, 0.0),
19//!         Complex64::new(2.0, 0.0),
20//!         Complex64::new(3.0, 0.0),
21//!         Complex64::new(4.0, 0.0),
22//!     ],
23//! )
24//! .unwrap();
25//! let y = x.fft(None, -1, FftNorm::Backward).unwrap();
26//!
27//! let mut compiler = GraphCompiler::new();
28//! let program = compiler.compile(&y).unwrap();
29//! let mut executor = GraphExecutor::new(CpuBackend::new());
30//! executor.register_extension(tenferro_fft::register_runtime).unwrap();
31//! let out = executor.run(&program).unwrap();
32//! assert_eq!(out.shape(), &[4]);
33//! assert_eq!(out.as_slice::<Complex64>().unwrap()[0], Complex64::new(10.0, 0.0));
34//! ```
35
36use std::any::Any;
37use std::hash::Hasher;
38use std::mem::MaybeUninit;
39use std::sync::Arc;
40
41#[cfg(feature = "autodiff")]
42use computegraph::types::{LocalValueId, OperationRole, ValueKey, ValueRef};
43use num_complex::Complex;
44use num_traits::{Float, FromPrimitive, Zero};
45use rustfft::{FftNum, FftPlanner};
46#[cfg(feature = "autodiff")]
47use tenferro_ad::extension::{ExtensionAdRule, ExtensionRegistryError, ExtensionRuleSet};
48use tenferro_extension_macros::define_extension_runtime;
49#[cfg(feature = "autodiff")]
50use tenferro_ops::ad::PrimitiveRuleBuilder;
51#[cfg(feature = "autodiff")]
52use tenferro_ops::std_tensor_op::StdTensorOp;
53#[cfg(feature = "autodiff")]
54use tenferro_ops::ShapeGuardContext;
55use tenferro_ops::SymDim;
56use tenferro_runtime::extension::{apply, ExtensionExecutionContext, ExtensionOp};
57use tenferro_runtime::{Error, Result, TracedTensor};
58use tenferro_tensor::{
59    DType, DeviceKind, MemoryKind, Placement, Tensor, TensorBackend, TensorRead, TypedTensor,
60};
61#[cfg(feature = "autodiff")]
62use tidu::{ADRuleError, ADRuleKind, ADRuleResult};
63
64/// Extension family id used by the tenferro FFT extension.
65///
66/// # Examples
67///
68/// ```
69/// assert_eq!(
70///     tenferro_fft::FFT_EXTENSION_FAMILY_ID,
71///     "tenferro-fft.fft.v1"
72/// );
73/// ```
74pub const FFT_EXTENSION_FAMILY_ID: &str = "tenferro-fft.fft.v1";
75
76/// FFT extension methods for [`TracedTensor`].
77pub trait TracedTensorFftExt {
78    fn fft(&self, n: Option<usize>, axis: isize, norm: FftNorm) -> Result<TracedTensor>;
79    fn ifft(&self, n: Option<usize>, axis: isize, norm: FftNorm) -> Result<TracedTensor>;
80    fn rfft(&self, n: Option<usize>, axis: isize, norm: FftNorm) -> Result<TracedTensor>;
81    fn irfft(&self, n: Option<usize>, axis: isize, norm: FftNorm) -> Result<TracedTensor>;
82}
83
84impl TracedTensorFftExt for TracedTensor {
85    fn fft(&self, n: Option<usize>, axis: isize, norm: FftNorm) -> Result<TracedTensor> {
86        fft(self, n, axis, norm)
87    }
88
89    fn ifft(&self, n: Option<usize>, axis: isize, norm: FftNorm) -> Result<TracedTensor> {
90        ifft(self, n, axis, norm)
91    }
92
93    fn rfft(&self, n: Option<usize>, axis: isize, norm: FftNorm) -> Result<TracedTensor> {
94        rfft(self, n, axis, norm)
95    }
96
97    fn irfft(&self, n: Option<usize>, axis: isize, norm: FftNorm) -> Result<TracedTensor> {
98        irfft(self, n, axis, norm)
99    }
100}
101
102/// FFT normalization convention.
103///
104/// `Backward` matches NumPy, JAX, and PyTorch defaults: the forward transform
105/// is unscaled and the inverse transform is scaled by `1 / n`.
106///
107/// # Examples
108///
109/// ```
110/// use tenferro_fft::FftNorm;
111///
112/// assert_eq!(FftNorm::default(), FftNorm::Backward);
113/// ```
114#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
115pub enum FftNorm {
116    /// Scale inverse transforms by `1 / n`.
117    #[default]
118    Backward,
119    /// Scale forward transforms by `1 / n`.
120    Forward,
121    /// Scale both forward and inverse transforms by `1 / sqrt(n)`.
122    Ortho,
123}
124
125#[cfg(feature = "autodiff")]
126impl FftNorm {
127    fn c2c_adjoint(self) -> Self {
128        match self {
129            Self::Backward => Self::Forward,
130            Self::Forward => Self::Backward,
131            Self::Ortho => Self::Ortho,
132        }
133    }
134}
135
136#[derive(Clone, Copy, Debug, Eq, PartialEq)]
137enum FftKind {
138    C2C { forward: bool },
139    R2C { onesided: bool },
140    C2R,
141}
142
143#[derive(Clone, Debug, PartialEq)]
144struct FftOp {
145    kind: FftKind,
146    axis: usize,
147    n: Option<usize>,
148    norm: FftNorm,
149}
150
151impl FftOp {
152    fn new(kind: FftKind, axis: usize, n: Option<usize>, norm: FftNorm) -> Self {
153        Self {
154            kind,
155            axis,
156            n,
157            norm,
158        }
159    }
160
161    #[cfg(feature = "autodiff")]
162    fn c2c_adjoint(&self) -> Option<Self> {
163        match self.kind {
164            FftKind::C2C { forward } => Some(Self {
165                kind: FftKind::C2C { forward: !forward },
166                axis: self.axis,
167                n: self.n,
168                norm: self.norm.c2c_adjoint(),
169            }),
170            FftKind::R2C { .. } | FftKind::C2R => None,
171        }
172    }
173}
174
175impl ExtensionOp for FftOp {
176    fn family_id(&self) -> &'static str {
177        FFT_EXTENSION_FAMILY_ID
178    }
179
180    fn payload_hash(&self, hasher: &mut dyn Hasher) {
181        let kind = match self.kind {
182            FftKind::C2C { forward: true } => 0,
183            FftKind::C2C { forward: false } => 1,
184            FftKind::R2C { onesided: true } => 2,
185            FftKind::R2C { onesided: false } => 3,
186            FftKind::C2R => 4,
187        };
188        hasher.write_u8(kind);
189        hasher.write_usize(self.axis);
190        match self.n {
191            Some(n) => {
192                hasher.write_u8(1);
193                hasher.write_usize(n);
194            }
195            None => hasher.write_u8(0),
196        }
197        let norm = match self.norm {
198            FftNorm::Backward => 0,
199            FftNorm::Forward => 1,
200            FftNorm::Ortho => 2,
201        };
202        hasher.write_u8(norm);
203    }
204
205    fn payload_eq(&self, other: &dyn ExtensionOp) -> bool {
206        other
207            .as_any()
208            .downcast_ref::<FftOp>()
209            .is_some_and(|that| self == that)
210    }
211
212    fn clone_arc(&self) -> Arc<dyn ExtensionOp> {
213        Arc::new(self.clone())
214    }
215
216    fn as_any(&self) -> &dyn Any {
217        self
218    }
219
220    fn input_count(&self) -> usize {
221        1
222    }
223
224    fn output_count(&self) -> usize {
225        1
226    }
227
228    fn infer_output_meta(
229        &self,
230        input_dtypes: &[DType],
231        input_shapes: &[&[SymDim]],
232    ) -> Vec<(DType, Vec<SymDim>)> {
233        // Public FFT constructors validate dtype/axis/n before building this
234        // op. The extension trait is non-fallible, so direct invalid trait
235        // calls return an output-count mismatch sentinel instead of panicking.
236        let [input_dtype] = input_dtypes else {
237            return Vec::new();
238        };
239        let [input_shape] = input_shapes else {
240            return Vec::new();
241        };
242        if self.axis >= input_shape.len() {
243            return Vec::new();
244        }
245
246        let mut out_shape = input_shape.to_vec();
247        let output_dtype = match self.kind {
248            FftKind::C2C { .. } => {
249                if !matches!(input_dtype, DType::C32 | DType::C64) {
250                    return Vec::new();
251                }
252                *input_dtype
253            }
254            FftKind::R2C { onesided } => {
255                let len = transform_len_dim(self.n, &input_shape[self.axis]);
256                out_shape[self.axis] = if onesided { len / 2usize + 1usize } else { len };
257                match input_dtype {
258                    DType::F32 => DType::C32,
259                    DType::F64 => DType::C64,
260                    _ => return Vec::new(),
261                }
262            }
263            FftKind::C2R => {
264                out_shape[self.axis] = match self.n {
265                    Some(n) => SymDim::from(n),
266                    None => (input_shape[self.axis].clone() - 1usize) * 2usize,
267                };
268                match input_dtype {
269                    DType::C32 => DType::F32,
270                    DType::C64 => DType::F64,
271                    _ => return Vec::new(),
272                }
273            }
274        };
275
276        if matches!(self.kind, FftKind::C2C { .. }) {
277            out_shape[self.axis] = transform_len_dim(self.n, &input_shape[self.axis]);
278        }
279
280        vec![(output_dtype, out_shape)]
281    }
282
283    fn eager_execute(&self, inputs: &[&Tensor]) -> tenferro_tensor::Result<Vec<Tensor>> {
284        execute_host_fft_op(self, inputs)
285    }
286}
287
288fn execute_host_fft_op(op: &FftOp, inputs: &[&Tensor]) -> tenferro_tensor::Result<Vec<Tensor>> {
289    if inputs.len() != 1 {
290        return Err(tenferro_tensor::Error::InvalidConfig {
291            op: "tenferro-fft",
292            message: format!("expected 1 input, got {}", inputs.len()),
293        });
294    }
295    validate_host_fft_input(fft_op_name(op.kind), inputs[0])?;
296
297    let output = match (op.kind, inputs[0]) {
298        (FftKind::C2C { forward }, Tensor::C64(input)) => {
299            Tensor::C64(TypedTensor::from_vec_col_major(
300                output_shape_c2c(input.shape(), op.axis, op.n)?,
301                execute_c2c(input, op.axis, op.n, forward, op.norm)?,
302            )?)
303        }
304        (FftKind::C2C { forward }, Tensor::C32(input)) => {
305            Tensor::C32(TypedTensor::from_vec_col_major(
306                output_shape_c2c(input.shape(), op.axis, op.n)?,
307                execute_c2c(input, op.axis, op.n, forward, op.norm)?,
308            )?)
309        }
310        (FftKind::R2C { onesided }, Tensor::F64(input)) => {
311            Tensor::C64(TypedTensor::from_vec_col_major(
312                output_shape_r2c(input.shape(), op.axis, op.n, onesided)?,
313                execute_r2c(input, op.axis, op.n, onesided, op.norm)?,
314            )?)
315        }
316        (FftKind::R2C { onesided }, Tensor::F32(input)) => {
317            Tensor::C32(TypedTensor::from_vec_col_major(
318                output_shape_r2c(input.shape(), op.axis, op.n, onesided)?,
319                execute_r2c(input, op.axis, op.n, onesided, op.norm)?,
320            )?)
321        }
322        (FftKind::C2R, Tensor::C64(input)) => Tensor::F64(TypedTensor::from_vec_col_major(
323            output_shape_c2r(input.shape(), op.axis, op.n)?,
324            execute_c2r(input, op.axis, op.n, op.norm)?,
325        )?),
326        (FftKind::C2R, Tensor::C32(input)) => Tensor::F32(TypedTensor::from_vec_col_major(
327            output_shape_c2r(input.shape(), op.axis, op.n)?,
328            execute_c2r(input, op.axis, op.n, op.norm)?,
329        )?),
330        (kind, other) => {
331            return Err(tenferro_tensor::Error::DTypeMismatch {
332                op: match kind {
333                    FftKind::C2C { .. } => "fft",
334                    FftKind::R2C { .. } => "rfft",
335                    FftKind::C2R => "irfft",
336                },
337                lhs: expected_dtype_for(kind),
338                rhs: other.dtype(),
339            });
340        }
341    };
342    Ok(vec![output])
343}
344
345fn tensor_placement(input: &Tensor) -> &Placement {
346    input.placement()
347}
348
349fn tensor_has_backend_buffer(input: &Tensor) -> bool {
350    input.is_backend_buffer()
351}
352
353fn validate_host_fft_input(op: &'static str, input: &Tensor) -> tenferro_tensor::Result<()> {
354    let placement = tensor_placement(input);
355    let is_device = matches!(placement.memory_kind, MemoryKind::Device);
356    if !is_device && !tensor_has_backend_buffer(input) {
357        return Ok(());
358    }
359
360    let location = match placement.device.as_ref().map(|device| &device.kind) {
361        Some(DeviceKind::Gpu(kind)) => format!("GPU backend {kind:?}"),
362        Some(kind) => format!("device kind {kind:?}"),
363        None if is_device => "device tensor without device metadata".to_string(),
364        None => "backend buffer".to_string(),
365    };
366    Err(tenferro_tensor::Error::backend_failure(
367        op,
368        format!(
369            "tenferro-fft supports host tensors only; unsupported {location} input; \
370             download the tensor to CPU before FFT"
371        ),
372    ))
373}
374
375#[cfg(feature = "autodiff")]
376#[derive(Debug)]
377struct FftAdRule;
378
379#[cfg(feature = "autodiff")]
380impl ExtensionAdRule for FftAdRule {
381    fn family_id(&self) -> &'static str {
382        FFT_EXTENSION_FAMILY_ID
383    }
384
385    fn linearize(
386        &self,
387        op: &dyn ExtensionOp,
388        builder: &mut dyn PrimitiveRuleBuilder,
389        _primal_in: &[ValueKey<StdTensorOp>],
390        _primal_out: &[ValueKey<StdTensorOp>],
391        tangent_in: &[Option<LocalValueId>],
392        _ctx: &mut ShapeGuardContext,
393    ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
394        let fft_op = fft_payload(op, ADRuleKind::Jvp)?;
395        if !matches!(fft_op.kind, FftKind::C2C { .. }) {
396            return Err(ADRuleError::unsupported(
397                fft_ad_family_id(fft_op.kind),
398                ADRuleKind::Jvp,
399            ));
400        }
401
402        match tangent_in[0] {
403            Some(dx) => {
404                let outputs = builder.add_operation(
405                    StdTensorOp::Extension(Arc::new(fft_op.clone())),
406                    vec![ValueRef::Local(dx)],
407                    OperationRole::Linearized {
408                        active_mask: vec![true],
409                    },
410                );
411                Ok(vec![Some(outputs[0])])
412            }
413            None => Ok(vec![None]),
414        }
415    }
416
417    fn transpose_rule(
418        &self,
419        op: &dyn ExtensionOp,
420        builder: &mut dyn PrimitiveRuleBuilder,
421        cotangent_out: &[Option<LocalValueId>],
422        _inputs: &[ValueRef<StdTensorOp>],
423        _mode: &OperationRole,
424        _ctx: &mut ShapeGuardContext,
425    ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
426        let fft_op = fft_payload(op, ADRuleKind::Transpose)?;
427        if !matches!(fft_op.kind, FftKind::C2C { .. }) {
428            return Err(ADRuleError::unsupported(
429                fft_ad_family_id(fft_op.kind),
430                ADRuleKind::Transpose,
431            ));
432        }
433
434        match cotangent_out[0] {
435            Some(ct) => {
436                let adjoint_op = fft_op.c2c_adjoint().ok_or_else(|| {
437                    ADRuleError::unsupported(FFT_EXTENSION_FAMILY_ID, ADRuleKind::Transpose)
438                })?;
439                let outputs = builder.add_operation(
440                    StdTensorOp::Extension(Arc::new(adjoint_op)),
441                    vec![ValueRef::Local(ct)],
442                    OperationRole::Linearized {
443                        active_mask: vec![true],
444                    },
445                );
446                Ok(vec![Some(outputs[0])])
447            }
448            None => Ok(vec![None]),
449        }
450    }
451}
452
453/// Return the explicit FFT extension AD rule set.
454#[cfg(feature = "autodiff")]
455pub fn ad_rules() -> std::result::Result<ExtensionRuleSet, ExtensionRegistryError> {
456    ExtensionRuleSet::new().with_rule(Arc::new(FftAdRule))
457}
458
459fn execute_fft_extension<B: TensorBackend + 'static>(
460    op: &FftOp,
461    inputs: &[&Tensor],
462    _ctx: &mut ExtensionExecutionContext<'_, B>,
463) -> tenferro_tensor::Result<Vec<Tensor>> {
464    execute_host_fft_op(op, inputs)
465}
466
467fn execute_fft_extension_reads<B: TensorBackend + 'static>(
468    op: &FftOp,
469    inputs: &[TensorRead<'_>],
470    ctx: &mut ExtensionExecutionContext<'_, B>,
471) -> tenferro_tensor::Result<Vec<Tensor>> {
472    let _ = ctx;
473    // rustfft consumes compact host tensors; materialization is explicit so
474    // backend-backed views produce a normal error instead of an implicit path.
475    let materialized_inputs: Vec<Tensor> = inputs
476        .iter()
477        .map(TensorRead::to_tensor)
478        .collect::<tenferro_tensor::Result<_>>()?;
479    let input_refs: Vec<&Tensor> = materialized_inputs.iter().collect();
480    execute_host_fft_op(op, &input_refs)
481}
482
483define_extension_runtime! {
484    runtime = FftRuntime,
485    family_id = FFT_EXTENSION_FAMILY_ID,
486    op_type = FftOp,
487    execute = execute_fft_extension,
488    execute_reads = execute_fft_extension_reads,
489    register_fn = register_runtime,
490}
491
492/// Build a one-dimensional FFT along `axis`.
493///
494/// Complex inputs use a complex-to-complex transform. Real inputs use a
495/// real-to-complex transform that returns the full complex spectrum.
496///
497/// # Examples
498///
499/// ```
500/// use num_complex::Complex64;
501/// use tenferro_cpu::CpuBackend;
502/// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
503/// use tenferro_fft::{FftNorm, TracedTensorFftExt};
504///
505/// let x = TracedTensor::from_vec_col_major(vec![2], vec![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)]).unwrap();
506/// let y = x.fft(None, -1, FftNorm::Backward).unwrap();
507///
508/// let mut compiler = GraphCompiler::new();
509/// let program = compiler.compile(&y).unwrap();
510/// let mut executor = GraphExecutor::new(CpuBackend::new());
511/// executor.register_extension(tenferro_fft::register_runtime).unwrap();
512/// let out = executor.run(&program).unwrap();
513/// assert_eq!(out.as_slice::<Complex64>().unwrap()[0], Complex64::new(3.0, 0.0));
514/// ```
515fn fft(input: &TracedTensor, n: Option<usize>, axis: isize, norm: FftNorm) -> Result<TracedTensor> {
516    let kind = match input.dtype {
517        DType::C32 | DType::C64 => FftKind::C2C { forward: true },
518        DType::F32 | DType::F64 => FftKind::R2C { onesided: false },
519        DType::I32 | DType::I64 | DType::Bool => {
520            return Err(fft_config_error(
521                "fft",
522                format!(
523                    "fft expects real or complex floating input, got {:?}",
524                    input.dtype
525                ),
526            ))
527        }
528    };
529    apply_unary_fft("fft", input, kind, n, axis, norm)
530}
531
532/// Build a one-dimensional inverse FFT along `axis`.
533///
534/// # Examples
535///
536/// ```
537/// use num_complex::Complex64;
538/// use tenferro_cpu::CpuBackend;
539/// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
540/// use tenferro_fft::{FftNorm, TracedTensorFftExt};
541///
542/// let spectrum = TracedTensor::from_vec_col_major(vec![2], vec![Complex64::new(3.0, 0.0), Complex64::new(-1.0, 0.0)]).unwrap();
543/// let y = spectrum.ifft(None, -1, FftNorm::Backward).unwrap();
544///
545/// let mut compiler = GraphCompiler::new();
546/// let program = compiler.compile(&y).unwrap();
547/// let mut executor = GraphExecutor::new(CpuBackend::new());
548/// executor.register_extension(tenferro_fft::register_runtime).unwrap();
549/// let out = executor.run(&program).unwrap();
550/// assert_eq!(out.as_slice::<Complex64>().unwrap()[0], Complex64::new(1.0, 0.0));
551/// ```
552fn ifft(
553    input: &TracedTensor,
554    n: Option<usize>,
555    axis: isize,
556    norm: FftNorm,
557) -> Result<TracedTensor> {
558    if !matches!(input.dtype, DType::C32 | DType::C64) {
559        return Err(fft_config_error(
560            "ifft",
561            format!("ifft expects C32 or C64 input; got {:?}", input.dtype),
562        ));
563    }
564    apply_unary_fft(
565        "ifft",
566        input,
567        FftKind::C2C { forward: false },
568        n,
569        axis,
570        norm,
571    )
572}
573
574/// Build a one-dimensional real FFT along `axis`.
575///
576/// The output keeps only the Hermitian one-sided spectrum with axis length
577/// `n / 2 + 1`.
578///
579/// # Examples
580///
581/// ```
582/// use num_complex::Complex64;
583/// use tenferro_cpu::CpuBackend;
584/// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
585/// use tenferro_fft::{FftNorm, TracedTensorFftExt};
586///
587/// let x = TracedTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]).unwrap();
588/// let y = x.rfft(None, -1, FftNorm::Backward).unwrap();
589///
590/// let mut compiler = GraphCompiler::new();
591/// let program = compiler.compile(&y).unwrap();
592/// let mut executor = GraphExecutor::new(CpuBackend::new());
593/// executor.register_extension(tenferro_fft::register_runtime).unwrap();
594/// let out = executor.run(&program).unwrap();
595/// assert_eq!(out.shape(), &[2]);
596/// assert_eq!(out.as_slice::<Complex64>().unwrap()[0], Complex64::new(3.0, 0.0));
597/// ```
598fn rfft(
599    input: &TracedTensor,
600    n: Option<usize>,
601    axis: isize,
602    norm: FftNorm,
603) -> Result<TracedTensor> {
604    if !matches!(input.dtype, DType::F32 | DType::F64) {
605        return Err(fft_config_error(
606            "rfft",
607            format!("rfft expects F32 or F64 input; got {:?}", input.dtype),
608        ));
609    }
610    apply_unary_fft(
611        "rfft",
612        input,
613        FftKind::R2C { onesided: true },
614        n,
615        axis,
616        norm,
617    )
618}
619
620/// Build a one-dimensional inverse real FFT along `axis`.
621///
622/// If `n` is `None`, the output length is inferred as twice one less than the
623/// input spectrum length.
624///
625/// # Examples
626///
627/// ```
628/// use num_complex::Complex64;
629/// use tenferro_cpu::CpuBackend;
630/// use tenferro_runtime::{GraphCompiler, GraphExecutor, TracedTensor};
631/// use tenferro_fft::{FftNorm, TracedTensorFftExt};
632///
633/// let spectrum = TracedTensor::from_vec_col_major(
634///     vec![2],
635///     vec![Complex64::new(3.0, 0.0), Complex64::new(-1.0, 0.0)],
636/// )
637/// .unwrap();
638/// let y = spectrum.irfft(Some(2), -1, FftNorm::Backward).unwrap();
639///
640/// let mut compiler = GraphCompiler::new();
641/// let program = compiler.compile(&y).unwrap();
642/// let mut executor = GraphExecutor::new(CpuBackend::new());
643/// executor.register_extension(tenferro_fft::register_runtime).unwrap();
644/// let out = executor.run(&program).unwrap();
645/// assert_eq!(out.as_slice::<f64>().unwrap(), &[1.0, 2.0]);
646/// ```
647fn irfft(
648    input: &TracedTensor,
649    n: Option<usize>,
650    axis: isize,
651    norm: FftNorm,
652) -> Result<TracedTensor> {
653    if !matches!(input.dtype, DType::C32 | DType::C64) {
654        return Err(fft_config_error(
655            "irfft",
656            format!("irfft expects C32 or C64 input; got {:?}", input.dtype),
657        ));
658    }
659    apply_unary_fft("irfft", input, FftKind::C2R, n, axis, norm)
660}
661
662fn apply_unary_fft(
663    op_name: &'static str,
664    input: &TracedTensor,
665    kind: FftKind,
666    n: Option<usize>,
667    axis: isize,
668    norm: FftNorm,
669) -> Result<TracedTensor> {
670    validate_n(op_name, n)?;
671    let axis = normalize_axis(op_name, axis, input.rank)?;
672    validate_resolved_transform_len(op_name, input, n, axis)?;
673    let op = Arc::new(FftOp::new(kind, axis, n, norm));
674    let mut outputs = apply(op, &[input])?;
675    outputs
676        .pop()
677        .ok_or_else(|| Error::Internal("FFT extension declares exactly one output".into()))
678}
679
680fn normalize_axis(op: &'static str, axis: isize, rank: usize) -> Result<usize> {
681    if rank == 0 {
682        return Err(fft_config_error(op, "tenferro-fft requires rank >= 1"));
683    }
684    let rank_isize = rank as isize;
685    let normalized = if axis < 0 { rank_isize + axis } else { axis };
686    if normalized < 0 || normalized >= rank_isize {
687        return Err(fft_config_error(
688            op,
689            format!("tenferro-fft axis {axis} out of bounds for rank {rank}"),
690        ));
691    }
692    Ok(normalized as usize)
693}
694
695fn validate_n(op: &'static str, n: Option<usize>) -> Result<()> {
696    if n == Some(0) {
697        return Err(fft_config_error(
698            op,
699            "tenferro-fft transform length n must be positive",
700        ));
701    }
702    Ok(())
703}
704
705fn validate_resolved_transform_len(
706    op: &'static str,
707    input: &TracedTensor,
708    n: Option<usize>,
709    axis: usize,
710) -> Result<()> {
711    if n.is_some() {
712        return Ok(());
713    }
714    if input
715        .try_concrete_shape()
716        .and_then(|shape| shape.get(axis).copied())
717        == Some(0)
718    {
719        return Err(fft_config_error(
720            op,
721            "tenferro-fft transform length n must be positive",
722        ));
723    }
724    Ok(())
725}
726
727fn fft_config_error(op: &'static str, message: impl std::fmt::Display) -> Error {
728    Error::TensorRuntime(tenferro_tensor::Error::InvalidConfig {
729        op,
730        message: message.to_string(),
731    })
732}
733
734fn transform_len_dim(n: Option<usize>, input_dim: &SymDim) -> SymDim {
735    n.map(SymDim::from).unwrap_or_else(|| input_dim.clone())
736}
737
738fn expected_dtype_for(kind: FftKind) -> DType {
739    match kind {
740        FftKind::C2C { .. } | FftKind::C2R => DType::C64,
741        FftKind::R2C { .. } => DType::F64,
742    }
743}
744
745fn fft_op_name(kind: FftKind) -> &'static str {
746    match kind {
747        FftKind::C2C { forward: true } => "fft",
748        FftKind::C2C { forward: false } => "ifft",
749        FftKind::R2C { .. } => "rfft",
750        FftKind::C2R => "irfft",
751    }
752}
753
754#[cfg(feature = "autodiff")]
755fn fft_ad_family_id(kind: FftKind) -> &'static str {
756    match kind {
757        FftKind::C2C { .. } => FFT_EXTENSION_FAMILY_ID,
758        FftKind::R2C { .. } => "tenferro-fft.rfft.v1",
759        FftKind::C2R => "tenferro-fft.irfft.v1",
760    }
761}
762
763#[cfg(feature = "autodiff")]
764fn fft_payload<'a>(op: &'a dyn ExtensionOp, rule: ADRuleKind) -> ADRuleResult<&'a FftOp> {
765    op.as_any()
766        .downcast_ref::<FftOp>()
767        .ok_or_else(|| ADRuleError::unsupported(FFT_EXTENSION_FAMILY_ID, rule))
768}
769
770fn output_shape_c2c(
771    shape: &[usize],
772    axis: usize,
773    n: Option<usize>,
774) -> tenferro_tensor::Result<Vec<usize>> {
775    let len = transform_len(shape, axis, n)?;
776    let mut out_shape = shape.to_vec();
777    out_shape[axis] = len;
778    Ok(out_shape)
779}
780
781fn output_shape_r2c(
782    shape: &[usize],
783    axis: usize,
784    n: Option<usize>,
785    onesided: bool,
786) -> tenferro_tensor::Result<Vec<usize>> {
787    let len = transform_len(shape, axis, n)?;
788    let mut out_shape = shape.to_vec();
789    out_shape[axis] = if onesided { len / 2 + 1 } else { len };
790    Ok(out_shape)
791}
792
793fn output_shape_c2r(
794    shape: &[usize],
795    axis: usize,
796    n: Option<usize>,
797) -> tenferro_tensor::Result<Vec<usize>> {
798    validate_axis("irfft", shape, axis)?;
799    let input_len = shape[axis];
800    if input_len == 0 {
801        return Err(tenferro_tensor::Error::InvalidConfig {
802            op: "irfft",
803            message: "input spectrum axis length must be positive".to_string(),
804        });
805    }
806    let len = match n {
807        Some(len) => len,
808        None => input_len
809            .checked_sub(1)
810            .and_then(|len| len.checked_mul(2))
811            .ok_or_else(|| tenferro_tensor::Error::InvalidConfig {
812                op: "irfft",
813                message: "default output length overflows usize".to_string(),
814            })?,
815    };
816    if len == 0 {
817        return Err(tenferro_tensor::Error::InvalidConfig {
818            op: "irfft",
819            message: "output length must be positive".to_string(),
820        });
821    }
822    let mut out_shape = shape.to_vec();
823    out_shape[axis] = len;
824    Ok(out_shape)
825}
826
827fn transform_len(shape: &[usize], axis: usize, n: Option<usize>) -> tenferro_tensor::Result<usize> {
828    validate_axis("fft", shape, axis)?;
829    let len = n.unwrap_or(shape[axis]);
830    if len == 0 {
831        return Err(tenferro_tensor::Error::InvalidConfig {
832            op: "fft",
833            message: "transform length must be positive".to_string(),
834        });
835    }
836    Ok(len)
837}
838
839fn validate_axis(op: &'static str, shape: &[usize], axis: usize) -> tenferro_tensor::Result<()> {
840    if axis >= shape.len() {
841        return Err(tenferro_tensor::Error::AxisOutOfBounds {
842            op,
843            axis,
844            rank: shape.len(),
845        });
846    }
847    Ok(())
848}
849
850fn checked_shape_product(
851    op: &'static str,
852    role: &'static str,
853    shape: &[usize],
854) -> tenferro_tensor::Result<usize> {
855    shape
856        .iter()
857        .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
858        .ok_or_else(|| tenferro_tensor::Error::InvalidConfig {
859            op,
860            message: format!("{role} shape product overflows usize"),
861        })
862}
863
864fn checked_mul(
865    op: &'static str,
866    role: &'static str,
867    lhs: usize,
868    rhs: usize,
869) -> tenferro_tensor::Result<usize> {
870    lhs.checked_mul(rhs)
871        .ok_or_else(|| tenferro_tensor::Error::InvalidConfig {
872            op,
873            message: format!("{role} overflows usize"),
874        })
875}
876
877fn checked_add(
878    op: &'static str,
879    role: &'static str,
880    lhs: usize,
881    rhs: usize,
882) -> tenferro_tensor::Result<usize> {
883    lhs.checked_add(rhs)
884        .ok_or_else(|| tenferro_tensor::Error::InvalidConfig {
885            op,
886            message: format!("{role} overflows usize"),
887        })
888}
889
890fn uninit_output_vec<T>(len: usize) -> Vec<MaybeUninit<T>> {
891    let mut output = Vec::with_capacity(len);
892    // SAFETY: Uninitialized bytes are valid for `MaybeUninit<T>` slots. The
893    // slots are converted to `T` only after all output positions are written.
894    unsafe { output.set_len(len) };
895    output
896}
897
898unsafe fn assume_init_output_vec<T>(mut output: Vec<MaybeUninit<T>>) -> Vec<T> {
899    let len = output.len();
900    let capacity = output.capacity();
901    let ptr = output.as_mut_ptr().cast::<T>();
902    std::mem::forget(output);
903    // SAFETY: `MaybeUninit<T>` has the same layout as `T`; the caller
904    // guarantees every slot has been initialized exactly once.
905    unsafe { Vec::from_raw_parts(ptr, len, capacity) }
906}
907
908fn execute_c2c<T>(
909    input: &TypedTensor<Complex<T>>,
910    axis: usize,
911    n: Option<usize>,
912    forward: bool,
913    norm: FftNorm,
914) -> tenferro_tensor::Result<Vec<Complex<T>>>
915where
916    T: FftNum + Float + FromPrimitive,
917{
918    let in_shape = input.shape();
919    let fft_len = transform_len(in_shape, axis, n)?;
920    let out_shape = output_shape_c2c(in_shape, axis, n)?;
921    let out_axis_len = out_shape[axis];
922    let input_data = input.host_data()?;
923    let output_len = checked_shape_product("fft", "output", &out_shape)?;
924    let mut output = uninit_output_vec(output_len);
925    let mut planner = FftPlanner::<T>::new();
926    let fft_plan = if forward {
927        planner.plan_fft_forward(fft_len)
928    } else {
929        planner.plan_fft_inverse(fft_len)
930    };
931    let scale: T = scale_for(norm, forward, fft_len)?;
932    let mut lane = vec![Complex::zero(); fft_len];
933
934    for_axis_lane(in_shape, axis, out_axis_len, |lane_ctx| {
935        lane.fill(Complex::zero());
936        let copy_len = lane_ctx.in_axis_len.min(fft_len);
937        for (k, slot) in lane.iter_mut().take(copy_len).enumerate() {
938            *slot = input_data[lane_ctx.input_offset(k)?];
939        }
940        fft_plan.process(&mut lane);
941        if scale != T::one() {
942            for value in &mut lane {
943                *value = *value * scale;
944            }
945        }
946        for (k, value) in lane.iter().take(out_axis_len).copied().enumerate() {
947            output[lane_ctx.output_offset(k)?].write(value);
948        }
949        Ok(())
950    })?;
951
952    // SAFETY: `for_axis_lane` covers every element in the compact column-major
953    // output exactly once, and each lane writes all `out_axis_len` positions.
954    Ok(unsafe { assume_init_output_vec(output) })
955}
956
957fn execute_r2c<T>(
958    input: &TypedTensor<T>,
959    axis: usize,
960    n: Option<usize>,
961    onesided: bool,
962    norm: FftNorm,
963) -> tenferro_tensor::Result<Vec<Complex<T>>>
964where
965    T: FftNum + Float + FromPrimitive,
966{
967    let in_shape = input.shape();
968    let fft_len = transform_len(in_shape, axis, n)?;
969    let out_shape = output_shape_r2c(in_shape, axis, n, onesided)?;
970    let out_axis_len = out_shape[axis];
971    let input_data = input.host_data()?;
972    let output_len = checked_shape_product("rfft", "output", &out_shape)?;
973    let mut output = uninit_output_vec(output_len);
974    let mut planner = FftPlanner::<T>::new();
975    let fft_plan = planner.plan_fft_forward(fft_len);
976    let scale: T = scale_for(norm, true, fft_len)?;
977    let mut lane = vec![Complex::zero(); fft_len];
978
979    for_axis_lane(in_shape, axis, out_axis_len, |lane_ctx| {
980        lane.fill(Complex::zero());
981        let copy_len = lane_ctx.in_axis_len.min(fft_len);
982        for (k, slot) in lane.iter_mut().take(copy_len).enumerate() {
983            *slot = Complex::new(input_data[lane_ctx.input_offset(k)?], T::zero());
984        }
985        fft_plan.process(&mut lane);
986        if scale != T::one() {
987            for value in &mut lane {
988                *value = *value * scale;
989            }
990        }
991        for (k, value) in lane.iter().take(out_axis_len).copied().enumerate() {
992            output[lane_ctx.output_offset(k)?].write(value);
993        }
994        Ok(())
995    })?;
996
997    // SAFETY: `for_axis_lane` covers every element in the compact column-major
998    // output exactly once, and each lane writes all `out_axis_len` positions.
999    Ok(unsafe { assume_init_output_vec(output) })
1000}
1001
1002fn execute_c2r<T>(
1003    input: &TypedTensor<Complex<T>>,
1004    axis: usize,
1005    n: Option<usize>,
1006    norm: FftNorm,
1007) -> tenferro_tensor::Result<Vec<T>>
1008where
1009    T: FftNum + Float + FromPrimitive,
1010{
1011    let in_shape = input.shape();
1012    let out_shape = output_shape_c2r(in_shape, axis, n)?;
1013    let out_axis_len = out_shape[axis];
1014    let expected_half = out_axis_len / 2 + 1;
1015    let input_data = input.host_data()?;
1016    let output_len = checked_shape_product("irfft", "output", &out_shape)?;
1017    let mut output = uninit_output_vec(output_len);
1018    let mut planner = FftPlanner::<T>::new();
1019    let fft_plan = planner.plan_fft_inverse(out_axis_len);
1020    let scale: T = scale_for(norm, false, out_axis_len)?;
1021    let mut lane = vec![Complex::zero(); out_axis_len];
1022
1023    for_axis_lane(in_shape, axis, out_axis_len, |lane_ctx| {
1024        lane.fill(Complex::zero());
1025        let copy_len = lane_ctx.in_axis_len.min(expected_half);
1026        for (k, slot) in lane.iter_mut().take(copy_len).enumerate() {
1027            *slot = input_data[lane_ctx.input_offset(k)?];
1028        }
1029        for k in expected_half..out_axis_len {
1030            let mirror = out_axis_len - k;
1031            if mirror < lane.len() {
1032                lane[k] = lane[mirror].conj();
1033            }
1034        }
1035        fft_plan.process(&mut lane);
1036        for (k, value) in lane.iter().take(out_axis_len).enumerate() {
1037            output[lane_ctx.output_offset(k)?].write(value.re * scale);
1038        }
1039        Ok(())
1040    })?;
1041
1042    // SAFETY: `for_axis_lane` covers every element in the compact column-major
1043    // output exactly once, and each lane writes all `out_axis_len` positions.
1044    Ok(unsafe { assume_init_output_vec(output) })
1045}
1046
1047fn scale_for<T>(norm: FftNorm, forward: bool, n: usize) -> tenferro_tensor::Result<T>
1048where
1049    T: Float + FromPrimitive,
1050{
1051    let len = T::from_usize(n).ok_or_else(|| tenferro_tensor::Error::InvalidConfig {
1052        op: "tenferro_fft::scale_for",
1053        message: format!("FFT length {n} cannot be represented as scalar"),
1054    })?;
1055    Ok(match (norm, forward) {
1056        (FftNorm::Backward, true) | (FftNorm::Forward, false) => T::one(),
1057        (FftNorm::Backward, false) | (FftNorm::Forward, true) => T::one() / len,
1058        (FftNorm::Ortho, _) => T::one() / len.sqrt(),
1059    })
1060}
1061
1062#[derive(Clone, Copy)]
1063struct LaneContext {
1064    input_base: usize,
1065    output_base: usize,
1066    axis_stride: usize,
1067    in_axis_len: usize,
1068}
1069
1070impl LaneContext {
1071    fn input_offset(self, k: usize) -> tenferro_tensor::Result<usize> {
1072        let lane_offset = checked_mul("fft", "input lane offset", k, self.axis_stride)?;
1073        checked_add("fft", "input element offset", self.input_base, lane_offset)
1074    }
1075
1076    fn output_offset(self, k: usize) -> tenferro_tensor::Result<usize> {
1077        let lane_offset = checked_mul("fft", "output lane offset", k, self.axis_stride)?;
1078        checked_add(
1079            "fft",
1080            "output element offset",
1081            self.output_base,
1082            lane_offset,
1083        )
1084    }
1085}
1086
1087fn for_axis_lane(
1088    in_shape: &[usize],
1089    axis: usize,
1090    out_axis_len: usize,
1091    mut f: impl FnMut(LaneContext) -> tenferro_tensor::Result<()>,
1092) -> tenferro_tensor::Result<()> {
1093    let in_axis_len = in_shape[axis];
1094    let axis_stride = checked_shape_product("fft", "axis stride", &in_shape[..axis])?;
1095    let outer = checked_shape_product("fft", "outer lane count", &in_shape[axis + 1..])?;
1096    let in_block = checked_mul("fft", "input lane block", axis_stride, in_axis_len)?;
1097    let out_block = checked_mul("fft", "output lane block", axis_stride, out_axis_len)?;
1098    let _input_len = checked_mul("fft", "input lane coverage", outer, in_block)?;
1099    let _output_len = checked_mul("fft", "output lane coverage", outer, out_block)?;
1100
1101    for outer_idx in 0..outer {
1102        let in_outer_base = checked_mul("fft", "input outer base", outer_idx, in_block)?;
1103        let out_outer_base = checked_mul("fft", "output outer base", outer_idx, out_block)?;
1104        for inner in 0..axis_stride {
1105            let input_base = checked_add("fft", "input lane base", in_outer_base, inner)?;
1106            let output_base = checked_add("fft", "output lane base", out_outer_base, inner)?;
1107            f(LaneContext {
1108                input_base,
1109                output_base,
1110                axis_stride,
1111                in_axis_len,
1112            })?;
1113        }
1114    }
1115    Ok(())
1116}
1117
1118#[cfg(test)]
1119mod tests {
1120    use super::*;
1121
1122    #[test]
1123    fn fft_infer_output_meta_rejects_invalid_trait_inputs_without_panicking() {
1124        let op = FftOp::new(FftKind::R2C { onesided: true }, 0, None, FftNorm::Backward);
1125        let shape = [SymDim::from(4usize)];
1126
1127        assert!(op.infer_output_meta(&[], &[&shape]).is_empty());
1128        assert!(op.infer_output_meta(&[DType::F64], &[]).is_empty());
1129        assert!(op.infer_output_meta(&[DType::I64], &[&shape]).is_empty());
1130
1131        let bad_axis = FftOp::new(FftKind::C2C { forward: true }, 2, None, FftNorm::Backward);
1132        assert!(bad_axis
1133            .infer_output_meta(&[DType::C64], &[&shape])
1134            .is_empty());
1135    }
1136
1137    #[test]
1138    fn checked_shape_product_rejects_overflow_before_allocation() {
1139        let err = checked_shape_product("fft", "output", &[usize::MAX, 2])
1140            .expect_err("overflowing output shape should be rejected");
1141
1142        assert!(err.to_string().contains("overflows usize"), "{err}");
1143    }
1144
1145    #[test]
1146    fn irfft_default_output_length_rejects_overflow() {
1147        let err = output_shape_c2r(&[usize::MAX], 0, None)
1148            .expect_err("default irfft output length should reject overflow");
1149
1150        assert!(err.to_string().contains("overflows usize"), "{err}");
1151    }
1152
1153    #[test]
1154    fn axis_lane_layout_rejects_stride_overflow() {
1155        let err = for_axis_lane(&[usize::MAX, 2], 1, 2, |_| Ok(()))
1156            .expect_err("lane layout should reject stride overflow");
1157
1158        assert!(err.to_string().contains("overflows usize"), "{err}");
1159    }
1160}