1use std::any::Any;
37use std::hash::Hasher;
38use std::mem::MaybeUninit;
39use std::sync::Arc;
40
41#[cfg(feature = "autodiff")]
42use computegraph::types::{LocalValueId, OperationRole, ValueKey, ValueRef};
43use num_complex::Complex;
44use num_traits::{Float, FromPrimitive, Zero};
45use rustfft::{FftNum, FftPlanner};
46#[cfg(feature = "autodiff")]
47use tenferro_ad::extension::{ExtensionAdRule, ExtensionRegistryError, ExtensionRuleSet};
48use tenferro_extension_macros::define_extension_runtime;
49#[cfg(feature = "autodiff")]
50use tenferro_ops::ad::PrimitiveRuleBuilder;
51#[cfg(feature = "autodiff")]
52use tenferro_ops::std_tensor_op::StdTensorOp;
53#[cfg(feature = "autodiff")]
54use tenferro_ops::ShapeGuardContext;
55use tenferro_ops::SymDim;
56use tenferro_runtime::extension::{apply, ExtensionExecutionContext, ExtensionOp};
57use tenferro_runtime::{Error, Result, TracedTensor};
58use tenferro_tensor::{
59 DType, DeviceKind, MemoryKind, Placement, Tensor, TensorBackend, TensorRead, TypedTensor,
60};
61#[cfg(feature = "autodiff")]
62use tidu::{ADRuleError, ADRuleKind, ADRuleResult};
63
64pub const FFT_EXTENSION_FAMILY_ID: &str = "tenferro-fft.fft.v1";
75
76pub trait TracedTensorFftExt {
78 fn fft(&self, n: Option<usize>, axis: isize, norm: FftNorm) -> Result<TracedTensor>;
79 fn ifft(&self, n: Option<usize>, axis: isize, norm: FftNorm) -> Result<TracedTensor>;
80 fn rfft(&self, n: Option<usize>, axis: isize, norm: FftNorm) -> Result<TracedTensor>;
81 fn irfft(&self, n: Option<usize>, axis: isize, norm: FftNorm) -> Result<TracedTensor>;
82}
83
84impl TracedTensorFftExt for TracedTensor {
85 fn fft(&self, n: Option<usize>, axis: isize, norm: FftNorm) -> Result<TracedTensor> {
86 fft(self, n, axis, norm)
87 }
88
89 fn ifft(&self, n: Option<usize>, axis: isize, norm: FftNorm) -> Result<TracedTensor> {
90 ifft(self, n, axis, norm)
91 }
92
93 fn rfft(&self, n: Option<usize>, axis: isize, norm: FftNorm) -> Result<TracedTensor> {
94 rfft(self, n, axis, norm)
95 }
96
97 fn irfft(&self, n: Option<usize>, axis: isize, norm: FftNorm) -> Result<TracedTensor> {
98 irfft(self, n, axis, norm)
99 }
100}
101
102#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
115pub enum FftNorm {
116 #[default]
118 Backward,
119 Forward,
121 Ortho,
123}
124
125#[cfg(feature = "autodiff")]
126impl FftNorm {
127 fn c2c_adjoint(self) -> Self {
128 match self {
129 Self::Backward => Self::Forward,
130 Self::Forward => Self::Backward,
131 Self::Ortho => Self::Ortho,
132 }
133 }
134}
135
136#[derive(Clone, Copy, Debug, Eq, PartialEq)]
137enum FftKind {
138 C2C { forward: bool },
139 R2C { onesided: bool },
140 C2R,
141}
142
143#[derive(Clone, Debug, PartialEq)]
144struct FftOp {
145 kind: FftKind,
146 axis: usize,
147 n: Option<usize>,
148 norm: FftNorm,
149}
150
151impl FftOp {
152 fn new(kind: FftKind, axis: usize, n: Option<usize>, norm: FftNorm) -> Self {
153 Self {
154 kind,
155 axis,
156 n,
157 norm,
158 }
159 }
160
161 #[cfg(feature = "autodiff")]
162 fn c2c_adjoint(&self) -> Option<Self> {
163 match self.kind {
164 FftKind::C2C { forward } => Some(Self {
165 kind: FftKind::C2C { forward: !forward },
166 axis: self.axis,
167 n: self.n,
168 norm: self.norm.c2c_adjoint(),
169 }),
170 FftKind::R2C { .. } | FftKind::C2R => None,
171 }
172 }
173}
174
175impl ExtensionOp for FftOp {
176 fn family_id(&self) -> &'static str {
177 FFT_EXTENSION_FAMILY_ID
178 }
179
180 fn payload_hash(&self, hasher: &mut dyn Hasher) {
181 let kind = match self.kind {
182 FftKind::C2C { forward: true } => 0,
183 FftKind::C2C { forward: false } => 1,
184 FftKind::R2C { onesided: true } => 2,
185 FftKind::R2C { onesided: false } => 3,
186 FftKind::C2R => 4,
187 };
188 hasher.write_u8(kind);
189 hasher.write_usize(self.axis);
190 match self.n {
191 Some(n) => {
192 hasher.write_u8(1);
193 hasher.write_usize(n);
194 }
195 None => hasher.write_u8(0),
196 }
197 let norm = match self.norm {
198 FftNorm::Backward => 0,
199 FftNorm::Forward => 1,
200 FftNorm::Ortho => 2,
201 };
202 hasher.write_u8(norm);
203 }
204
205 fn payload_eq(&self, other: &dyn ExtensionOp) -> bool {
206 other
207 .as_any()
208 .downcast_ref::<FftOp>()
209 .is_some_and(|that| self == that)
210 }
211
212 fn clone_arc(&self) -> Arc<dyn ExtensionOp> {
213 Arc::new(self.clone())
214 }
215
216 fn as_any(&self) -> &dyn Any {
217 self
218 }
219
220 fn input_count(&self) -> usize {
221 1
222 }
223
224 fn output_count(&self) -> usize {
225 1
226 }
227
228 fn infer_output_meta(
229 &self,
230 input_dtypes: &[DType],
231 input_shapes: &[&[SymDim]],
232 ) -> Vec<(DType, Vec<SymDim>)> {
233 let [input_dtype] = input_dtypes else {
237 return Vec::new();
238 };
239 let [input_shape] = input_shapes else {
240 return Vec::new();
241 };
242 if self.axis >= input_shape.len() {
243 return Vec::new();
244 }
245
246 let mut out_shape = input_shape.to_vec();
247 let output_dtype = match self.kind {
248 FftKind::C2C { .. } => {
249 if !matches!(input_dtype, DType::C32 | DType::C64) {
250 return Vec::new();
251 }
252 *input_dtype
253 }
254 FftKind::R2C { onesided } => {
255 let len = transform_len_dim(self.n, &input_shape[self.axis]);
256 out_shape[self.axis] = if onesided { len / 2usize + 1usize } else { len };
257 match input_dtype {
258 DType::F32 => DType::C32,
259 DType::F64 => DType::C64,
260 _ => return Vec::new(),
261 }
262 }
263 FftKind::C2R => {
264 out_shape[self.axis] = match self.n {
265 Some(n) => SymDim::from(n),
266 None => (input_shape[self.axis].clone() - 1usize) * 2usize,
267 };
268 match input_dtype {
269 DType::C32 => DType::F32,
270 DType::C64 => DType::F64,
271 _ => return Vec::new(),
272 }
273 }
274 };
275
276 if matches!(self.kind, FftKind::C2C { .. }) {
277 out_shape[self.axis] = transform_len_dim(self.n, &input_shape[self.axis]);
278 }
279
280 vec![(output_dtype, out_shape)]
281 }
282
283 fn eager_execute(&self, inputs: &[&Tensor]) -> tenferro_tensor::Result<Vec<Tensor>> {
284 execute_host_fft_op(self, inputs)
285 }
286}
287
288fn execute_host_fft_op(op: &FftOp, inputs: &[&Tensor]) -> tenferro_tensor::Result<Vec<Tensor>> {
289 if inputs.len() != 1 {
290 return Err(tenferro_tensor::Error::InvalidConfig {
291 op: "tenferro-fft",
292 message: format!("expected 1 input, got {}", inputs.len()),
293 });
294 }
295 validate_host_fft_input(fft_op_name(op.kind), inputs[0])?;
296
297 let output = match (op.kind, inputs[0]) {
298 (FftKind::C2C { forward }, Tensor::C64(input)) => {
299 Tensor::C64(TypedTensor::from_vec_col_major(
300 output_shape_c2c(input.shape(), op.axis, op.n)?,
301 execute_c2c(input, op.axis, op.n, forward, op.norm)?,
302 )?)
303 }
304 (FftKind::C2C { forward }, Tensor::C32(input)) => {
305 Tensor::C32(TypedTensor::from_vec_col_major(
306 output_shape_c2c(input.shape(), op.axis, op.n)?,
307 execute_c2c(input, op.axis, op.n, forward, op.norm)?,
308 )?)
309 }
310 (FftKind::R2C { onesided }, Tensor::F64(input)) => {
311 Tensor::C64(TypedTensor::from_vec_col_major(
312 output_shape_r2c(input.shape(), op.axis, op.n, onesided)?,
313 execute_r2c(input, op.axis, op.n, onesided, op.norm)?,
314 )?)
315 }
316 (FftKind::R2C { onesided }, Tensor::F32(input)) => {
317 Tensor::C32(TypedTensor::from_vec_col_major(
318 output_shape_r2c(input.shape(), op.axis, op.n, onesided)?,
319 execute_r2c(input, op.axis, op.n, onesided, op.norm)?,
320 )?)
321 }
322 (FftKind::C2R, Tensor::C64(input)) => Tensor::F64(TypedTensor::from_vec_col_major(
323 output_shape_c2r(input.shape(), op.axis, op.n)?,
324 execute_c2r(input, op.axis, op.n, op.norm)?,
325 )?),
326 (FftKind::C2R, Tensor::C32(input)) => Tensor::F32(TypedTensor::from_vec_col_major(
327 output_shape_c2r(input.shape(), op.axis, op.n)?,
328 execute_c2r(input, op.axis, op.n, op.norm)?,
329 )?),
330 (kind, other) => {
331 return Err(tenferro_tensor::Error::DTypeMismatch {
332 op: match kind {
333 FftKind::C2C { .. } => "fft",
334 FftKind::R2C { .. } => "rfft",
335 FftKind::C2R => "irfft",
336 },
337 lhs: expected_dtype_for(kind),
338 rhs: other.dtype(),
339 });
340 }
341 };
342 Ok(vec![output])
343}
344
345fn tensor_placement(input: &Tensor) -> &Placement {
346 input.placement()
347}
348
349fn tensor_has_backend_buffer(input: &Tensor) -> bool {
350 input.is_backend_buffer()
351}
352
353fn validate_host_fft_input(op: &'static str, input: &Tensor) -> tenferro_tensor::Result<()> {
354 let placement = tensor_placement(input);
355 let is_device = matches!(placement.memory_kind, MemoryKind::Device);
356 if !is_device && !tensor_has_backend_buffer(input) {
357 return Ok(());
358 }
359
360 let location = match placement.device.as_ref().map(|device| &device.kind) {
361 Some(DeviceKind::Gpu(kind)) => format!("GPU backend {kind:?}"),
362 Some(kind) => format!("device kind {kind:?}"),
363 None if is_device => "device tensor without device metadata".to_string(),
364 None => "backend buffer".to_string(),
365 };
366 Err(tenferro_tensor::Error::backend_failure(
367 op,
368 format!(
369 "tenferro-fft supports host tensors only; unsupported {location} input; \
370 download the tensor to CPU before FFT"
371 ),
372 ))
373}
374
375#[cfg(feature = "autodiff")]
376#[derive(Debug)]
377struct FftAdRule;
378
379#[cfg(feature = "autodiff")]
380impl ExtensionAdRule for FftAdRule {
381 fn family_id(&self) -> &'static str {
382 FFT_EXTENSION_FAMILY_ID
383 }
384
385 fn linearize(
386 &self,
387 op: &dyn ExtensionOp,
388 builder: &mut dyn PrimitiveRuleBuilder,
389 _primal_in: &[ValueKey<StdTensorOp>],
390 _primal_out: &[ValueKey<StdTensorOp>],
391 tangent_in: &[Option<LocalValueId>],
392 _ctx: &mut ShapeGuardContext,
393 ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
394 let fft_op = fft_payload(op, ADRuleKind::Jvp)?;
395 if !matches!(fft_op.kind, FftKind::C2C { .. }) {
396 return Err(ADRuleError::unsupported(
397 fft_ad_family_id(fft_op.kind),
398 ADRuleKind::Jvp,
399 ));
400 }
401
402 match tangent_in[0] {
403 Some(dx) => {
404 let outputs = builder.add_operation(
405 StdTensorOp::Extension(Arc::new(fft_op.clone())),
406 vec![ValueRef::Local(dx)],
407 OperationRole::Linearized {
408 active_mask: vec![true],
409 },
410 );
411 Ok(vec![Some(outputs[0])])
412 }
413 None => Ok(vec![None]),
414 }
415 }
416
417 fn transpose_rule(
418 &self,
419 op: &dyn ExtensionOp,
420 builder: &mut dyn PrimitiveRuleBuilder,
421 cotangent_out: &[Option<LocalValueId>],
422 _inputs: &[ValueRef<StdTensorOp>],
423 _mode: &OperationRole,
424 _ctx: &mut ShapeGuardContext,
425 ) -> ADRuleResult<Vec<Option<LocalValueId>>> {
426 let fft_op = fft_payload(op, ADRuleKind::Transpose)?;
427 if !matches!(fft_op.kind, FftKind::C2C { .. }) {
428 return Err(ADRuleError::unsupported(
429 fft_ad_family_id(fft_op.kind),
430 ADRuleKind::Transpose,
431 ));
432 }
433
434 match cotangent_out[0] {
435 Some(ct) => {
436 let adjoint_op = fft_op.c2c_adjoint().ok_or_else(|| {
437 ADRuleError::unsupported(FFT_EXTENSION_FAMILY_ID, ADRuleKind::Transpose)
438 })?;
439 let outputs = builder.add_operation(
440 StdTensorOp::Extension(Arc::new(adjoint_op)),
441 vec![ValueRef::Local(ct)],
442 OperationRole::Linearized {
443 active_mask: vec![true],
444 },
445 );
446 Ok(vec![Some(outputs[0])])
447 }
448 None => Ok(vec![None]),
449 }
450 }
451}
452
453#[cfg(feature = "autodiff")]
455pub fn ad_rules() -> std::result::Result<ExtensionRuleSet, ExtensionRegistryError> {
456 ExtensionRuleSet::new().with_rule(Arc::new(FftAdRule))
457}
458
459fn execute_fft_extension<B: TensorBackend + 'static>(
460 op: &FftOp,
461 inputs: &[&Tensor],
462 _ctx: &mut ExtensionExecutionContext<'_, B>,
463) -> tenferro_tensor::Result<Vec<Tensor>> {
464 execute_host_fft_op(op, inputs)
465}
466
467fn execute_fft_extension_reads<B: TensorBackend + 'static>(
468 op: &FftOp,
469 inputs: &[TensorRead<'_>],
470 ctx: &mut ExtensionExecutionContext<'_, B>,
471) -> tenferro_tensor::Result<Vec<Tensor>> {
472 let _ = ctx;
473 let materialized_inputs: Vec<Tensor> = inputs
476 .iter()
477 .map(TensorRead::to_tensor)
478 .collect::<tenferro_tensor::Result<_>>()?;
479 let input_refs: Vec<&Tensor> = materialized_inputs.iter().collect();
480 execute_host_fft_op(op, &input_refs)
481}
482
483define_extension_runtime! {
484 runtime = FftRuntime,
485 family_id = FFT_EXTENSION_FAMILY_ID,
486 op_type = FftOp,
487 execute = execute_fft_extension,
488 execute_reads = execute_fft_extension_reads,
489 register_fn = register_runtime,
490}
491
492fn fft(input: &TracedTensor, n: Option<usize>, axis: isize, norm: FftNorm) -> Result<TracedTensor> {
516 let kind = match input.dtype {
517 DType::C32 | DType::C64 => FftKind::C2C { forward: true },
518 DType::F32 | DType::F64 => FftKind::R2C { onesided: false },
519 DType::I32 | DType::I64 | DType::Bool => {
520 return Err(fft_config_error(
521 "fft",
522 format!(
523 "fft expects real or complex floating input, got {:?}",
524 input.dtype
525 ),
526 ))
527 }
528 };
529 apply_unary_fft("fft", input, kind, n, axis, norm)
530}
531
532fn ifft(
553 input: &TracedTensor,
554 n: Option<usize>,
555 axis: isize,
556 norm: FftNorm,
557) -> Result<TracedTensor> {
558 if !matches!(input.dtype, DType::C32 | DType::C64) {
559 return Err(fft_config_error(
560 "ifft",
561 format!("ifft expects C32 or C64 input; got {:?}", input.dtype),
562 ));
563 }
564 apply_unary_fft(
565 "ifft",
566 input,
567 FftKind::C2C { forward: false },
568 n,
569 axis,
570 norm,
571 )
572}
573
574fn rfft(
599 input: &TracedTensor,
600 n: Option<usize>,
601 axis: isize,
602 norm: FftNorm,
603) -> Result<TracedTensor> {
604 if !matches!(input.dtype, DType::F32 | DType::F64) {
605 return Err(fft_config_error(
606 "rfft",
607 format!("rfft expects F32 or F64 input; got {:?}", input.dtype),
608 ));
609 }
610 apply_unary_fft(
611 "rfft",
612 input,
613 FftKind::R2C { onesided: true },
614 n,
615 axis,
616 norm,
617 )
618}
619
620fn irfft(
648 input: &TracedTensor,
649 n: Option<usize>,
650 axis: isize,
651 norm: FftNorm,
652) -> Result<TracedTensor> {
653 if !matches!(input.dtype, DType::C32 | DType::C64) {
654 return Err(fft_config_error(
655 "irfft",
656 format!("irfft expects C32 or C64 input; got {:?}", input.dtype),
657 ));
658 }
659 apply_unary_fft("irfft", input, FftKind::C2R, n, axis, norm)
660}
661
662fn apply_unary_fft(
663 op_name: &'static str,
664 input: &TracedTensor,
665 kind: FftKind,
666 n: Option<usize>,
667 axis: isize,
668 norm: FftNorm,
669) -> Result<TracedTensor> {
670 validate_n(op_name, n)?;
671 let axis = normalize_axis(op_name, axis, input.rank)?;
672 validate_resolved_transform_len(op_name, input, n, axis)?;
673 let op = Arc::new(FftOp::new(kind, axis, n, norm));
674 let mut outputs = apply(op, &[input])?;
675 outputs
676 .pop()
677 .ok_or_else(|| Error::Internal("FFT extension declares exactly one output".into()))
678}
679
680fn normalize_axis(op: &'static str, axis: isize, rank: usize) -> Result<usize> {
681 if rank == 0 {
682 return Err(fft_config_error(op, "tenferro-fft requires rank >= 1"));
683 }
684 let rank_isize = rank as isize;
685 let normalized = if axis < 0 { rank_isize + axis } else { axis };
686 if normalized < 0 || normalized >= rank_isize {
687 return Err(fft_config_error(
688 op,
689 format!("tenferro-fft axis {axis} out of bounds for rank {rank}"),
690 ));
691 }
692 Ok(normalized as usize)
693}
694
695fn validate_n(op: &'static str, n: Option<usize>) -> Result<()> {
696 if n == Some(0) {
697 return Err(fft_config_error(
698 op,
699 "tenferro-fft transform length n must be positive",
700 ));
701 }
702 Ok(())
703}
704
705fn validate_resolved_transform_len(
706 op: &'static str,
707 input: &TracedTensor,
708 n: Option<usize>,
709 axis: usize,
710) -> Result<()> {
711 if n.is_some() {
712 return Ok(());
713 }
714 if input
715 .try_concrete_shape()
716 .and_then(|shape| shape.get(axis).copied())
717 == Some(0)
718 {
719 return Err(fft_config_error(
720 op,
721 "tenferro-fft transform length n must be positive",
722 ));
723 }
724 Ok(())
725}
726
727fn fft_config_error(op: &'static str, message: impl std::fmt::Display) -> Error {
728 Error::TensorRuntime(tenferro_tensor::Error::InvalidConfig {
729 op,
730 message: message.to_string(),
731 })
732}
733
734fn transform_len_dim(n: Option<usize>, input_dim: &SymDim) -> SymDim {
735 n.map(SymDim::from).unwrap_or_else(|| input_dim.clone())
736}
737
738fn expected_dtype_for(kind: FftKind) -> DType {
739 match kind {
740 FftKind::C2C { .. } | FftKind::C2R => DType::C64,
741 FftKind::R2C { .. } => DType::F64,
742 }
743}
744
745fn fft_op_name(kind: FftKind) -> &'static str {
746 match kind {
747 FftKind::C2C { forward: true } => "fft",
748 FftKind::C2C { forward: false } => "ifft",
749 FftKind::R2C { .. } => "rfft",
750 FftKind::C2R => "irfft",
751 }
752}
753
754#[cfg(feature = "autodiff")]
755fn fft_ad_family_id(kind: FftKind) -> &'static str {
756 match kind {
757 FftKind::C2C { .. } => FFT_EXTENSION_FAMILY_ID,
758 FftKind::R2C { .. } => "tenferro-fft.rfft.v1",
759 FftKind::C2R => "tenferro-fft.irfft.v1",
760 }
761}
762
763#[cfg(feature = "autodiff")]
764fn fft_payload<'a>(op: &'a dyn ExtensionOp, rule: ADRuleKind) -> ADRuleResult<&'a FftOp> {
765 op.as_any()
766 .downcast_ref::<FftOp>()
767 .ok_or_else(|| ADRuleError::unsupported(FFT_EXTENSION_FAMILY_ID, rule))
768}
769
770fn output_shape_c2c(
771 shape: &[usize],
772 axis: usize,
773 n: Option<usize>,
774) -> tenferro_tensor::Result<Vec<usize>> {
775 let len = transform_len(shape, axis, n)?;
776 let mut out_shape = shape.to_vec();
777 out_shape[axis] = len;
778 Ok(out_shape)
779}
780
781fn output_shape_r2c(
782 shape: &[usize],
783 axis: usize,
784 n: Option<usize>,
785 onesided: bool,
786) -> tenferro_tensor::Result<Vec<usize>> {
787 let len = transform_len(shape, axis, n)?;
788 let mut out_shape = shape.to_vec();
789 out_shape[axis] = if onesided { len / 2 + 1 } else { len };
790 Ok(out_shape)
791}
792
793fn output_shape_c2r(
794 shape: &[usize],
795 axis: usize,
796 n: Option<usize>,
797) -> tenferro_tensor::Result<Vec<usize>> {
798 validate_axis("irfft", shape, axis)?;
799 let input_len = shape[axis];
800 if input_len == 0 {
801 return Err(tenferro_tensor::Error::InvalidConfig {
802 op: "irfft",
803 message: "input spectrum axis length must be positive".to_string(),
804 });
805 }
806 let len = match n {
807 Some(len) => len,
808 None => input_len
809 .checked_sub(1)
810 .and_then(|len| len.checked_mul(2))
811 .ok_or_else(|| tenferro_tensor::Error::InvalidConfig {
812 op: "irfft",
813 message: "default output length overflows usize".to_string(),
814 })?,
815 };
816 if len == 0 {
817 return Err(tenferro_tensor::Error::InvalidConfig {
818 op: "irfft",
819 message: "output length must be positive".to_string(),
820 });
821 }
822 let mut out_shape = shape.to_vec();
823 out_shape[axis] = len;
824 Ok(out_shape)
825}
826
827fn transform_len(shape: &[usize], axis: usize, n: Option<usize>) -> tenferro_tensor::Result<usize> {
828 validate_axis("fft", shape, axis)?;
829 let len = n.unwrap_or(shape[axis]);
830 if len == 0 {
831 return Err(tenferro_tensor::Error::InvalidConfig {
832 op: "fft",
833 message: "transform length must be positive".to_string(),
834 });
835 }
836 Ok(len)
837}
838
839fn validate_axis(op: &'static str, shape: &[usize], axis: usize) -> tenferro_tensor::Result<()> {
840 if axis >= shape.len() {
841 return Err(tenferro_tensor::Error::AxisOutOfBounds {
842 op,
843 axis,
844 rank: shape.len(),
845 });
846 }
847 Ok(())
848}
849
850fn checked_shape_product(
851 op: &'static str,
852 role: &'static str,
853 shape: &[usize],
854) -> tenferro_tensor::Result<usize> {
855 shape
856 .iter()
857 .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
858 .ok_or_else(|| tenferro_tensor::Error::InvalidConfig {
859 op,
860 message: format!("{role} shape product overflows usize"),
861 })
862}
863
864fn checked_mul(
865 op: &'static str,
866 role: &'static str,
867 lhs: usize,
868 rhs: usize,
869) -> tenferro_tensor::Result<usize> {
870 lhs.checked_mul(rhs)
871 .ok_or_else(|| tenferro_tensor::Error::InvalidConfig {
872 op,
873 message: format!("{role} overflows usize"),
874 })
875}
876
877fn checked_add(
878 op: &'static str,
879 role: &'static str,
880 lhs: usize,
881 rhs: usize,
882) -> tenferro_tensor::Result<usize> {
883 lhs.checked_add(rhs)
884 .ok_or_else(|| tenferro_tensor::Error::InvalidConfig {
885 op,
886 message: format!("{role} overflows usize"),
887 })
888}
889
890fn uninit_output_vec<T>(len: usize) -> Vec<MaybeUninit<T>> {
891 let mut output = Vec::with_capacity(len);
892 unsafe { output.set_len(len) };
895 output
896}
897
898unsafe fn assume_init_output_vec<T>(mut output: Vec<MaybeUninit<T>>) -> Vec<T> {
899 let len = output.len();
900 let capacity = output.capacity();
901 let ptr = output.as_mut_ptr().cast::<T>();
902 std::mem::forget(output);
903 unsafe { Vec::from_raw_parts(ptr, len, capacity) }
906}
907
908fn execute_c2c<T>(
909 input: &TypedTensor<Complex<T>>,
910 axis: usize,
911 n: Option<usize>,
912 forward: bool,
913 norm: FftNorm,
914) -> tenferro_tensor::Result<Vec<Complex<T>>>
915where
916 T: FftNum + Float + FromPrimitive,
917{
918 let in_shape = input.shape();
919 let fft_len = transform_len(in_shape, axis, n)?;
920 let out_shape = output_shape_c2c(in_shape, axis, n)?;
921 let out_axis_len = out_shape[axis];
922 let input_data = input.host_data()?;
923 let output_len = checked_shape_product("fft", "output", &out_shape)?;
924 let mut output = uninit_output_vec(output_len);
925 let mut planner = FftPlanner::<T>::new();
926 let fft_plan = if forward {
927 planner.plan_fft_forward(fft_len)
928 } else {
929 planner.plan_fft_inverse(fft_len)
930 };
931 let scale: T = scale_for(norm, forward, fft_len)?;
932 let mut lane = vec![Complex::zero(); fft_len];
933
934 for_axis_lane(in_shape, axis, out_axis_len, |lane_ctx| {
935 lane.fill(Complex::zero());
936 let copy_len = lane_ctx.in_axis_len.min(fft_len);
937 for (k, slot) in lane.iter_mut().take(copy_len).enumerate() {
938 *slot = input_data[lane_ctx.input_offset(k)?];
939 }
940 fft_plan.process(&mut lane);
941 if scale != T::one() {
942 for value in &mut lane {
943 *value = *value * scale;
944 }
945 }
946 for (k, value) in lane.iter().take(out_axis_len).copied().enumerate() {
947 output[lane_ctx.output_offset(k)?].write(value);
948 }
949 Ok(())
950 })?;
951
952 Ok(unsafe { assume_init_output_vec(output) })
955}
956
957fn execute_r2c<T>(
958 input: &TypedTensor<T>,
959 axis: usize,
960 n: Option<usize>,
961 onesided: bool,
962 norm: FftNorm,
963) -> tenferro_tensor::Result<Vec<Complex<T>>>
964where
965 T: FftNum + Float + FromPrimitive,
966{
967 let in_shape = input.shape();
968 let fft_len = transform_len(in_shape, axis, n)?;
969 let out_shape = output_shape_r2c(in_shape, axis, n, onesided)?;
970 let out_axis_len = out_shape[axis];
971 let input_data = input.host_data()?;
972 let output_len = checked_shape_product("rfft", "output", &out_shape)?;
973 let mut output = uninit_output_vec(output_len);
974 let mut planner = FftPlanner::<T>::new();
975 let fft_plan = planner.plan_fft_forward(fft_len);
976 let scale: T = scale_for(norm, true, fft_len)?;
977 let mut lane = vec![Complex::zero(); fft_len];
978
979 for_axis_lane(in_shape, axis, out_axis_len, |lane_ctx| {
980 lane.fill(Complex::zero());
981 let copy_len = lane_ctx.in_axis_len.min(fft_len);
982 for (k, slot) in lane.iter_mut().take(copy_len).enumerate() {
983 *slot = Complex::new(input_data[lane_ctx.input_offset(k)?], T::zero());
984 }
985 fft_plan.process(&mut lane);
986 if scale != T::one() {
987 for value in &mut lane {
988 *value = *value * scale;
989 }
990 }
991 for (k, value) in lane.iter().take(out_axis_len).copied().enumerate() {
992 output[lane_ctx.output_offset(k)?].write(value);
993 }
994 Ok(())
995 })?;
996
997 Ok(unsafe { assume_init_output_vec(output) })
1000}
1001
1002fn execute_c2r<T>(
1003 input: &TypedTensor<Complex<T>>,
1004 axis: usize,
1005 n: Option<usize>,
1006 norm: FftNorm,
1007) -> tenferro_tensor::Result<Vec<T>>
1008where
1009 T: FftNum + Float + FromPrimitive,
1010{
1011 let in_shape = input.shape();
1012 let out_shape = output_shape_c2r(in_shape, axis, n)?;
1013 let out_axis_len = out_shape[axis];
1014 let expected_half = out_axis_len / 2 + 1;
1015 let input_data = input.host_data()?;
1016 let output_len = checked_shape_product("irfft", "output", &out_shape)?;
1017 let mut output = uninit_output_vec(output_len);
1018 let mut planner = FftPlanner::<T>::new();
1019 let fft_plan = planner.plan_fft_inverse(out_axis_len);
1020 let scale: T = scale_for(norm, false, out_axis_len)?;
1021 let mut lane = vec![Complex::zero(); out_axis_len];
1022
1023 for_axis_lane(in_shape, axis, out_axis_len, |lane_ctx| {
1024 lane.fill(Complex::zero());
1025 let copy_len = lane_ctx.in_axis_len.min(expected_half);
1026 for (k, slot) in lane.iter_mut().take(copy_len).enumerate() {
1027 *slot = input_data[lane_ctx.input_offset(k)?];
1028 }
1029 for k in expected_half..out_axis_len {
1030 let mirror = out_axis_len - k;
1031 if mirror < lane.len() {
1032 lane[k] = lane[mirror].conj();
1033 }
1034 }
1035 fft_plan.process(&mut lane);
1036 for (k, value) in lane.iter().take(out_axis_len).enumerate() {
1037 output[lane_ctx.output_offset(k)?].write(value.re * scale);
1038 }
1039 Ok(())
1040 })?;
1041
1042 Ok(unsafe { assume_init_output_vec(output) })
1045}
1046
1047fn scale_for<T>(norm: FftNorm, forward: bool, n: usize) -> tenferro_tensor::Result<T>
1048where
1049 T: Float + FromPrimitive,
1050{
1051 let len = T::from_usize(n).ok_or_else(|| tenferro_tensor::Error::InvalidConfig {
1052 op: "tenferro_fft::scale_for",
1053 message: format!("FFT length {n} cannot be represented as scalar"),
1054 })?;
1055 Ok(match (norm, forward) {
1056 (FftNorm::Backward, true) | (FftNorm::Forward, false) => T::one(),
1057 (FftNorm::Backward, false) | (FftNorm::Forward, true) => T::one() / len,
1058 (FftNorm::Ortho, _) => T::one() / len.sqrt(),
1059 })
1060}
1061
1062#[derive(Clone, Copy)]
1063struct LaneContext {
1064 input_base: usize,
1065 output_base: usize,
1066 axis_stride: usize,
1067 in_axis_len: usize,
1068}
1069
1070impl LaneContext {
1071 fn input_offset(self, k: usize) -> tenferro_tensor::Result<usize> {
1072 let lane_offset = checked_mul("fft", "input lane offset", k, self.axis_stride)?;
1073 checked_add("fft", "input element offset", self.input_base, lane_offset)
1074 }
1075
1076 fn output_offset(self, k: usize) -> tenferro_tensor::Result<usize> {
1077 let lane_offset = checked_mul("fft", "output lane offset", k, self.axis_stride)?;
1078 checked_add(
1079 "fft",
1080 "output element offset",
1081 self.output_base,
1082 lane_offset,
1083 )
1084 }
1085}
1086
1087fn for_axis_lane(
1088 in_shape: &[usize],
1089 axis: usize,
1090 out_axis_len: usize,
1091 mut f: impl FnMut(LaneContext) -> tenferro_tensor::Result<()>,
1092) -> tenferro_tensor::Result<()> {
1093 let in_axis_len = in_shape[axis];
1094 let axis_stride = checked_shape_product("fft", "axis stride", &in_shape[..axis])?;
1095 let outer = checked_shape_product("fft", "outer lane count", &in_shape[axis + 1..])?;
1096 let in_block = checked_mul("fft", "input lane block", axis_stride, in_axis_len)?;
1097 let out_block = checked_mul("fft", "output lane block", axis_stride, out_axis_len)?;
1098 let _input_len = checked_mul("fft", "input lane coverage", outer, in_block)?;
1099 let _output_len = checked_mul("fft", "output lane coverage", outer, out_block)?;
1100
1101 for outer_idx in 0..outer {
1102 let in_outer_base = checked_mul("fft", "input outer base", outer_idx, in_block)?;
1103 let out_outer_base = checked_mul("fft", "output outer base", outer_idx, out_block)?;
1104 for inner in 0..axis_stride {
1105 let input_base = checked_add("fft", "input lane base", in_outer_base, inner)?;
1106 let output_base = checked_add("fft", "output lane base", out_outer_base, inner)?;
1107 f(LaneContext {
1108 input_base,
1109 output_base,
1110 axis_stride,
1111 in_axis_len,
1112 })?;
1113 }
1114 }
1115 Ok(())
1116}
1117
1118#[cfg(test)]
1119mod tests {
1120 use super::*;
1121
1122 #[test]
1123 fn fft_infer_output_meta_rejects_invalid_trait_inputs_without_panicking() {
1124 let op = FftOp::new(FftKind::R2C { onesided: true }, 0, None, FftNorm::Backward);
1125 let shape = [SymDim::from(4usize)];
1126
1127 assert!(op.infer_output_meta(&[], &[&shape]).is_empty());
1128 assert!(op.infer_output_meta(&[DType::F64], &[]).is_empty());
1129 assert!(op.infer_output_meta(&[DType::I64], &[&shape]).is_empty());
1130
1131 let bad_axis = FftOp::new(FftKind::C2C { forward: true }, 2, None, FftNorm::Backward);
1132 assert!(bad_axis
1133 .infer_output_meta(&[DType::C64], &[&shape])
1134 .is_empty());
1135 }
1136
1137 #[test]
1138 fn checked_shape_product_rejects_overflow_before_allocation() {
1139 let err = checked_shape_product("fft", "output", &[usize::MAX, 2])
1140 .expect_err("overflowing output shape should be rejected");
1141
1142 assert!(err.to_string().contains("overflows usize"), "{err}");
1143 }
1144
1145 #[test]
1146 fn irfft_default_output_length_rejects_overflow() {
1147 let err = output_shape_c2r(&[usize::MAX], 0, None)
1148 .expect_err("default irfft output length should reject overflow");
1149
1150 assert!(err.to_string().contains("overflows usize"), "{err}");
1151 }
1152
1153 #[test]
1154 fn axis_lane_layout_rejects_stride_overflow() {
1155 let err = for_axis_lane(&[usize::MAX, 2], 1, 2, |_| Ok(()))
1156 .expect_err("lane layout should reject stride overflow");
1157
1158 assert!(err.to_string().contains("overflows usize"), "{err}");
1159 }
1160}