cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

TensorFFT.h (24345B)


      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2015 Jianwei Cui <thucjw@gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_FFT_H
     11 #define EIGEN_CXX11_TENSOR_TENSOR_FFT_H
     12 
     13 namespace Eigen {
     14 
     15 /** \class TensorFFT
     16   * \ingroup CXX11_Tensor_Module
     17   *
     18   * \brief Tensor FFT class.
     19   *
     20   * TODO:
     21   * Vectorize the Cooley Tukey and the Bluestein algorithm
     22   * Add support for multithreaded evaluation
     23   * Improve the performance on GPU
     24   */
     25 
     26 template <bool NeedUprade> struct MakeComplex {
     27   template <typename T>
     28   EIGEN_DEVICE_FUNC
     29   T operator() (const T& val) const { return val; }
     30 };
     31 
     32 template <> struct MakeComplex<true> {
     33   template <typename T>
     34   EIGEN_DEVICE_FUNC
     35   std::complex<T> operator() (const T& val) const { return std::complex<T>(val, 0); }
     36 };
     37 
     38 template <> struct MakeComplex<false> {
     39   template <typename T>
     40   EIGEN_DEVICE_FUNC
     41   std::complex<T> operator() (const std::complex<T>& val) const { return val; }
     42 };
     43 
     44 template <int ResultType> struct PartOf {
     45   template <typename T> T operator() (const T& val) const { return val; }
     46 };
     47 
     48 template <> struct PartOf<RealPart> {
     49   template <typename T> T operator() (const std::complex<T>& val) const { return val.real(); }
     50 };
     51 
     52 template <> struct PartOf<ImagPart> {
     53   template <typename T> T operator() (const std::complex<T>& val) const { return val.imag(); }
     54 };
     55 
     56 namespace internal {
     57 template <typename FFT, typename XprType, int FFTResultType, int FFTDir>
     58 struct traits<TensorFFTOp<FFT, XprType, FFTResultType, FFTDir> > : public traits<XprType> {
     59   typedef traits<XprType> XprTraits;
     60   typedef typename NumTraits<typename XprTraits::Scalar>::Real RealScalar;
     61   typedef typename std::complex<RealScalar> ComplexScalar;
     62   typedef typename XprTraits::Scalar InputScalar;
     63   typedef typename conditional<FFTResultType == RealPart || FFTResultType == ImagPart, RealScalar, ComplexScalar>::type OutputScalar;
     64   typedef typename XprTraits::StorageKind StorageKind;
     65   typedef typename XprTraits::Index Index;
     66   typedef typename XprType::Nested Nested;
     67   typedef typename remove_reference<Nested>::type _Nested;
     68   static const int NumDimensions = XprTraits::NumDimensions;
     69   static const int Layout = XprTraits::Layout;
     70   typedef typename traits<XprType>::PointerType PointerType;
     71 };
     72 
     73 template <typename FFT, typename XprType, int FFTResultType, int FFTDirection>
     74 struct eval<TensorFFTOp<FFT, XprType, FFTResultType, FFTDirection>, Eigen::Dense> {
     75   typedef const TensorFFTOp<FFT, XprType, FFTResultType, FFTDirection>& type;
     76 };
     77 
     78 template <typename FFT, typename XprType, int FFTResultType, int FFTDirection>
     79 struct nested<TensorFFTOp<FFT, XprType, FFTResultType, FFTDirection>, 1, typename eval<TensorFFTOp<FFT, XprType, FFTResultType, FFTDirection> >::type> {
     80   typedef TensorFFTOp<FFT, XprType, FFTResultType, FFTDirection> type;
     81 };
     82 
     83 }  // end namespace internal
     84 
     85 template <typename FFT, typename XprType, int FFTResultType, int FFTDir>
     86 class TensorFFTOp : public TensorBase<TensorFFTOp<FFT, XprType, FFTResultType, FFTDir>, ReadOnlyAccessors> {
     87  public:
     88   typedef typename Eigen::internal::traits<TensorFFTOp>::Scalar Scalar;
     89   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
     90   typedef typename std::complex<RealScalar> ComplexScalar;
     91   typedef typename internal::conditional<FFTResultType == RealPart || FFTResultType == ImagPart, RealScalar, ComplexScalar>::type OutputScalar;
     92   typedef OutputScalar CoeffReturnType;
     93   typedef typename Eigen::internal::nested<TensorFFTOp>::type Nested;
     94   typedef typename Eigen::internal::traits<TensorFFTOp>::StorageKind StorageKind;
     95   typedef typename Eigen::internal::traits<TensorFFTOp>::Index Index;
     96 
     97   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorFFTOp(const XprType& expr, const FFT& fft)
     98       : m_xpr(expr), m_fft(fft) {}
     99 
    100   EIGEN_DEVICE_FUNC
    101   const FFT& fft() const { return m_fft; }
    102 
    103   EIGEN_DEVICE_FUNC
    104   const typename internal::remove_all<typename XprType::Nested>::type& expression() const {
    105     return m_xpr;
    106   }
    107 
    108  protected:
    109   typename XprType::Nested m_xpr;
    110   const FFT m_fft;
    111 };
    112 
    113 // Eval as rvalue
    114 template <typename FFT, typename ArgType, typename Device, int FFTResultType, int FFTDir>
    115 struct TensorEvaluator<const TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir>, Device> {
    116   typedef TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir> XprType;
    117   typedef typename XprType::Index Index;
    118   static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
    119   typedef DSizes<Index, NumDims> Dimensions;
    120   typedef typename XprType::Scalar Scalar;
    121   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
    122   typedef typename std::complex<RealScalar> ComplexScalar;
    123   typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
    124   typedef internal::traits<XprType> XprTraits;
    125   typedef typename XprTraits::Scalar InputScalar;
    126   typedef typename internal::conditional<FFTResultType == RealPart || FFTResultType == ImagPart, RealScalar, ComplexScalar>::type OutputScalar;
    127   typedef OutputScalar CoeffReturnType;
    128   typedef typename PacketType<OutputScalar, Device>::type PacketReturnType;
    129   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
    130     typedef StorageMemory<CoeffReturnType, Device> Storage;
    131   typedef typename Storage::Type EvaluatorPointerType;
    132 
    133   enum {
    134     IsAligned = false,
    135     PacketAccess = true,
    136     BlockAccess = false,
    137     PreferBlockAccess = false,
    138     Layout = TensorEvaluator<ArgType, Device>::Layout,
    139     CoordAccess = false,
    140     RawAccess = false
    141   };
    142 
    143   //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
    144   typedef internal::TensorBlockNotImplemented TensorBlock;
    145   //===--------------------------------------------------------------------===//
    146 
    147   EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) : m_fft(op.fft()), m_impl(op.expression(), device), m_data(NULL), m_device(device) {
    148     const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
    149     for (int i = 0; i < NumDims; ++i) {
    150       eigen_assert(input_dims[i] > 0);
    151       m_dimensions[i] = input_dims[i];
    152     }
    153 
    154     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    155       m_strides[0] = 1;
    156       for (int i = 1; i < NumDims; ++i) {
    157         m_strides[i] = m_strides[i - 1] * m_dimensions[i - 1];
    158       }
    159     } else {
    160       m_strides[NumDims - 1] = 1;
    161       for (int i = NumDims - 2; i >= 0; --i) {
    162         m_strides[i] = m_strides[i + 1] * m_dimensions[i + 1];
    163       }
    164     }
    165     m_size = m_dimensions.TotalSize();
    166   }
    167 
    168   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
    169     return m_dimensions;
    170   }
    171 
    172   EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
    173     m_impl.evalSubExprsIfNeeded(NULL);
    174     if (data) {
    175       evalToBuf(data);
    176       return false;
    177     } else {
    178       m_data = (EvaluatorPointerType)m_device.get((CoeffReturnType*)(m_device.allocate_temp(sizeof(CoeffReturnType) * m_size)));
    179       evalToBuf(m_data);
    180       return true;
    181     }
    182   }
    183 
    184   EIGEN_STRONG_INLINE void cleanup() {
    185     if (m_data) {
    186       m_device.deallocate(m_data);
    187       m_data = NULL;
    188     }
    189     m_impl.cleanup();
    190   }
    191 
    192   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index) const {
    193     return m_data[index];
    194   }
    195 
    196   template <int LoadMode>
    197   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType
    198   packet(Index index) const {
    199     return internal::ploadt<PacketReturnType, LoadMode>(m_data + index);
    200   }
    201 
    202   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
    203   costPerCoeff(bool vectorized) const {
    204     return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
    205   }
    206 
    207   EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return m_data; }
    208 #ifdef EIGEN_USE_SYCL
    209   // binding placeholder accessors to a command group handler for SYCL
    210   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
    211     m_data.bind(cgh);
    212   }
    213 #endif
    214 
    215  private:
    216   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalToBuf(EvaluatorPointerType data) {
    217     const bool write_to_out = internal::is_same<OutputScalar, ComplexScalar>::value;
    218     ComplexScalar* buf = write_to_out ? (ComplexScalar*)data : (ComplexScalar*)m_device.allocate(sizeof(ComplexScalar) * m_size);
    219 
    220     for (Index i = 0; i < m_size; ++i) {
    221       buf[i] = MakeComplex<internal::is_same<InputScalar, RealScalar>::value>()(m_impl.coeff(i));
    222     }
    223 
    224     for (size_t i = 0; i < m_fft.size(); ++i) {
    225       Index dim = m_fft[i];
    226       eigen_assert(dim >= 0 && dim < NumDims);
    227       Index line_len = m_dimensions[dim];
    228       eigen_assert(line_len >= 1);
    229       ComplexScalar* line_buf = (ComplexScalar*)m_device.allocate(sizeof(ComplexScalar) * line_len);
    230       const bool is_power_of_two = isPowerOfTwo(line_len);
    231       const Index good_composite = is_power_of_two ? 0 : findGoodComposite(line_len);
    232       const Index log_len = is_power_of_two ? getLog2(line_len) : getLog2(good_composite);
    233 
    234       ComplexScalar* a = is_power_of_two ? NULL : (ComplexScalar*)m_device.allocate(sizeof(ComplexScalar) * good_composite);
    235       ComplexScalar* b = is_power_of_two ? NULL : (ComplexScalar*)m_device.allocate(sizeof(ComplexScalar) * good_composite);
    236       ComplexScalar* pos_j_base_powered = is_power_of_two ? NULL : (ComplexScalar*)m_device.allocate(sizeof(ComplexScalar) * (line_len + 1));
    237       if (!is_power_of_two) {
    238         // Compute twiddle factors
    239         //   t_n = exp(sqrt(-1) * pi * n^2 / line_len)
    240         // for n = 0, 1,..., line_len-1.
    241         // For n > 2 we use the recurrence t_n = t_{n-1}^2 / t_{n-2} * t_1^2
    242 
    243         // The recurrence is correct in exact arithmetic, but causes
    244         // numerical issues for large transforms, especially in
    245         // single-precision floating point.
    246         //
    247         // pos_j_base_powered[0] = ComplexScalar(1, 0);
    248         // if (line_len > 1) {
    249         //   const ComplexScalar pos_j_base = ComplexScalar(
    250         //       numext::cos(M_PI / line_len), numext::sin(M_PI / line_len));
    251         //   pos_j_base_powered[1] = pos_j_base;
    252         //   if (line_len > 2) {
    253         //     const ComplexScalar pos_j_base_sq = pos_j_base * pos_j_base;
    254         //     for (int i = 2; i < line_len + 1; ++i) {
    255         //       pos_j_base_powered[i] = pos_j_base_powered[i - 1] *
    256         //           pos_j_base_powered[i - 1] /
    257         //           pos_j_base_powered[i - 2] *
    258         //           pos_j_base_sq;
    259         //     }
    260         //   }
    261         // }
    262         // TODO(rmlarsen): Find a way to use Eigen's vectorized sin
    263         // and cosine functions here.
    264         for (int j = 0; j < line_len + 1; ++j) {
    265           double arg = ((EIGEN_PI * j) * j) / line_len;
    266           std::complex<double> tmp(numext::cos(arg), numext::sin(arg));
    267           pos_j_base_powered[j] = static_cast<ComplexScalar>(tmp);
    268         }
    269       }
    270 
    271       for (Index partial_index = 0; partial_index < m_size / line_len; ++partial_index) {
    272         const Index base_offset = getBaseOffsetFromIndex(partial_index, dim);
    273 
    274         // get data into line_buf
    275         const Index stride = m_strides[dim];
    276         if (stride == 1) {
    277           m_device.memcpy(line_buf, &buf[base_offset], line_len*sizeof(ComplexScalar));
    278         } else {
    279           Index offset = base_offset;
    280           for (int j = 0; j < line_len; ++j, offset += stride) {
    281             line_buf[j] = buf[offset];
    282           }
    283         }
    284 
    285         // process the line
    286         if (is_power_of_two) {
    287           processDataLineCooleyTukey(line_buf, line_len, log_len);
    288         }
    289         else {
    290           processDataLineBluestein(line_buf, line_len, good_composite, log_len, a, b, pos_j_base_powered);
    291         }
    292 
    293         // write back
    294         if (FFTDir == FFT_FORWARD && stride == 1) {
    295           m_device.memcpy(&buf[base_offset], line_buf, line_len*sizeof(ComplexScalar));
    296         } else {
    297           Index offset = base_offset;
    298           const ComplexScalar div_factor =  ComplexScalar(1.0 / line_len, 0);
    299           for (int j = 0; j < line_len; ++j, offset += stride) {
    300              buf[offset] = (FFTDir == FFT_FORWARD) ? line_buf[j] : line_buf[j] * div_factor;
    301           }
    302         }
    303       }
    304       m_device.deallocate(line_buf);
    305       if (!is_power_of_two) {
    306         m_device.deallocate(a);
    307         m_device.deallocate(b);
    308         m_device.deallocate(pos_j_base_powered);
    309       }
    310     }
    311 
    312     if(!write_to_out) {
    313       for (Index i = 0; i < m_size; ++i) {
    314         data[i] = PartOf<FFTResultType>()(buf[i]);
    315       }
    316       m_device.deallocate(buf);
    317     }
    318   }
    319 
    320   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static bool isPowerOfTwo(Index x) {
    321     eigen_assert(x > 0);
    322     return !(x & (x - 1));
    323   }
    324 
    325   // The composite number for padding, used in Bluestein's FFT algorithm
    326   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static Index findGoodComposite(Index n) {
    327     Index i = 2;
    328     while (i < 2 * n - 1) i *= 2;
    329     return i;
    330   }
    331 
    332   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static Index getLog2(Index m) {
    333     Index log2m = 0;
    334     while (m >>= 1) log2m++;
    335     return log2m;
    336   }
    337 
    338   // Call Cooley Tukey algorithm directly, data length must be power of 2
    339   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void processDataLineCooleyTukey(ComplexScalar* line_buf, Index line_len, Index log_len) {
    340     eigen_assert(isPowerOfTwo(line_len));
    341     scramble_FFT(line_buf, line_len);
    342     compute_1D_Butterfly<FFTDir>(line_buf, line_len, log_len);
    343   }
    344 
    345   // Call Bluestein's FFT algorithm, m is a good composite number greater than (2 * n - 1), used as the padding length
    346   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void processDataLineBluestein(ComplexScalar* line_buf, Index line_len, Index good_composite, Index log_len, ComplexScalar* a, ComplexScalar* b, const ComplexScalar* pos_j_base_powered) {
    347     Index n = line_len;
    348     Index m = good_composite;
    349     ComplexScalar* data = line_buf;
    350 
    351     for (Index i = 0; i < n; ++i) {
    352       if(FFTDir == FFT_FORWARD) {
    353         a[i] = data[i] * numext::conj(pos_j_base_powered[i]);
    354       }
    355       else {
    356         a[i] = data[i] * pos_j_base_powered[i];
    357       }
    358     }
    359     for (Index i = n; i < m; ++i) {
    360       a[i] = ComplexScalar(0, 0);
    361     }
    362 
    363     for (Index i = 0; i < n; ++i) {
    364       if(FFTDir == FFT_FORWARD) {
    365         b[i] = pos_j_base_powered[i];
    366       }
    367       else {
    368         b[i] = numext::conj(pos_j_base_powered[i]);
    369       }
    370     }
    371     for (Index i = n; i < m - n; ++i) {
    372       b[i] = ComplexScalar(0, 0);
    373     }
    374     for (Index i = m - n; i < m; ++i) {
    375       if(FFTDir == FFT_FORWARD) {
    376         b[i] = pos_j_base_powered[m-i];
    377       }
    378       else {
    379         b[i] = numext::conj(pos_j_base_powered[m-i]);
    380       }
    381     }
    382 
    383     scramble_FFT(a, m);
    384     compute_1D_Butterfly<FFT_FORWARD>(a, m, log_len);
    385 
    386     scramble_FFT(b, m);
    387     compute_1D_Butterfly<FFT_FORWARD>(b, m, log_len);
    388 
    389     for (Index i = 0; i < m; ++i) {
    390       a[i] *= b[i];
    391     }
    392 
    393     scramble_FFT(a, m);
    394     compute_1D_Butterfly<FFT_REVERSE>(a, m, log_len);
    395 
    396     //Do the scaling after ifft
    397     for (Index i = 0; i < m; ++i) {
    398       a[i] /= m;
    399     }
    400 
    401     for (Index i = 0; i < n; ++i) {
    402       if(FFTDir == FFT_FORWARD) {
    403         data[i] = a[i] * numext::conj(pos_j_base_powered[i]);
    404       }
    405       else {
    406         data[i] = a[i] * pos_j_base_powered[i];
    407       }
    408     }
    409   }
    410 
    411   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static void scramble_FFT(ComplexScalar* data, Index n) {
    412     eigen_assert(isPowerOfTwo(n));
    413     Index j = 1;
    414     for (Index i = 1; i < n; ++i){
    415       if (j > i) {
    416         std::swap(data[j-1], data[i-1]);
    417       }
    418       Index m = n >> 1;
    419       while (m >= 2 && j > m) {
    420         j -= m;
    421         m >>= 1;
    422       }
    423       j += m;
    424     }
    425   }
    426 
    427   template <int Dir>
    428   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_2(ComplexScalar* data) {
    429     ComplexScalar tmp = data[1];
    430     data[1] = data[0] - data[1];
    431     data[0] += tmp;
    432   }
    433 
    434   template <int Dir>
    435   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_4(ComplexScalar* data) {
    436     ComplexScalar tmp[4];
    437     tmp[0] = data[0] + data[1];
    438     tmp[1] = data[0] - data[1];
    439     tmp[2] = data[2] + data[3];
    440     if (Dir == FFT_FORWARD) {
    441       tmp[3] = ComplexScalar(0.0, -1.0) * (data[2] - data[3]);
    442     } else {
    443       tmp[3] = ComplexScalar(0.0, 1.0) * (data[2] - data[3]);
    444     }
    445     data[0] = tmp[0] + tmp[2];
    446     data[1] = tmp[1] + tmp[3];
    447     data[2] = tmp[0] - tmp[2];
    448     data[3] = tmp[1] - tmp[3];
    449   }
    450 
    451   template <int Dir>
    452   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_8(ComplexScalar* data) {
    453     ComplexScalar tmp_1[8];
    454     ComplexScalar tmp_2[8];
    455 
    456     tmp_1[0] = data[0] + data[1];
    457     tmp_1[1] = data[0] - data[1];
    458     tmp_1[2] = data[2] + data[3];
    459     if (Dir == FFT_FORWARD) {
    460       tmp_1[3] = (data[2] - data[3]) * ComplexScalar(0, -1);
    461     } else {
    462       tmp_1[3] = (data[2] - data[3]) * ComplexScalar(0, 1);
    463     }
    464     tmp_1[4] = data[4] + data[5];
    465     tmp_1[5] = data[4] - data[5];
    466     tmp_1[6] = data[6] + data[7];
    467     if (Dir == FFT_FORWARD) {
    468       tmp_1[7] = (data[6] - data[7]) * ComplexScalar(0, -1);
    469     } else {
    470       tmp_1[7] = (data[6] - data[7]) * ComplexScalar(0, 1);
    471     }
    472     tmp_2[0] = tmp_1[0] + tmp_1[2];
    473     tmp_2[1] = tmp_1[1] + tmp_1[3];
    474     tmp_2[2] = tmp_1[0] - tmp_1[2];
    475     tmp_2[3] = tmp_1[1] - tmp_1[3];
    476     tmp_2[4] = tmp_1[4] + tmp_1[6];
    477 // SQRT2DIV2 = sqrt(2)/2
    478 #define SQRT2DIV2 0.7071067811865476
    479     if (Dir == FFT_FORWARD) {
    480       tmp_2[5] = (tmp_1[5] + tmp_1[7]) * ComplexScalar(SQRT2DIV2, -SQRT2DIV2);
    481       tmp_2[6] = (tmp_1[4] - tmp_1[6]) * ComplexScalar(0, -1);
    482       tmp_2[7] = (tmp_1[5] - tmp_1[7]) * ComplexScalar(-SQRT2DIV2, -SQRT2DIV2);
    483     } else {
    484       tmp_2[5] = (tmp_1[5] + tmp_1[7]) * ComplexScalar(SQRT2DIV2, SQRT2DIV2);
    485       tmp_2[6] = (tmp_1[4] - tmp_1[6]) * ComplexScalar(0, 1);
    486       tmp_2[7] = (tmp_1[5] - tmp_1[7]) * ComplexScalar(-SQRT2DIV2, SQRT2DIV2);
    487     }
    488     data[0] = tmp_2[0] + tmp_2[4];
    489     data[1] = tmp_2[1] + tmp_2[5];
    490     data[2] = tmp_2[2] + tmp_2[6];
    491     data[3] = tmp_2[3] + tmp_2[7];
    492     data[4] = tmp_2[0] - tmp_2[4];
    493     data[5] = tmp_2[1] - tmp_2[5];
    494     data[6] = tmp_2[2] - tmp_2[6];
    495     data[7] = tmp_2[3] - tmp_2[7];
    496   }
    497 
    498   template <int Dir>
    499   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void butterfly_1D_merge(
    500       ComplexScalar* data, Index n, Index n_power_of_2) {
    501     // Original code:
    502     // RealScalar wtemp = std::sin(M_PI/n);
    503     // RealScalar wpi =  -std::sin(2 * M_PI/n);
    504     const RealScalar wtemp = m_sin_PI_div_n_LUT[n_power_of_2];
    505     const RealScalar wpi = (Dir == FFT_FORWARD)
    506                                ? m_minus_sin_2_PI_div_n_LUT[n_power_of_2]
    507                                : -m_minus_sin_2_PI_div_n_LUT[n_power_of_2];
    508 
    509     const ComplexScalar wp(wtemp, wpi);
    510     const ComplexScalar wp_one = wp + ComplexScalar(1, 0);
    511     const ComplexScalar wp_one_2 = wp_one * wp_one;
    512     const ComplexScalar wp_one_3 = wp_one_2 * wp_one;
    513     const ComplexScalar wp_one_4 = wp_one_3 * wp_one;
    514     const Index n2 = n / 2;
    515     ComplexScalar w(1.0, 0.0);
    516     for (Index i = 0; i < n2; i += 4) {
    517        ComplexScalar temp0(data[i + n2] * w);
    518        ComplexScalar temp1(data[i + 1 + n2] * w * wp_one);
    519        ComplexScalar temp2(data[i + 2 + n2] * w * wp_one_2);
    520        ComplexScalar temp3(data[i + 3 + n2] * w * wp_one_3);
    521        w = w * wp_one_4;
    522 
    523        data[i + n2] = data[i] - temp0;
    524        data[i] += temp0;
    525 
    526        data[i + 1 + n2] = data[i + 1] - temp1;
    527        data[i + 1] += temp1;
    528 
    529        data[i + 2 + n2] = data[i + 2] - temp2;
    530        data[i + 2] += temp2;
    531 
    532        data[i + 3 + n2] = data[i + 3] - temp3;
    533        data[i + 3] += temp3;
    534     }
    535   }
    536 
    537  template <int Dir>
    538   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_1D_Butterfly(
    539       ComplexScalar* data, Index n, Index n_power_of_2) {
    540     eigen_assert(isPowerOfTwo(n));
    541     if (n > 8) {
    542       compute_1D_Butterfly<Dir>(data, n / 2, n_power_of_2 - 1);
    543       compute_1D_Butterfly<Dir>(data + n / 2, n / 2, n_power_of_2 - 1);
    544       butterfly_1D_merge<Dir>(data, n, n_power_of_2);
    545     } else if (n == 8) {
    546       butterfly_8<Dir>(data);
    547     } else if (n == 4) {
    548       butterfly_4<Dir>(data);
    549     } else if (n == 2) {
    550       butterfly_2<Dir>(data);
    551     }
    552   }
    553 
    554   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index getBaseOffsetFromIndex(Index index, Index omitted_dim) const {
    555     Index result = 0;
    556 
    557     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    558       for (int i = NumDims - 1; i > omitted_dim; --i) {
    559         const Index partial_m_stride = m_strides[i] / m_dimensions[omitted_dim];
    560         const Index idx = index / partial_m_stride;
    561         index -= idx * partial_m_stride;
    562         result += idx * m_strides[i];
    563       }
    564       result += index;
    565     }
    566     else {
    567       for (Index i = 0; i < omitted_dim; ++i) {
    568         const Index partial_m_stride = m_strides[i] / m_dimensions[omitted_dim];
    569         const Index idx = index / partial_m_stride;
    570         index -= idx * partial_m_stride;
    571         result += idx * m_strides[i];
    572       }
    573       result += index;
    574     }
    575     // Value of index_coords[omitted_dim] is not determined to this step
    576     return result;
    577   }
    578 
    579   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index getIndexFromOffset(Index base, Index omitted_dim, Index offset) const {
    580     Index result = base + offset * m_strides[omitted_dim] ;
    581     return result;
    582   }
    583 
    584  protected:
    585   Index m_size;
    586   const FFT EIGEN_DEVICE_REF m_fft;
    587   Dimensions m_dimensions;
    588   array<Index, NumDims> m_strides;
    589   TensorEvaluator<ArgType, Device> m_impl;
    590   EvaluatorPointerType m_data;
    591   const Device EIGEN_DEVICE_REF m_device;
    592 
    593   // This will support a maximum FFT size of 2^32 for each dimension
    594   // m_sin_PI_div_n_LUT[i] = (-2) * std::sin(M_PI / std::pow(2,i)) ^ 2;
    595   const RealScalar m_sin_PI_div_n_LUT[32] = {
    596     RealScalar(0.0),
    597     RealScalar(-2),
    598     RealScalar(-0.999999999999999),
    599     RealScalar(-0.292893218813453),
    600     RealScalar(-0.0761204674887130),
    601     RealScalar(-0.0192147195967696),
    602     RealScalar(-0.00481527332780311),
    603     RealScalar(-0.00120454379482761),
    604     RealScalar(-3.01181303795779e-04),
    605     RealScalar(-7.52981608554592e-05),
    606     RealScalar(-1.88247173988574e-05),
    607     RealScalar(-4.70619042382852e-06),
    608     RealScalar(-1.17654829809007e-06),
    609     RealScalar(-2.94137117780840e-07),
    610     RealScalar(-7.35342821488550e-08),
    611     RealScalar(-1.83835707061916e-08),
    612     RealScalar(-4.59589268710903e-09),
    613     RealScalar(-1.14897317243732e-09),
    614     RealScalar(-2.87243293150586e-10),
    615     RealScalar( -7.18108232902250e-11),
    616     RealScalar(-1.79527058227174e-11),
    617     RealScalar(-4.48817645568941e-12),
    618     RealScalar(-1.12204411392298e-12),
    619     RealScalar(-2.80511028480785e-13),
    620     RealScalar(-7.01277571201985e-14),
    621     RealScalar(-1.75319392800498e-14),
    622     RealScalar(-4.38298482001247e-15),
    623     RealScalar(-1.09574620500312e-15),
    624     RealScalar(-2.73936551250781e-16),
    625     RealScalar(-6.84841378126949e-17),
    626     RealScalar(-1.71210344531737e-17),
    627     RealScalar(-4.28025861329343e-18)
    628   };
    629 
    630   // m_minus_sin_2_PI_div_n_LUT[i] = -std::sin(2 * M_PI / std::pow(2,i));
    631   const RealScalar m_minus_sin_2_PI_div_n_LUT[32] = {
    632     RealScalar(0.0),
    633     RealScalar(0.0),
    634     RealScalar(-1.00000000000000e+00),
    635     RealScalar(-7.07106781186547e-01),
    636     RealScalar(-3.82683432365090e-01),
    637     RealScalar(-1.95090322016128e-01),
    638     RealScalar(-9.80171403295606e-02),
    639     RealScalar(-4.90676743274180e-02),
    640     RealScalar(-2.45412285229123e-02),
    641     RealScalar(-1.22715382857199e-02),
    642     RealScalar(-6.13588464915448e-03),
    643     RealScalar(-3.06795676296598e-03),
    644     RealScalar(-1.53398018628477e-03),
    645     RealScalar(-7.66990318742704e-04),
    646     RealScalar(-3.83495187571396e-04),
    647     RealScalar(-1.91747597310703e-04),
    648     RealScalar(-9.58737990959773e-05),
    649     RealScalar(-4.79368996030669e-05),
    650     RealScalar(-2.39684498084182e-05),
    651     RealScalar(-1.19842249050697e-05),
    652     RealScalar(-5.99211245264243e-06),
    653     RealScalar(-2.99605622633466e-06),
    654     RealScalar(-1.49802811316901e-06),
    655     RealScalar(-7.49014056584716e-07),
    656     RealScalar(-3.74507028292384e-07),
    657     RealScalar(-1.87253514146195e-07),
    658     RealScalar(-9.36267570730981e-08),
    659     RealScalar(-4.68133785365491e-08),
    660     RealScalar(-2.34066892682746e-08),
    661     RealScalar(-1.17033446341373e-08),
    662     RealScalar(-5.85167231706864e-09),
    663     RealScalar(-2.92583615853432e-09)
    664   };
    665 };
    666 
    667 }  // end namespace Eigen
    668 
    669 #endif  // EIGEN_CXX11_TENSOR_TENSOR_FFT_H