cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

BFloat16.h (26903B)


      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef EIGEN_BFLOAT16_H
     17 #define EIGEN_BFLOAT16_H
     18 
     19 #define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD)         \
     20   template <>                                                       \
     21   EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED  \
     22   PACKET_BF16 METHOD<PACKET_BF16>(const PACKET_BF16& _x) {          \
     23     return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x)));              \
     24   }
     25 
     26 namespace Eigen {
     27 
     28 struct bfloat16;
     29 
     30 namespace bfloat16_impl {
     31 
     32 // Make our own __bfloat16_raw definition.
     33 struct __bfloat16_raw {
     34   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() : value(0) {}
     35   explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(unsigned short raw) : value(raw) {}
     36   unsigned short value;
     37 };
     38 
     39 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value);
     40 template <bool AssumeArgumentIsNormalOrInfinityOrZero>
     41 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff);
     42 // Forward declarations of template specializations, to avoid Visual C++ 2019 errors, saying:
     43 // > error C2908: explicit specialization; 'float_to_bfloat16_rtne' has already been instantiated
     44 template <>
     45 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff);
     46 template <>
     47 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff);
     48 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h);
     49 
     50 struct bfloat16_base : public __bfloat16_raw {
     51   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base() {}
     52   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base(const __bfloat16_raw& h) : __bfloat16_raw(h) {}
     53 };
     54 
     55 } // namespace bfloat16_impl
     56 
     57 // Class definition.
     58 struct bfloat16 : public bfloat16_impl::bfloat16_base {
     59 
     60   typedef bfloat16_impl::__bfloat16_raw __bfloat16_raw;
     61 
     62   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16() {}
     63 
     64   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const __bfloat16_raw& h) : bfloat16_impl::bfloat16_base(h) {}
     65 
     66   explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(bool b)
     67       : bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
     68 
     69   template<class T>
     70   explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(T val)
     71       : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<internal::is_integral<T>::value>(static_cast<float>(val))) {}
     72 
     73   explicit EIGEN_DEVICE_FUNC bfloat16(float f)
     74       : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(f)) {}
     75 
     76   // Following the convention of numpy, converting between complex and
     77   // float will lead to loss of imag value.
     78   template<typename RealScalar>
     79   explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const std::complex<RealScalar>& val)
     80       : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(static_cast<float>(val.real()))) {}
     81 
     82   EIGEN_DEVICE_FUNC operator float() const {  // NOLINT: Allow implicit conversion to float, because it is lossless.
     83     return bfloat16_impl::bfloat16_to_float(*this);
     84   }
     85 };
     86 } // namespace Eigen
     87 
     88 namespace std {
     89 template<>
     90 struct numeric_limits<Eigen::bfloat16> {
     91   static const bool is_specialized = true;
     92   static const bool is_signed = true;
     93   static const bool is_integer = false;
     94   static const bool is_exact = false;
     95   static const bool has_infinity = true;
     96   static const bool has_quiet_NaN = true;
     97   static const bool has_signaling_NaN = true;
     98   static const float_denorm_style has_denorm = std::denorm_absent;
     99   static const bool has_denorm_loss = false;
    100   static const std::float_round_style round_style = numeric_limits<float>::round_style;
    101   static const bool is_iec559 = false;
    102   static const bool is_bounded = true;
    103   static const bool is_modulo = false;
    104   static const int digits = 8;
    105   static const int digits10 = 2;
    106   static const int max_digits10 = 4;
    107   static const int radix = 2;
    108   static const int min_exponent = numeric_limits<float>::min_exponent;
    109   static const int min_exponent10 = numeric_limits<float>::min_exponent10;
    110   static const int max_exponent = numeric_limits<float>::max_exponent;
    111   static const int max_exponent10 = numeric_limits<float>::max_exponent10;
    112   static const bool traps = numeric_limits<float>::traps;
    113   static const bool tinyness_before = numeric_limits<float>::tinyness_before;
    114 
    115   static Eigen::bfloat16 (min)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); }
    116   static Eigen::bfloat16 lowest() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); }
    117   static Eigen::bfloat16 (max)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); }
    118   static Eigen::bfloat16 epsilon() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); }
    119   static Eigen::bfloat16 round_error() { return Eigen::bfloat16(0x3f00); }
    120   static Eigen::bfloat16 infinity() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); }
    121   static Eigen::bfloat16 quiet_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); }
    122   static Eigen::bfloat16 signaling_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f81); }
    123   static Eigen::bfloat16 denorm_min() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); }
    124 };
    125 
    126 // If std::numeric_limits<T> is specialized, should also specialize
    127 // std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
    128 // std::numeric_limits<const volatile T>
    129 // https://stackoverflow.com/a/16519653/
    130 template<>
    131 struct numeric_limits<const Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
    132 template<>
    133 struct numeric_limits<volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
    134 template<>
    135 struct numeric_limits<const volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
    136 } // namespace std
    137 
    138 namespace Eigen {
    139 
    140 namespace bfloat16_impl {
    141 
    142 // We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler,
    143 // invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation
    144 // of the functions, while the latter can only deal with one of them.
    145 #if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for bfloat16 floats
    146 
    147 #if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
    148 // We need to provide emulated *host-side* BF16 operators for clang.
    149 #pragma push_macro("EIGEN_DEVICE_FUNC")
    150 #undef EIGEN_DEVICE_FUNC
    151 #if defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_NATIVE_BF16)
    152 #define EIGEN_DEVICE_FUNC __host__
    153 #else // both host and device need emulated ops.
    154 #define EIGEN_DEVICE_FUNC __host__ __device__
    155 #endif
    156 #endif
    157 
    158 // Definitions for CPUs, mostly working through conversion
    159 // to/from fp32.
    160 
    161 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const bfloat16& b) {
    162   return bfloat16(float(a) + float(b));
    163 }
    164 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const int& b) {
    165   return bfloat16(float(a) + static_cast<float>(b));
    166 }
    167 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const int& a, const bfloat16& b) {
    168   return bfloat16(static_cast<float>(a) + float(b));
    169 }
    170 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator * (const bfloat16& a, const bfloat16& b) {
    171   return bfloat16(float(a) * float(b));
    172 }
    173 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a, const bfloat16& b) {
    174   return bfloat16(float(a) - float(b));
    175 }
    176 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, const bfloat16& b) {
    177   return bfloat16(float(a) / float(b));
    178 }
    179 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a) {
    180   bfloat16 result;
    181   result.value = a.value ^ 0x8000;
    182   return result;
    183 }
    184 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator += (bfloat16& a, const bfloat16& b) {
    185   a = bfloat16(float(a) + float(b));
    186   return a;
    187 }
    188 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator *= (bfloat16& a, const bfloat16& b) {
    189   a = bfloat16(float(a) * float(b));
    190   return a;
    191 }
    192 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator -= (bfloat16& a, const bfloat16& b) {
    193   a = bfloat16(float(a) - float(b));
    194   return a;
    195 }
    196 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator /= (bfloat16& a, const bfloat16& b) {
    197   a = bfloat16(float(a) / float(b));
    198   return a;
    199 }
    200 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) {
    201   a += bfloat16(1);
    202   return a;
    203 }
    204 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) {
    205   a -= bfloat16(1);
    206   return a;
    207 }
    208 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a, int) {
    209   bfloat16 original_value = a;
    210   ++a;
    211   return original_value;
    212 }
    213 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a, int) {
    214   bfloat16 original_value = a;
    215   --a;
    216   return original_value;
    217 }
    218 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const bfloat16& a, const bfloat16& b) {
    219   return numext::equal_strict(float(a),float(b));
    220 }
    221 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const bfloat16& a, const bfloat16& b) {
    222   return numext::not_equal_strict(float(a), float(b));
    223 }
    224 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const bfloat16& a, const bfloat16& b) {
    225   return float(a) < float(b);
    226 }
    227 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const bfloat16& a, const bfloat16& b) {
    228   return float(a) <= float(b);
    229 }
    230 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const bfloat16& a, const bfloat16& b) {
    231   return float(a) > float(b);
    232 }
    233 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const bfloat16& a, const bfloat16& b) {
    234   return float(a) >= float(b);
    235 }
    236 
    237 #if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
    238 #pragma pop_macro("EIGEN_DEVICE_FUNC")
    239 #endif
    240 #endif  // Emulate support for bfloat16 floats
    241 
    242 // Division by an index. Do it in full float precision to avoid accuracy
    243 // issues in converting the denominator to bfloat16.
    244 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, Index b) {
    245   return bfloat16(static_cast<float>(a) / static_cast<float>(b));
    246 }
    247 
    248 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const float v) {
    249   __bfloat16_raw output;
    250   if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) {
    251     output.value = std::signbit(v) ? 0xFFC0: 0x7FC0;
    252     return output;
    253   }
    254   const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
    255 #if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
    256   output.value = p[0];
    257 #else
    258   output.value = p[1];
    259 #endif
    260   return output;
    261 }
    262 
    263 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value) {
    264   return __bfloat16_raw(value);
    265 }
    266 
    267 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(const __bfloat16_raw& bf) {
    268   return bf.value;
    269 }
    270 
    271 // float_to_bfloat16_rtne template specialization that does not make any
    272 // assumption about the value of its function argument (ff).
    273 template <>
    274 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff) {
    275 #if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
    276   // Nothing to do here
    277 #else
    278   __bfloat16_raw output;
    279 
    280   if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(ff)) {
    281     // If the value is a NaN, squash it to a qNaN with msb of fraction set,
    282     // this makes sure after truncation we don't end up with an inf.
    283     //
    284     // qNaN magic: All exponent bits set + most significant bit of fraction
    285     // set.
    286     output.value = std::signbit(ff) ? 0xFFC0: 0x7FC0;
    287   } else {
    288     // Fast rounding algorithm that rounds a half value to nearest even. This
    289     // reduces expected error when we convert a large number of floats. Here
    290     // is how it works:
    291     //
    292     // Definitions:
    293     // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
    294     // with the following tags:
    295     //
    296     // Sign |  Exp (8 bits) | Frac (23 bits)
    297     //  S     EEEEEEEE         FFFFFFLRTTTTTTTTTTTTTTT
    298     //
    299     //  S: Sign bit.
    300     //  E: Exponent bits.
    301     //  F: First 6 bits of fraction.
    302     //  L: Least significant bit of resulting bfloat16 if we truncate away the
    303     //  rest of the float32. This is also the 7th bit of fraction
    304     //  R: Rounding bit, 8th bit of fraction.
    305     //  T: Sticky bits, rest of fraction, 15 bits.
    306     //
    307     // To round half to nearest even, there are 3 cases where we want to round
    308     // down (simply truncate the result of the bits away, which consists of
    309     // rounding bit and sticky bits) and two cases where we want to round up
    310     // (truncate then add one to the result).
    311     //
    312     // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
    313     // 1s) as the rounding bias, adds the rounding bias to the input, then
    314     // truncates the last 16 bits away.
    315     //
    316     // To understand how it works, we can analyze this algorithm case by case:
    317     //
    318     // 1. L = 0, R = 0:
    319     //   Expect: round down, this is less than half value.
    320     //
    321     //   Algorithm:
    322     //   - Rounding bias: 0x7fff + 0 = 0x7fff
    323     //   - Adding rounding bias to input may create any carry, depending on
    324     //   whether there is any value set to 1 in T bits.
    325     //   - R may be set to 1 if there is a carry.
    326     //   - L remains 0.
    327     //   - Note that this case also handles Inf and -Inf, where all fraction
    328     //   bits, including L, R and Ts are all 0. The output remains Inf after
    329     //   this algorithm.
    330     //
    331     // 2. L = 1, R = 0:
    332     //   Expect: round down, this is less than half value.
    333     //
    334     //   Algorithm:
    335     //   - Rounding bias: 0x7fff + 1 = 0x8000
    336     //   - Adding rounding bias to input doesn't change sticky bits but
    337     //   adds 1 to rounding bit.
    338     //   - L remains 1.
    339     //
    340     // 3. L = 0, R = 1, all of T are 0:
    341     //   Expect: round down, this is exactly at half, the result is already
    342     //   even (L=0).
    343     //
    344     //   Algorithm:
    345     //   - Rounding bias: 0x7fff + 0 = 0x7fff
    346     //   - Adding rounding bias to input sets all sticky bits to 1, but
    347     //   doesn't create a carry.
    348     //   - R remains 1.
    349     //   - L remains 0.
    350     //
    351     // 4. L = 1, R = 1:
    352     //   Expect: round up, this is exactly at half, the result needs to be
    353     //   round to the next even number.
    354     //
    355     //   Algorithm:
    356     //   - Rounding bias: 0x7fff + 1 = 0x8000
    357     //   - Adding rounding bias to input doesn't change sticky bits, but
    358     //   creates a carry from rounding bit.
    359     //   - The carry sets L to 0, creates another carry bit and propagate
    360     //   forward to F bits.
    361     //   - If all the F bits are 1, a carry then propagates to the exponent
    362     //   bits, which then creates the minimum value with the next exponent
    363     //   value. Note that we won't have the case where exponents are all 1,
    364     //   since that's either a NaN (handled in the other if condition) or inf
    365     //   (handled in case 1).
    366     //
    367     // 5. L = 0, R = 1, any of T is 1:
    368     //   Expect: round up, this is greater than half.
    369     //
    370     //   Algorithm:
    371     //   - Rounding bias: 0x7fff + 0 = 0x7fff
    372     //   - Adding rounding bias to input creates a carry from sticky bits,
    373     //   sets rounding bit to 0, then create another carry.
    374     //   - The second carry sets L to 1.
    375     //
    376     // Examples:
    377     //
    378     //  Exact half value that is already even:
    379     //    Input:
    380     //    Sign |  Exp (8 bit)     | Frac (first 7 bit) | Frac (last 16 bit)
    381     //     S     E E E E E E E E      F F F F F F L     RTTTTTTTTTTTTTTT
    382     //     0     0 0 0 0 0 0 0 0      0 0 0 0 0 1 0     1000000000000000
    383     //
    384     //     This falls into case 3. We truncate the rest of 16 bits and no
    385     //     carry is created into F and L:
    386     //
    387     //    Output:
    388     //    Sign |  Exp (8 bit)     | Frac (first 7 bit)
    389     //     S     E E E E E E E E      F F F F F F L
    390     //     0     0 0 0 0 0 0 0 0      0 0 0 0 0 1 0
    391     //
    392     //  Exact half value, round to next even number:
    393     //    Input:
    394     //    Sign |  Exp (8 bit)     | Frac (first 7 bit) | Frac (last 16 bit)
    395     //     S     E E E E E E E E      F F F F F F L     RTTTTTTTTTTTTTTT
    396     //     0     0 0 0 0 0 0 0 0      0 0 0 0 0 0 1     1000000000000000
    397     //
    398     //     This falls into case 4. We create a carry from R and T,
    399     //     which then propagates into L and F:
    400     //
    401     //    Output:
    402     //    Sign |  Exp (8 bit)     | Frac (first 7 bit)
    403     //     S     E E E E E E E E      F F F F F F L
    404     //     0     0 0 0 0 0 0 0 0      0 0 0 0 0 1 0
    405     //
    406     //
    407     //  Max denormal value round to min normal value:
    408     //    Input:
    409     //    Sign |  Exp (8 bit)     | Frac (first 7 bit) | Frac (last 16 bit)
    410     //     S     E E E E E E E E      F F F F F F L     RTTTTTTTTTTTTTTT
    411     //     0     0 0 0 0 0 0 0 0      1 1 1 1 1 1 1     1111111111111111
    412     //
    413     //     This falls into case 4. We create a carry from R and T,
    414     //     propagate into L and F, which then propagates into exponent
    415     //     bits:
    416     //
    417     //    Output:
    418     //    Sign |  Exp (8 bit)     | Frac (first 7 bit)
    419     //     S     E E E E E E E E      F F F F F F L
    420     //     0     0 0 0 0 0 0 0 1      0 0 0 0 0 0 0
    421     //
    422     //  Max normal value round to Inf:
    423     //    Input:
    424     //    Sign |  Exp (8 bit)     | Frac (first 7 bit) | Frac (last 16 bit)
    425     //     S     E E E E E E E E      F F F F F F L     RTTTTTTTTTTTTTTT
    426     //     0     1 1 1 1 1 1 1 0      1 1 1 1 1 1 1     1111111111111111
    427     //
    428     //     This falls into case 4. We create a carry from R and T,
    429     //     propagate into L and F, which then propagates into exponent
    430     //     bits:
    431     //
    432     //    Sign |  Exp (8 bit)     | Frac (first 7 bit)
    433     //     S     E E E E E E E E      F F F F F F L
    434     //     0     1 1 1 1 1 1 1 1      0 0 0 0 0 0 0
    435 
    436     // At this point, ff must be either a normal float, or +/-infinity.
    437     output = float_to_bfloat16_rtne<true>(ff);
    438   }
    439   return output;
    440 #endif
    441 }
    442 
    443 // float_to_bfloat16_rtne template specialization that assumes that its function
    444 // argument (ff) is either a normal floating point number, or +/-infinity, or
    445 // zero. Used to improve the runtime performance of conversion from an integer
    446 // type to bfloat16.
    447 template <>
    448 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff) {
    449 #if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
    450     // Nothing to do here
    451 #else
    452     numext::uint32_t input = numext::bit_cast<numext::uint32_t>(ff);
    453     __bfloat16_raw output;
    454 
    455     // Least significant bit of resulting bfloat.
    456     numext::uint32_t lsb = (input >> 16) & 1;
    457     numext::uint32_t rounding_bias = 0x7fff + lsb;
    458     input += rounding_bias;
    459     output.value = static_cast<numext::uint16_t>(input >> 16);
    460     return output;
    461 #endif
    462 }
    463 
    464 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) {
    465     float result = 0;
    466     unsigned short* q = reinterpret_cast<unsigned short*>(&result);
    467 #if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
    468     q[0] = h.value;
    469 #else
    470     q[1] = h.value;
    471 #endif
    472     return result;
    473 }
    474 // --- standard functions ---
    475 
    476 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const bfloat16& a) {
    477   EIGEN_USING_STD(isinf);
    478   return (isinf)(float(a));
    479 }
    480 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const bfloat16& a) {
    481   EIGEN_USING_STD(isnan);
    482   return (isnan)(float(a));
    483 }
    484 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const bfloat16& a) {
    485   return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a));
    486 }
    487 
    488 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(const bfloat16& a) {
    489   bfloat16 result;
    490   result.value = a.value & 0x7FFF;
    491   return result;
    492 }
    493 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(const bfloat16& a) {
    494    return bfloat16(::expf(float(a)));
    495 }
    496 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(const bfloat16& a) {
    497   return bfloat16(numext::expm1(float(a)));
    498 }
    499 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(const bfloat16& a) {
    500   return bfloat16(::logf(float(a)));
    501 }
    502 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(const bfloat16& a) {
    503   return bfloat16(numext::log1p(float(a)));
    504 }
    505 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(const bfloat16& a) {
    506   return bfloat16(::log10f(float(a)));
    507 }
    508 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(const bfloat16& a) {
    509   return bfloat16(static_cast<float>(EIGEN_LOG2E) * ::logf(float(a)));
    510 }
    511 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) {
    512     return bfloat16(::sqrtf(float(a)));
    513 }
    514 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
    515   return bfloat16(::powf(float(a), float(b)));
    516 }
    517 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) {
    518   return bfloat16(::sinf(float(a)));
    519 }
    520 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(const bfloat16& a) {
    521   return bfloat16(::cosf(float(a)));
    522 }
    523 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(const bfloat16& a) {
    524   return bfloat16(::tanf(float(a)));
    525 }
    526 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(const bfloat16& a) {
    527   return bfloat16(::asinf(float(a)));
    528 }
    529 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(const bfloat16& a) {
    530   return bfloat16(::acosf(float(a)));
    531 }
    532 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(const bfloat16& a) {
    533   return bfloat16(::atanf(float(a)));
    534 }
    535 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(const bfloat16& a) {
    536   return bfloat16(::sinhf(float(a)));
    537 }
    538 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(const bfloat16& a) {
    539   return bfloat16(::coshf(float(a)));
    540 }
    541 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(const bfloat16& a) {
    542   return bfloat16(::tanhf(float(a)));
    543 }
    544 #if EIGEN_HAS_CXX11_MATH
    545 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(const bfloat16& a) {
    546   return bfloat16(::asinhf(float(a)));
    547 }
    548 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(const bfloat16& a) {
    549   return bfloat16(::acoshf(float(a)));
    550 }
    551 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(const bfloat16& a) {
    552   return bfloat16(::atanhf(float(a)));
    553 }
    554 #endif
    555 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(const bfloat16& a) {
    556   return bfloat16(::floorf(float(a)));
    557 }
    558 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(const bfloat16& a) {
    559   return bfloat16(::ceilf(float(a)));
    560 }
    561 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 rint(const bfloat16& a) {
    562   return bfloat16(::rintf(float(a)));
    563 }
    564 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 round(const bfloat16& a) {
    565   return bfloat16(::roundf(float(a)));
    566 }
    567 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(const bfloat16& a, const bfloat16& b) {
    568   return bfloat16(::fmodf(float(a), float(b)));
    569 }
    570 
    571 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (min)(const bfloat16& a, const bfloat16& b) {
    572   const float f1 = static_cast<float>(a);
    573   const float f2 = static_cast<float>(b);
    574   return f2 < f1 ? b : a;
    575 }
    576 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (max)(const bfloat16& a, const bfloat16& b) {
    577   const float f1 = static_cast<float>(a);
    578   const float f2 = static_cast<float>(b);
    579   return f1 < f2 ? b : a;
    580 }
    581 
    582 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(const bfloat16& a, const bfloat16& b) {
    583   const float f1 = static_cast<float>(a);
    584   const float f2 = static_cast<float>(b);
    585   return bfloat16(::fminf(f1, f2));
    586 }
    587 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(const bfloat16& a, const bfloat16& b) {
    588   const float f1 = static_cast<float>(a);
    589   const float f2 = static_cast<float>(b);
    590   return bfloat16(::fmaxf(f1, f2));
    591 }
    592 
    593 #ifndef EIGEN_NO_IO
    594 EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const bfloat16& v) {
    595   os << static_cast<float>(v);
    596   return os;
    597 }
    598 #endif
    599 
    600 } // namespace bfloat16_impl
    601 
    602 namespace internal {
    603 
    604 template<>
    605 struct random_default_impl<bfloat16, false, false>
    606 {
    607   static inline bfloat16 run(const bfloat16& x, const bfloat16& y)
    608   {
    609     return x + (y-x) * bfloat16(float(std::rand()) / float(RAND_MAX));
    610   }
    611   static inline bfloat16 run()
    612   {
    613     return run(bfloat16(-1.f), bfloat16(1.f));
    614   }
    615 };
    616 
    617 template<> struct is_arithmetic<bfloat16> { enum { value = true }; };
    618 
    619 } // namespace internal
    620 
    621 template<> struct NumTraits<Eigen::bfloat16>
    622     : GenericNumTraits<Eigen::bfloat16>
    623 {
    624   enum {
    625     IsSigned = true,
    626     IsInteger = false,
    627     IsComplex = false,
    628     RequireInitialization = false
    629   };
    630 
    631   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 epsilon() {
    632     return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00);
    633   }
    634   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() {
    635     return bfloat16_impl::raw_uint16_to_bfloat16(0x3D4D);  // bfloat16(5e-2f);
    636 
    637   }
    638   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() {
    639     return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
    640   }
    641   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 lowest() {
    642     return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F);
    643   }
    644   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 infinity() {
    645     return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80);
    646   }
    647   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 quiet_NaN() {
    648     return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0);
    649   }
    650 };
    651 
    652 } // namespace Eigen
    653 
    654 namespace Eigen {
    655 namespace numext {
    656 
    657 template<>
    658 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
    659 bool (isnan)(const Eigen::bfloat16& h) {
    660   return (bfloat16_impl::isnan)(h);
    661 }
    662 
    663 template<>
    664 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
    665 bool (isinf)(const Eigen::bfloat16& h) {
    666   return (bfloat16_impl::isinf)(h);
    667 }
    668 
    669 template<>
    670 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
    671 bool (isfinite)(const Eigen::bfloat16& h) {
    672   return (bfloat16_impl::isfinite)(h);
    673 }
    674 
    675 template <>
    676 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(const uint16_t& src) {
    677   return Eigen::bfloat16(Eigen::bfloat16_impl::raw_uint16_to_bfloat16(src));
    678 }
    679 
    680 template <>
    681 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(const Eigen::bfloat16& src) {
    682   return Eigen::bfloat16_impl::raw_bfloat16_as_uint16(src);
    683 }
    684 
    685 }  // namespace numext
    686 }  // namespace Eigen
    687 
    688 #if EIGEN_HAS_STD_HASH
    689 namespace std {
    690 template <>
    691 struct hash<Eigen::bfloat16> {
    692   EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::bfloat16& a) const {
    693     return static_cast<std::size_t>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(a));
    694   }
    695 };
    696 } // namespace std
    697 #endif
    698 
    699 
    700 #endif // EIGEN_BFLOAT16_H