bfloat16_float.cpp (17931B)
1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // This Source Code Form is subject to the terms of the Mozilla 5 // Public License v. 2.0. If a copy of the MPL was not distributed 6 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 7 8 #include <sstream> 9 #include <memory> 10 #include <math.h> 11 12 #include "main.h" 13 14 #include <Eigen/src/Core/arch/Default/BFloat16.h> 15 16 #define VERIFY_BFLOAT16_BITS_EQUAL(h, bits) \ 17 VERIFY_IS_EQUAL((numext::bit_cast<numext::uint16_t>(h)), (static_cast<numext::uint16_t>(bits))) 18 19 // Make sure it's possible to forward declare Eigen::bfloat16 20 namespace Eigen { 21 struct bfloat16; 22 } 23 24 using Eigen::bfloat16; 25 26 float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa, 27 uint32_t low_mantissa) { 28 float dest; 29 uint32_t src = (sign << 31) + (exponent << 23) + (high_mantissa << 16) + low_mantissa; 30 memcpy(static_cast<void*>(&dest), 31 static_cast<const void*>(&src), sizeof(dest)); 32 return dest; 33 } 34 35 template<typename T> 36 void test_roundtrip() { 37 // Representable T round trip via bfloat16 38 VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(-std::numeric_limits<T>::infinity()))), -std::numeric_limits<T>::infinity()); 39 VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(std::numeric_limits<T>::infinity()))), std::numeric_limits<T>::infinity()); 40 VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(T(-1.0)))), T(-1.0)); 41 VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(T(-0.5)))), T(-0.5)); 42 VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(T(-0.0)))), T(-0.0)); 43 VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(T(1.0)))), T(1.0)); 44 VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(T(0.5)))), T(0.5)); 45 VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(T(0.0)))), T(0.0)); 46 } 47 48 void test_conversion() 49 { 50 using Eigen::bfloat16_impl::__bfloat16_raw; 51 52 // Round-trip casts 53 VERIFY_IS_EQUAL( 54 numext::bit_cast<bfloat16>(numext::bit_cast<numext::uint16_t>(bfloat16(1.0f))), 55 bfloat16(1.0f)); 56 VERIFY_IS_EQUAL( 57 numext::bit_cast<bfloat16>(numext::bit_cast<numext::uint16_t>(bfloat16(0.5f))), 58 bfloat16(0.5f)); 59 VERIFY_IS_EQUAL( 60 numext::bit_cast<bfloat16>(numext::bit_cast<numext::uint16_t>(bfloat16(-0.33333f))), 61 bfloat16(-0.33333f)); 62 VERIFY_IS_EQUAL( 63 numext::bit_cast<bfloat16>(numext::bit_cast<numext::uint16_t>(bfloat16(0.0f))), 64 bfloat16(0.0f)); 65 66 // Conversion from float. 67 VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(1.0f), 0x3f80); 68 VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(0.5f), 0x3f00); 69 VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(0.33333f), 0x3eab); 70 VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(3.38e38f), 0x7f7e); 71 VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(3.40e38f), 0x7f80); // Becomes infinity. 72 73 // Verify round-to-nearest-even behavior. 74 float val1 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c00))); 75 float val2 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c01))); 76 float val3 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c02))); 77 VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(0.5f * (val1 + val2)), 0x3c00); 78 VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(0.5f * (val2 + val3)), 0x3c02); 79 80 // Conversion from int. 81 VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(-1), 0xbf80); 82 VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(0), 0x0000); 83 VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(1), 0x3f80); 84 VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(2), 0x4000); 85 VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(3), 0x4040); 86 VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(12), 0x4140); 87 88 // Conversion from bool. 89 VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(false), 0x0000); 90 VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(true), 0x3f80); 91 92 // Conversion to bool 93 VERIFY_IS_EQUAL(static_cast<bool>(bfloat16(3)), true); 94 VERIFY_IS_EQUAL(static_cast<bool>(bfloat16(0.33333f)), true); 95 VERIFY_IS_EQUAL(bfloat16(-0.0), false); 96 VERIFY_IS_EQUAL(static_cast<bool>(bfloat16(0.0)), false); 97 98 // Explicit conversion to float. 99 VERIFY_IS_EQUAL(static_cast<float>(bfloat16(__bfloat16_raw(0x0000))), 0.0f); 100 VERIFY_IS_EQUAL(static_cast<float>(bfloat16(__bfloat16_raw(0x3f80))), 1.0f); 101 102 // Implicit conversion to float 103 VERIFY_IS_EQUAL(bfloat16(__bfloat16_raw(0x0000)), 0.0f); 104 VERIFY_IS_EQUAL(bfloat16(__bfloat16_raw(0x3f80)), 1.0f); 105 106 // Zero representations 107 VERIFY_IS_EQUAL(bfloat16(0.0f), bfloat16(0.0f)); 108 VERIFY_IS_EQUAL(bfloat16(-0.0f), bfloat16(0.0f)); 109 VERIFY_IS_EQUAL(bfloat16(-0.0f), bfloat16(-0.0f)); 110 VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(0.0f), 0x0000); 111 VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(-0.0f), 0x8000); 112 113 // Default is zero 114 VERIFY_IS_EQUAL(static_cast<float>(bfloat16()), 0.0f); 115 116 // Representable floats round trip via bfloat16 117 test_roundtrip<float>(); 118 test_roundtrip<double>(); 119 test_roundtrip<std::complex<float> >(); 120 test_roundtrip<std::complex<double> >(); 121 122 // Conversion 123 Array<float,1,100> a; 124 for (int i = 0; i < 100; i++) a(i) = i + 1.25; 125 Array<bfloat16,1,100> b = a.cast<bfloat16>(); 126 Array<float,1,100> c = b.cast<float>(); 127 for (int i = 0; i < 100; ++i) { 128 VERIFY_LE(numext::abs(c(i) - a(i)), a(i) / 128); 129 } 130 131 // Epsilon 132 VERIFY_LE(1.0f, static_cast<float>((std::numeric_limits<bfloat16>::epsilon)() + bfloat16(1.0f))); 133 VERIFY_IS_EQUAL(1.0f, static_cast<float>((std::numeric_limits<bfloat16>::epsilon)() / bfloat16(2.0f) + bfloat16(1.0f))); 134 135 // Negate 136 VERIFY_IS_EQUAL(static_cast<float>(-bfloat16(3.0f)), -3.0f); 137 VERIFY_IS_EQUAL(static_cast<float>(-bfloat16(-4.5f)), 4.5f); 138 139 140 #if !EIGEN_COMP_MSVC 141 // Visual Studio errors out on divisions by 0 142 VERIFY((numext::isnan)(static_cast<float>(bfloat16(0.0 / 0.0)))); 143 VERIFY((numext::isinf)(static_cast<float>(bfloat16(1.0 / 0.0)))); 144 VERIFY((numext::isinf)(static_cast<float>(bfloat16(-1.0 / 0.0)))); 145 146 // Visual Studio errors out on divisions by 0 147 VERIFY((numext::isnan)(bfloat16(0.0 / 0.0))); 148 VERIFY((numext::isinf)(bfloat16(1.0 / 0.0))); 149 VERIFY((numext::isinf)(bfloat16(-1.0 / 0.0))); 150 #endif 151 152 // NaNs and infinities. 153 VERIFY(!(numext::isinf)(static_cast<float>(bfloat16(3.38e38f)))); // Largest finite number. 154 VERIFY(!(numext::isnan)(static_cast<float>(bfloat16(0.0f)))); 155 VERIFY((numext::isinf)(static_cast<float>(bfloat16(__bfloat16_raw(0xff80))))); 156 VERIFY((numext::isnan)(static_cast<float>(bfloat16(__bfloat16_raw(0xffc0))))); 157 VERIFY((numext::isinf)(static_cast<float>(bfloat16(__bfloat16_raw(0x7f80))))); 158 VERIFY((numext::isnan)(static_cast<float>(bfloat16(__bfloat16_raw(0x7fc0))))); 159 160 // Exactly same checks as above, just directly on the bfloat16 representation. 161 VERIFY(!(numext::isinf)(bfloat16(__bfloat16_raw(0x7bff)))); 162 VERIFY(!(numext::isnan)(bfloat16(__bfloat16_raw(0x0000)))); 163 VERIFY((numext::isinf)(bfloat16(__bfloat16_raw(0xff80)))); 164 VERIFY((numext::isnan)(bfloat16(__bfloat16_raw(0xffc0)))); 165 VERIFY((numext::isinf)(bfloat16(__bfloat16_raw(0x7f80)))); 166 VERIFY((numext::isnan)(bfloat16(__bfloat16_raw(0x7fc0)))); 167 168 VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(BinaryToFloat(0x0, 0xff, 0x40, 0x0)), 0x7fc0); 169 VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(BinaryToFloat(0x1, 0xff, 0x40, 0x0)), 0xffc0); 170 } 171 172 void test_numtraits() 173 { 174 std::cout << "epsilon = " << NumTraits<bfloat16>::epsilon() << " (0x" << std::hex << numext::bit_cast<numext::uint16_t>(NumTraits<bfloat16>::epsilon()) << ")" << std::endl; 175 std::cout << "highest = " << NumTraits<bfloat16>::highest() << " (0x" << std::hex << numext::bit_cast<numext::uint16_t>(NumTraits<bfloat16>::highest()) << ")" << std::endl; 176 std::cout << "lowest = " << NumTraits<bfloat16>::lowest() << " (0x" << std::hex << numext::bit_cast<numext::uint16_t>(NumTraits<bfloat16>::lowest()) << ")" << std::endl; 177 std::cout << "min = " << (std::numeric_limits<bfloat16>::min)() << " (0x" << std::hex << numext::bit_cast<numext::uint16_t>((std::numeric_limits<bfloat16>::min)()) << ")" << std::endl; 178 std::cout << "denorm min = " << (std::numeric_limits<bfloat16>::denorm_min)() << " (0x" << std::hex << numext::bit_cast<numext::uint16_t>((std::numeric_limits<bfloat16>::denorm_min)()) << ")" << std::endl; 179 std::cout << "infinity = " << NumTraits<bfloat16>::infinity() << " (0x" << std::hex << numext::bit_cast<numext::uint16_t>(NumTraits<bfloat16>::infinity()) << ")" << std::endl; 180 std::cout << "quiet nan = " << NumTraits<bfloat16>::quiet_NaN() << " (0x" << std::hex << numext::bit_cast<numext::uint16_t>(NumTraits<bfloat16>::quiet_NaN()) << ")" << std::endl; 181 std::cout << "signaling nan = " << std::numeric_limits<bfloat16>::signaling_NaN() << " (0x" << std::hex << numext::bit_cast<numext::uint16_t>(std::numeric_limits<bfloat16>::signaling_NaN()) << ")" << std::endl; 182 183 VERIFY(NumTraits<bfloat16>::IsSigned); 184 185 VERIFY_IS_EQUAL( 186 numext::bit_cast<numext::uint16_t>(std::numeric_limits<bfloat16>::infinity()), 187 numext::bit_cast<numext::uint16_t>(bfloat16(std::numeric_limits<float>::infinity())) ); 188 // There is no guarantee that casting a 32-bit NaN to bfloat16 has a precise 189 // bit pattern. We test that it is in fact a NaN, then test the signaling 190 // bit (msb of significand is 1 for quiet, 0 for signaling). 191 const numext::uint16_t BFLOAT16_QUIET_BIT = 0x0040; 192 VERIFY( 193 (numext::isnan)(std::numeric_limits<bfloat16>::quiet_NaN()) 194 && (numext::isnan)(bfloat16(std::numeric_limits<float>::quiet_NaN())) 195 && ((numext::bit_cast<numext::uint16_t>(std::numeric_limits<bfloat16>::quiet_NaN()) & BFLOAT16_QUIET_BIT) > 0) 196 && ((numext::bit_cast<numext::uint16_t>(bfloat16(std::numeric_limits<float>::quiet_NaN())) & BFLOAT16_QUIET_BIT) > 0) ); 197 // After a cast to bfloat16, a signaling NaN may become non-signaling. Thus, 198 // we check that both are NaN, and that only the `numeric_limits` version is 199 // signaling. 200 VERIFY( 201 (numext::isnan)(std::numeric_limits<bfloat16>::signaling_NaN()) 202 && (numext::isnan)(bfloat16(std::numeric_limits<float>::signaling_NaN())) 203 && ((numext::bit_cast<numext::uint16_t>(std::numeric_limits<bfloat16>::signaling_NaN()) & BFLOAT16_QUIET_BIT) == 0) ); 204 205 VERIFY( (std::numeric_limits<bfloat16>::min)() > bfloat16(0.f) ); 206 VERIFY( (std::numeric_limits<bfloat16>::denorm_min)() > bfloat16(0.f) ); 207 VERIFY_IS_EQUAL( (std::numeric_limits<bfloat16>::denorm_min)()/bfloat16(2), bfloat16(0.f) ); 208 } 209 210 void test_arithmetic() 211 { 212 VERIFY_IS_EQUAL(static_cast<float>(bfloat16(2) + bfloat16(2)), 4); 213 VERIFY_IS_EQUAL(static_cast<float>(bfloat16(2) + bfloat16(-2)), 0); 214 VERIFY_IS_APPROX(static_cast<float>(bfloat16(0.33333f) + bfloat16(0.66667f)), 1.0f); 215 VERIFY_IS_EQUAL(static_cast<float>(bfloat16(2.0f) * bfloat16(-5.5f)), -11.0f); 216 VERIFY_IS_APPROX(static_cast<float>(bfloat16(1.0f) / bfloat16(3.0f)), 0.3339f); 217 VERIFY_IS_EQUAL(static_cast<float>(-bfloat16(4096.0f)), -4096.0f); 218 VERIFY_IS_EQUAL(static_cast<float>(-bfloat16(-4096.0f)), 4096.0f); 219 } 220 221 void test_comparison() 222 { 223 VERIFY(bfloat16(1.0f) > bfloat16(0.5f)); 224 VERIFY(bfloat16(0.5f) < bfloat16(1.0f)); 225 VERIFY(!(bfloat16(1.0f) < bfloat16(0.5f))); 226 VERIFY(!(bfloat16(0.5f) > bfloat16(1.0f))); 227 228 VERIFY(!(bfloat16(4.0f) > bfloat16(4.0f))); 229 VERIFY(!(bfloat16(4.0f) < bfloat16(4.0f))); 230 231 VERIFY(!(bfloat16(0.0f) < bfloat16(-0.0f))); 232 VERIFY(!(bfloat16(-0.0f) < bfloat16(0.0f))); 233 VERIFY(!(bfloat16(0.0f) > bfloat16(-0.0f))); 234 VERIFY(!(bfloat16(-0.0f) > bfloat16(0.0f))); 235 236 VERIFY(bfloat16(0.2f) > bfloat16(-1.0f)); 237 VERIFY(bfloat16(-1.0f) < bfloat16(0.2f)); 238 VERIFY(bfloat16(-16.0f) < bfloat16(-15.0f)); 239 240 VERIFY(bfloat16(1.0f) == bfloat16(1.0f)); 241 VERIFY(bfloat16(1.0f) != bfloat16(2.0f)); 242 243 // Comparisons with NaNs and infinities. 244 #if !EIGEN_COMP_MSVC 245 // Visual Studio errors out on divisions by 0 246 VERIFY(!(bfloat16(0.0 / 0.0) == bfloat16(0.0 / 0.0))); 247 VERIFY(bfloat16(0.0 / 0.0) != bfloat16(0.0 / 0.0)); 248 249 VERIFY(!(bfloat16(1.0) == bfloat16(0.0 / 0.0))); 250 VERIFY(!(bfloat16(1.0) < bfloat16(0.0 / 0.0))); 251 VERIFY(!(bfloat16(1.0) > bfloat16(0.0 / 0.0))); 252 VERIFY(bfloat16(1.0) != bfloat16(0.0 / 0.0)); 253 254 VERIFY(bfloat16(1.0) < bfloat16(1.0 / 0.0)); 255 VERIFY(bfloat16(1.0) > bfloat16(-1.0 / 0.0)); 256 #endif 257 } 258 259 void test_basic_functions() 260 { 261 VERIFY_IS_EQUAL(static_cast<float>(numext::abs(bfloat16(3.5f))), 3.5f); 262 VERIFY_IS_EQUAL(static_cast<float>(abs(bfloat16(3.5f))), 3.5f); 263 VERIFY_IS_EQUAL(static_cast<float>(numext::abs(bfloat16(-3.5f))), 3.5f); 264 VERIFY_IS_EQUAL(static_cast<float>(abs(bfloat16(-3.5f))), 3.5f); 265 266 VERIFY_IS_EQUAL(static_cast<float>(numext::floor(bfloat16(3.5f))), 3.0f); 267 VERIFY_IS_EQUAL(static_cast<float>(floor(bfloat16(3.5f))), 3.0f); 268 VERIFY_IS_EQUAL(static_cast<float>(numext::floor(bfloat16(-3.5f))), -4.0f); 269 VERIFY_IS_EQUAL(static_cast<float>(floor(bfloat16(-3.5f))), -4.0f); 270 271 VERIFY_IS_EQUAL(static_cast<float>(numext::ceil(bfloat16(3.5f))), 4.0f); 272 VERIFY_IS_EQUAL(static_cast<float>(ceil(bfloat16(3.5f))), 4.0f); 273 VERIFY_IS_EQUAL(static_cast<float>(numext::ceil(bfloat16(-3.5f))), -3.0f); 274 VERIFY_IS_EQUAL(static_cast<float>(ceil(bfloat16(-3.5f))), -3.0f); 275 276 VERIFY_IS_APPROX(static_cast<float>(numext::sqrt(bfloat16(0.0f))), 0.0f); 277 VERIFY_IS_APPROX(static_cast<float>(sqrt(bfloat16(0.0f))), 0.0f); 278 VERIFY_IS_APPROX(static_cast<float>(numext::sqrt(bfloat16(4.0f))), 2.0f); 279 VERIFY_IS_APPROX(static_cast<float>(sqrt(bfloat16(4.0f))), 2.0f); 280 281 VERIFY_IS_APPROX(static_cast<float>(numext::pow(bfloat16(0.0f), bfloat16(1.0f))), 0.0f); 282 VERIFY_IS_APPROX(static_cast<float>(pow(bfloat16(0.0f), bfloat16(1.0f))), 0.0f); 283 VERIFY_IS_APPROX(static_cast<float>(numext::pow(bfloat16(2.0f), bfloat16(2.0f))), 4.0f); 284 VERIFY_IS_APPROX(static_cast<float>(pow(bfloat16(2.0f), bfloat16(2.0f))), 4.0f); 285 286 VERIFY_IS_EQUAL(static_cast<float>(numext::exp(bfloat16(0.0f))), 1.0f); 287 VERIFY_IS_EQUAL(static_cast<float>(exp(bfloat16(0.0f))), 1.0f); 288 VERIFY_IS_APPROX(static_cast<float>(numext::exp(bfloat16(EIGEN_PI))), 20.f + static_cast<float>(EIGEN_PI)); 289 VERIFY_IS_APPROX(static_cast<float>(exp(bfloat16(EIGEN_PI))), 20.f + static_cast<float>(EIGEN_PI)); 290 291 VERIFY_IS_EQUAL(static_cast<float>(numext::expm1(bfloat16(0.0f))), 0.0f); 292 VERIFY_IS_EQUAL(static_cast<float>(expm1(bfloat16(0.0f))), 0.0f); 293 VERIFY_IS_APPROX(static_cast<float>(numext::expm1(bfloat16(2.0f))), 6.375f); 294 VERIFY_IS_APPROX(static_cast<float>(expm1(bfloat16(2.0f))), 6.375f); 295 296 VERIFY_IS_EQUAL(static_cast<float>(numext::log(bfloat16(1.0f))), 0.0f); 297 VERIFY_IS_EQUAL(static_cast<float>(log(bfloat16(1.0f))), 0.0f); 298 VERIFY_IS_APPROX(static_cast<float>(numext::log(bfloat16(10.0f))), 2.296875f); 299 VERIFY_IS_APPROX(static_cast<float>(log(bfloat16(10.0f))), 2.296875f); 300 301 VERIFY_IS_EQUAL(static_cast<float>(numext::log1p(bfloat16(0.0f))), 0.0f); 302 VERIFY_IS_EQUAL(static_cast<float>(log1p(bfloat16(0.0f))), 0.0f); 303 VERIFY_IS_APPROX(static_cast<float>(numext::log1p(bfloat16(10.0f))), 2.390625f); 304 VERIFY_IS_APPROX(static_cast<float>(log1p(bfloat16(10.0f))), 2.390625f); 305 } 306 307 void test_trigonometric_functions() 308 { 309 VERIFY_IS_APPROX(numext::cos(bfloat16(0.0f)), bfloat16(cosf(0.0f))); 310 VERIFY_IS_APPROX(cos(bfloat16(0.0f)), bfloat16(cosf(0.0f))); 311 VERIFY_IS_APPROX(numext::cos(bfloat16(EIGEN_PI)), bfloat16(cosf(EIGEN_PI))); 312 // VERIFY_IS_APPROX(numext::cos(bfloat16(EIGEN_PI/2)), bfloat16(cosf(EIGEN_PI/2))); 313 // VERIFY_IS_APPROX(numext::cos(bfloat16(3*EIGEN_PI/2)), bfloat16(cosf(3*EIGEN_PI/2))); 314 VERIFY_IS_APPROX(numext::cos(bfloat16(3.5f)), bfloat16(cosf(3.5f))); 315 316 VERIFY_IS_APPROX(numext::sin(bfloat16(0.0f)), bfloat16(sinf(0.0f))); 317 VERIFY_IS_APPROX(sin(bfloat16(0.0f)), bfloat16(sinf(0.0f))); 318 // VERIFY_IS_APPROX(numext::sin(bfloat16(EIGEN_PI)), bfloat16(sinf(EIGEN_PI))); 319 VERIFY_IS_APPROX(numext::sin(bfloat16(EIGEN_PI/2)), bfloat16(sinf(EIGEN_PI/2))); 320 VERIFY_IS_APPROX(numext::sin(bfloat16(3*EIGEN_PI/2)), bfloat16(sinf(3*EIGEN_PI/2))); 321 VERIFY_IS_APPROX(numext::sin(bfloat16(3.5f)), bfloat16(sinf(3.5f))); 322 323 VERIFY_IS_APPROX(numext::tan(bfloat16(0.0f)), bfloat16(tanf(0.0f))); 324 VERIFY_IS_APPROX(tan(bfloat16(0.0f)), bfloat16(tanf(0.0f))); 325 // VERIFY_IS_APPROX(numext::tan(bfloat16(EIGEN_PI)), bfloat16(tanf(EIGEN_PI))); 326 // VERIFY_IS_APPROX(numext::tan(bfloat16(EIGEN_PI/2)), bfloat16(tanf(EIGEN_PI/2))); 327 // VERIFY_IS_APPROX(numext::tan(bfloat16(3*EIGEN_PI/2)), bfloat16(tanf(3*EIGEN_PI/2))); 328 VERIFY_IS_APPROX(numext::tan(bfloat16(3.5f)), bfloat16(tanf(3.5f))); 329 } 330 331 void test_array() 332 { 333 typedef Array<bfloat16,1,Dynamic> ArrayXh; 334 Index size = internal::random<Index>(1,10); 335 Index i = internal::random<Index>(0,size-1); 336 ArrayXh a1 = ArrayXh::Random(size), a2 = ArrayXh::Random(size); 337 VERIFY_IS_APPROX( a1+a1, bfloat16(2)*a1 ); 338 VERIFY( (a1.abs() >= bfloat16(0)).all() ); 339 VERIFY_IS_APPROX( (a1*a1).sqrt(), a1.abs() ); 340 341 VERIFY( ((a1.min)(a2) <= (a1.max)(a2)).all() ); 342 a1(i) = bfloat16(-10.); 343 VERIFY_IS_EQUAL( a1.minCoeff(), bfloat16(-10.) ); 344 a1(i) = bfloat16(10.); 345 VERIFY_IS_EQUAL( a1.maxCoeff(), bfloat16(10.) ); 346 347 std::stringstream ss; 348 ss << a1; 349 } 350 351 void test_product() 352 { 353 typedef Matrix<bfloat16,Dynamic,Dynamic> MatrixXh; 354 Index rows = internal::random<Index>(1,EIGEN_TEST_MAX_SIZE); 355 Index cols = internal::random<Index>(1,EIGEN_TEST_MAX_SIZE); 356 Index depth = internal::random<Index>(1,EIGEN_TEST_MAX_SIZE); 357 MatrixXh Ah = MatrixXh::Random(rows,depth); 358 MatrixXh Bh = MatrixXh::Random(depth,cols); 359 MatrixXh Ch = MatrixXh::Random(rows,cols); 360 MatrixXf Af = Ah.cast<float>(); 361 MatrixXf Bf = Bh.cast<float>(); 362 MatrixXf Cf = Ch.cast<float>(); 363 VERIFY_IS_APPROX(Ch.noalias()+=Ah*Bh, (Cf.noalias()+=Af*Bf).cast<bfloat16>()); 364 } 365 366 EIGEN_DECLARE_TEST(bfloat16_float) 367 { 368 CALL_SUBTEST(test_numtraits()); 369 for(int i = 0; i < g_repeat; i++) { 370 CALL_SUBTEST(test_conversion()); 371 CALL_SUBTEST(test_arithmetic()); 372 CALL_SUBTEST(test_comparison()); 373 CALL_SUBTEST(test_basic_functions()); 374 CALL_SUBTEST(test_trigonometric_functions()); 375 CALL_SUBTEST(test_array()); 376 CALL_SUBTEST(test_product()); 377 } 378 }