cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

cxx11_tensor_image_patch.cpp (36037B)


      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #include "main.h"
     11 
     12 #include <Eigen/CXX11/Tensor>
     13 
     14 using Eigen::Tensor;
     15 
     16 void test_simple_patch()
     17 {
     18   Tensor<float, 4> tensor(2,3,5,7);
     19   tensor.setRandom();
     20   Tensor<float, 4, RowMajor> tensor_row_major = tensor.swap_layout();
     21   VERIFY_IS_EQUAL(tensor.dimension(0), tensor_row_major.dimension(3));
     22   VERIFY_IS_EQUAL(tensor.dimension(1), tensor_row_major.dimension(2));
     23   VERIFY_IS_EQUAL(tensor.dimension(2), tensor_row_major.dimension(1));
     24   VERIFY_IS_EQUAL(tensor.dimension(3), tensor_row_major.dimension(0));
     25 
     26   // Single pixel patch: ColMajor
     27   Tensor<float, 5> single_pixel_patch;
     28   single_pixel_patch = tensor.extract_image_patches(1, 1);
     29   VERIFY_IS_EQUAL(single_pixel_patch.dimension(0), 2);
     30   VERIFY_IS_EQUAL(single_pixel_patch.dimension(1), 1);
     31   VERIFY_IS_EQUAL(single_pixel_patch.dimension(2), 1);
     32   VERIFY_IS_EQUAL(single_pixel_patch.dimension(3), 3*5);
     33   VERIFY_IS_EQUAL(single_pixel_patch.dimension(4), 7);
     34 
     35   // Single pixel patch: RowMajor
     36   Tensor<float, 5, RowMajor> single_pixel_patch_row_major;
     37   single_pixel_patch_row_major = tensor_row_major.extract_image_patches(1, 1);
     38   VERIFY_IS_EQUAL(single_pixel_patch_row_major.dimension(0), 7);
     39   VERIFY_IS_EQUAL(single_pixel_patch_row_major.dimension(1), 3*5);
     40   VERIFY_IS_EQUAL(single_pixel_patch_row_major.dimension(2), 1);
     41   VERIFY_IS_EQUAL(single_pixel_patch_row_major.dimension(3), 1);
     42   VERIFY_IS_EQUAL(single_pixel_patch_row_major.dimension(4), 2);
     43 
     44   for (int i = 0; i < tensor.size(); ++i) {
     45     // ColMajor
     46     if (tensor.data()[i] != single_pixel_patch.data()[i]) {
     47       std::cout << "Mismatch detected at index " << i << " : "
     48            << tensor.data()[i] << " vs " << single_pixel_patch.data()[i]
     49            << std::endl;
     50     }
     51     VERIFY_IS_EQUAL(single_pixel_patch.data()[i], tensor.data()[i]);
     52     // RowMajor
     53     if (tensor_row_major.data()[i] != single_pixel_patch_row_major.data()[i]) {
     54       std::cout << "Mismatch detected at index " << i << " : "
     55            << tensor.data()[i] << " vs "
     56            << single_pixel_patch_row_major.data()[i] << std::endl;
     57     }
     58     VERIFY_IS_EQUAL(single_pixel_patch_row_major.data()[i],
     59                     tensor_row_major.data()[i]);
     60     VERIFY_IS_EQUAL(tensor.data()[i], tensor_row_major.data()[i]);
     61     VERIFY_IS_EQUAL(single_pixel_patch.data()[i],
     62                     single_pixel_patch_row_major.data()[i]);
     63   }
     64 
     65   // Entire image patch: ColMajor
     66   Tensor<float, 5> entire_image_patch;
     67   entire_image_patch = tensor.extract_image_patches(3, 5);
     68   VERIFY_IS_EQUAL(entire_image_patch.dimension(0), 2);
     69   VERIFY_IS_EQUAL(entire_image_patch.dimension(1), 3);
     70   VERIFY_IS_EQUAL(entire_image_patch.dimension(2), 5);
     71   VERIFY_IS_EQUAL(entire_image_patch.dimension(3), 3*5);
     72   VERIFY_IS_EQUAL(entire_image_patch.dimension(4), 7);
     73 
     74   // Entire image patch: RowMajor
     75   Tensor<float, 5, RowMajor> entire_image_patch_row_major;
     76   entire_image_patch_row_major = tensor_row_major.extract_image_patches(3, 5);
     77   VERIFY_IS_EQUAL(entire_image_patch_row_major.dimension(0), 7);
     78   VERIFY_IS_EQUAL(entire_image_patch_row_major.dimension(1), 3*5);
     79   VERIFY_IS_EQUAL(entire_image_patch_row_major.dimension(2), 5);
     80   VERIFY_IS_EQUAL(entire_image_patch_row_major.dimension(3), 3);
     81   VERIFY_IS_EQUAL(entire_image_patch_row_major.dimension(4), 2);
     82 
     83   for (int i = 0; i < 3; ++i) {
     84     for (int j = 0; j < 5; ++j) {
     85       int patchId = i+3*j;
     86       for (int r = 0; r < 3; ++r) {
     87         for (int c = 0; c < 5; ++c) {
     88           for (int d = 0; d < 2; ++d) {
     89             for (int b = 0; b < 7; ++b) {
     90               float expected = 0.0f;
     91               float expected_row_major = 0.0f;
     92               if (r-1+i >= 0 && c-2+j >= 0 && r-1+i < 3 && c-2+j < 5) {
     93                 expected = tensor(d, r-1+i, c-2+j, b);
     94                 expected_row_major = tensor_row_major(b, c-2+j, r-1+i, d);
     95               }
     96               // ColMajor
     97               if (entire_image_patch(d, r, c, patchId, b) != expected) {
     98                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
     99               }
    100               VERIFY_IS_EQUAL(entire_image_patch(d, r, c, patchId, b), expected);
    101               // RowMajor
    102               if (entire_image_patch_row_major(b, patchId, c, r, d) !=
    103                   expected_row_major) {
    104                 std::cout << "Mismatch detected at index i=" << i << " j=" << j
    105                      << " r=" << r << " c=" << c << " d=" << d << " b=" << b
    106                      << std::endl;
    107               }
    108               VERIFY_IS_EQUAL(entire_image_patch_row_major(b, patchId, c, r, d),
    109                               expected_row_major);
    110               // Check that ColMajor and RowMajor agree.
    111               VERIFY_IS_EQUAL(expected, expected_row_major);
    112             }
    113           }
    114         }
    115       }
    116     }
    117   }
    118 
    119   // 2D patch: ColMajor
    120   Tensor<float, 5> twod_patch;
    121   twod_patch = tensor.extract_image_patches(2, 2);
    122   VERIFY_IS_EQUAL(twod_patch.dimension(0), 2);
    123   VERIFY_IS_EQUAL(twod_patch.dimension(1), 2);
    124   VERIFY_IS_EQUAL(twod_patch.dimension(2), 2);
    125   VERIFY_IS_EQUAL(twod_patch.dimension(3), 3*5);
    126   VERIFY_IS_EQUAL(twod_patch.dimension(4), 7);
    127 
    128   // 2D patch: RowMajor
    129   Tensor<float, 5, RowMajor> twod_patch_row_major;
    130   twod_patch_row_major = tensor_row_major.extract_image_patches(2, 2);
    131   VERIFY_IS_EQUAL(twod_patch_row_major.dimension(0), 7);
    132   VERIFY_IS_EQUAL(twod_patch_row_major.dimension(1), 3*5);
    133   VERIFY_IS_EQUAL(twod_patch_row_major.dimension(2), 2);
    134   VERIFY_IS_EQUAL(twod_patch_row_major.dimension(3), 2);
    135   VERIFY_IS_EQUAL(twod_patch_row_major.dimension(4), 2);
    136 
    137 
    138   // Based on the calculation described in TensorTraits.h, padding happens to be 0.
    139   int row_padding = 0;
    140   int col_padding = 0;
    141   int stride = 1;
    142 
    143   for (int i = 0; i < 3; ++i) {
    144     for (int j = 0; j < 5; ++j) {
    145       int patchId = i+3*j;
    146       for (int r = 0; r < 2; ++r) {
    147         for (int c = 0; c < 2; ++c) {
    148           for (int d = 0; d < 2; ++d) {
    149             for (int b = 0; b < 7; ++b) {
    150               float expected = 0.0f;
    151               float expected_row_major = 0.0f;
    152               int row_offset = r*stride + i - row_padding;
    153               int col_offset = c*stride + j - col_padding;
    154               // ColMajor
    155               if (row_offset >= 0 && col_offset >= 0 && row_offset < tensor.dimension(1) && col_offset < tensor.dimension(2)) {
    156                 expected = tensor(d, row_offset, col_offset, b);
    157               }
    158               if (twod_patch(d, r, c, patchId, b) != expected) {
    159                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
    160               }
    161               VERIFY_IS_EQUAL(twod_patch(d, r, c, patchId, b), expected);
    162 
    163               // RowMajor
    164               if (row_offset >= 0 && col_offset >= 0 && row_offset < tensor_row_major.dimension(2) && col_offset < tensor_row_major.dimension(1)) {
    165                 expected_row_major = tensor_row_major(b, col_offset, row_offset, d);
    166 
    167               }
    168               if (twod_patch_row_major(b, patchId, c, r, d) != expected_row_major) {
    169                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
    170               }
    171               VERIFY_IS_EQUAL(twod_patch_row_major(b, patchId, c, r, d), expected_row_major);
    172               // Check that ColMajor and RowMajor agree.
    173               VERIFY_IS_EQUAL(expected, expected_row_major);
    174             }
    175           }
    176         }
    177       }
    178     }
    179   }
    180 }
    181 
    182 // Verifies VALID padding (no padding) with incrementing values.
    183 void test_patch_padding_valid()
    184 {
    185   int input_depth = 3;
    186   int input_rows = 3;
    187   int input_cols = 3;
    188   int input_batches = 1;
    189   int ksize = 2;  // Corresponds to the Rows and Cols for tensor.extract_image_patches<>.
    190   int stride = 2;  // Only same stride is supported.
    191   Tensor<float, 4> tensor(input_depth, input_rows, input_cols, input_batches);
    192   // Initializes tensor with incrementing numbers.
    193   for (int i = 0; i < tensor.size(); ++i) {
    194     tensor.data()[i] = i + 1;
    195   }
    196   // ColMajor
    197   Tensor<float, 5> result = tensor.extract_image_patches(ksize, ksize, stride, stride, 1, 1, PADDING_VALID);
    198 
    199   VERIFY_IS_EQUAL(result.dimension(0), input_depth);  // depth
    200   VERIFY_IS_EQUAL(result.dimension(1), ksize);  // kernel rows
    201   VERIFY_IS_EQUAL(result.dimension(2), ksize);  // kernel cols
    202   VERIFY_IS_EQUAL(result.dimension(3), 1);  // number of patches
    203   VERIFY_IS_EQUAL(result.dimension(4), input_batches);  // number of batches
    204 
    205   // RowMajor
    206   Tensor<float, 4, RowMajor> tensor_row_major = tensor.swap_layout();
    207   VERIFY_IS_EQUAL(tensor.dimension(0), tensor_row_major.dimension(3));
    208   VERIFY_IS_EQUAL(tensor.dimension(1), tensor_row_major.dimension(2));
    209   VERIFY_IS_EQUAL(tensor.dimension(2), tensor_row_major.dimension(1));
    210   VERIFY_IS_EQUAL(tensor.dimension(3), tensor_row_major.dimension(0));
    211 
    212   Tensor<float, 5, RowMajor> result_row_major = tensor_row_major.extract_image_patches(ksize, ksize, stride, stride, 1, 1, PADDING_VALID);
    213   VERIFY_IS_EQUAL(result.dimension(0), result_row_major.dimension(4));
    214   VERIFY_IS_EQUAL(result.dimension(1), result_row_major.dimension(3));
    215   VERIFY_IS_EQUAL(result.dimension(2), result_row_major.dimension(2));
    216   VERIFY_IS_EQUAL(result.dimension(3), result_row_major.dimension(1));
    217   VERIFY_IS_EQUAL(result.dimension(4), result_row_major.dimension(0));
    218 
    219   // No padding is carried out.
    220   int row_padding = 0;
    221   int col_padding = 0;
    222 
    223   for (int i = 0; (i+stride+ksize-1) < input_rows; i += stride) {  // input rows
    224     for (int j = 0; (j+stride+ksize-1) < input_cols; j += stride) {  // input cols
    225       int patchId = i+input_rows*j;
    226       for (int r = 0; r < ksize; ++r) {  // patch rows
    227         for (int c = 0; c < ksize; ++c) {  // patch cols
    228           for (int d = 0; d < input_depth; ++d) {  // depth
    229             for (int b = 0; b < input_batches; ++b) {  // batch
    230               float expected = 0.0f;
    231               float expected_row_major = 0.0f;
    232               int row_offset = r + i - row_padding;
    233               int col_offset = c + j - col_padding;
    234               if (row_offset >= 0 && col_offset >= 0 && row_offset < input_rows && col_offset < input_cols) {
    235                 expected = tensor(d, row_offset, col_offset, b);
    236                 expected_row_major = tensor_row_major(b, col_offset, row_offset, d);
    237               }
    238               // ColMajor
    239               if (result(d, r, c, patchId, b) != expected) {
    240                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
    241               }
    242               VERIFY_IS_EQUAL(result(d, r, c, patchId, b), expected);
    243               // RowMajor
    244               if (result_row_major(b, patchId, c, r, d) != expected_row_major) {
    245                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
    246               }
    247               VERIFY_IS_EQUAL(result_row_major(b, patchId, c, r, d), expected_row_major);
    248               // Check that ColMajor and RowMajor agree.
    249               VERIFY_IS_EQUAL(expected, expected_row_major);
    250             }
    251           }
    252         }
    253       }
    254     }
    255   }
    256 }
    257 
    258 // Verifies VALID padding (no padding) with the same value.
    259 void test_patch_padding_valid_same_value()
    260 {
    261   int input_depth = 1;
    262   int input_rows = 5;
    263   int input_cols = 5;
    264   int input_batches = 2;
    265   int ksize = 3;  // Corresponds to the Rows and Cols for tensor.extract_image_patches<>.
    266   int stride = 2;  // Only same stride is supported.
    267   // ColMajor
    268   Tensor<float, 4> tensor(input_depth, input_rows, input_cols, input_batches);
    269   tensor = tensor.constant(11.0f);
    270   Tensor<float, 5> result = tensor.extract_image_patches(ksize, ksize, stride, stride, 1, 1, PADDING_VALID);
    271 
    272   VERIFY_IS_EQUAL(result.dimension(0), input_depth);  // depth
    273   VERIFY_IS_EQUAL(result.dimension(1), ksize);  // kernel rows
    274   VERIFY_IS_EQUAL(result.dimension(2), ksize);  // kernel cols
    275   VERIFY_IS_EQUAL(result.dimension(3), 4);  // number of patches
    276   VERIFY_IS_EQUAL(result.dimension(4), input_batches);  // number of batches
    277 
    278   // RowMajor
    279   Tensor<float, 4, RowMajor> tensor_row_major = tensor.swap_layout();
    280   VERIFY_IS_EQUAL(tensor.dimension(0), tensor_row_major.dimension(3));
    281   VERIFY_IS_EQUAL(tensor.dimension(1), tensor_row_major.dimension(2));
    282   VERIFY_IS_EQUAL(tensor.dimension(2), tensor_row_major.dimension(1));
    283   VERIFY_IS_EQUAL(tensor.dimension(3), tensor_row_major.dimension(0));
    284 
    285   Tensor<float, 5, RowMajor> result_row_major = tensor_row_major.extract_image_patches(ksize, ksize, stride, stride, 1, 1, PADDING_VALID);
    286   VERIFY_IS_EQUAL(result.dimension(0), result_row_major.dimension(4));
    287   VERIFY_IS_EQUAL(result.dimension(1), result_row_major.dimension(3));
    288   VERIFY_IS_EQUAL(result.dimension(2), result_row_major.dimension(2));
    289   VERIFY_IS_EQUAL(result.dimension(3), result_row_major.dimension(1));
    290   VERIFY_IS_EQUAL(result.dimension(4), result_row_major.dimension(0));
    291 
    292   // No padding is carried out.
    293   int row_padding = 0;
    294   int col_padding = 0;
    295 
    296   for (int i = 0; (i+stride+ksize-1) <= input_rows; i += stride) {  // input rows
    297     for (int j = 0; (j+stride+ksize-1) <= input_cols; j += stride) {  // input cols
    298       int patchId = i+input_rows*j;
    299       for (int r = 0; r < ksize; ++r) {  // patch rows
    300         for (int c = 0; c < ksize; ++c) {  // patch cols
    301           for (int d = 0; d < input_depth; ++d) {  // depth
    302             for (int b = 0; b < input_batches; ++b) {  // batch
    303               float expected = 0.0f;
    304               float expected_row_major = 0.0f;
    305               int row_offset = r + i - row_padding;
    306               int col_offset = c + j - col_padding;
    307               if (row_offset >= 0 && col_offset >= 0 && row_offset < input_rows && col_offset < input_cols) {
    308                 expected = tensor(d, row_offset, col_offset, b);
    309                 expected_row_major = tensor_row_major(b, col_offset, row_offset, d);
    310               }
    311               // ColMajor
    312               if (result(d, r, c, patchId, b) != expected) {
    313                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
    314               }
    315               VERIFY_IS_EQUAL(result(d, r, c, patchId, b), expected);
    316               // RowMajor
    317               if (result_row_major(b, patchId, c, r, d) != expected_row_major) {
    318                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
    319               }
    320               VERIFY_IS_EQUAL(result_row_major(b, patchId, c, r, d), expected_row_major);
    321               // Check that ColMajor and RowMajor agree.
    322               VERIFY_IS_EQUAL(expected, expected_row_major);
    323             }
    324           }
    325         }
    326       }
    327     }
    328   }
    329 }
    330 
    331 // Verifies SAME padding.
    332 void test_patch_padding_same()
    333 {
    334   int input_depth = 3;
    335   int input_rows = 4;
    336   int input_cols = 2;
    337   int input_batches = 1;
    338   int ksize = 2;  // Corresponds to the Rows and Cols for tensor.extract_image_patches<>.
    339   int stride = 2;  // Only same stride is supported.
    340   // ColMajor
    341   Tensor<float, 4> tensor(input_depth, input_rows, input_cols, input_batches);
    342   // Initializes tensor with incrementing numbers.
    343   for (int i = 0; i < tensor.size(); ++i) {
    344     tensor.data()[i] = i + 1;
    345   }
    346   Tensor<float, 5> result = tensor.extract_image_patches(ksize, ksize, stride, stride, PADDING_SAME);
    347 
    348   VERIFY_IS_EQUAL(result.dimension(0), input_depth);  // depth
    349   VERIFY_IS_EQUAL(result.dimension(1), ksize);  // kernel rows
    350   VERIFY_IS_EQUAL(result.dimension(2), ksize);  // kernel cols
    351   VERIFY_IS_EQUAL(result.dimension(3), 2);  // number of patches
    352   VERIFY_IS_EQUAL(result.dimension(4), input_batches);  // number of batches
    353 
    354   // RowMajor
    355   Tensor<float, 4, RowMajor> tensor_row_major = tensor.swap_layout();
    356   VERIFY_IS_EQUAL(tensor.dimension(0), tensor_row_major.dimension(3));
    357   VERIFY_IS_EQUAL(tensor.dimension(1), tensor_row_major.dimension(2));
    358   VERIFY_IS_EQUAL(tensor.dimension(2), tensor_row_major.dimension(1));
    359   VERIFY_IS_EQUAL(tensor.dimension(3), tensor_row_major.dimension(0));
    360 
    361   Tensor<float, 5, RowMajor> result_row_major = tensor_row_major.extract_image_patches(ksize, ksize, stride, stride, PADDING_SAME);
    362   VERIFY_IS_EQUAL(result.dimension(0), result_row_major.dimension(4));
    363   VERIFY_IS_EQUAL(result.dimension(1), result_row_major.dimension(3));
    364   VERIFY_IS_EQUAL(result.dimension(2), result_row_major.dimension(2));
    365   VERIFY_IS_EQUAL(result.dimension(3), result_row_major.dimension(1));
    366   VERIFY_IS_EQUAL(result.dimension(4), result_row_major.dimension(0));
    367 
    368   // Based on the calculation described in TensorTraits.h, padding happens to be
    369   // 0.
    370   int row_padding = 0;
    371   int col_padding = 0;
    372 
    373   for (int i = 0; (i+stride+ksize-1) <= input_rows; i += stride) {  // input rows
    374     for (int j = 0; (j+stride+ksize-1) <= input_cols; j += stride) {  // input cols
    375       int patchId = i+input_rows*j;
    376       for (int r = 0; r < ksize; ++r) {  // patch rows
    377         for (int c = 0; c < ksize; ++c) {  // patch cols
    378           for (int d = 0; d < input_depth; ++d) {  // depth
    379             for (int b = 0; b < input_batches; ++b) {  // batch
    380               float expected = 0.0f;
    381               float expected_row_major = 0.0f;
    382               int row_offset = r*stride + i - row_padding;
    383               int col_offset = c*stride + j - col_padding;
    384               if (row_offset >= 0 && col_offset >= 0 && row_offset < input_rows && col_offset < input_cols) {
    385                 expected = tensor(d, row_offset, col_offset, b);
    386                 expected_row_major = tensor_row_major(b, col_offset, row_offset, d);
    387               }
    388               // ColMajor
    389               if (result(d, r, c, patchId, b) != expected) {
    390                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
    391               }
    392               VERIFY_IS_EQUAL(result(d, r, c, patchId, b), expected);
    393               // RowMajor
    394               if (result_row_major(b, patchId, c, r, d) != expected_row_major) {
    395                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
    396               }
    397               VERIFY_IS_EQUAL(result_row_major(b, patchId, c, r, d), expected_row_major);
    398               // Check that ColMajor and RowMajor agree.
    399               VERIFY_IS_EQUAL(expected, expected_row_major);
    400             }
    401           }
    402         }
    403       }
    404     }
    405   }
    406 }
    407 
    408 // Verifies that SAME padding, when computed as negative values, will be clipped
    409 // to zero.
    410 void test_patch_padding_same_negative_padding_clip_to_zero() {
    411   int input_depth = 1;
    412   int input_rows = 15;
    413   int input_cols = 1;
    414   int input_batches = 1;
    415   int ksize = 1;  // Corresponds to the Rows and Cols for
    416                   // tensor.extract_image_patches<>.
    417   int row_stride = 5;
    418   int col_stride = 1;
    419   // ColMajor
    420   Tensor<float, 4> tensor(input_depth, input_rows, input_cols, input_batches);
    421   // Initializes tensor with incrementing numbers.
    422   for (int i = 0; i < tensor.size(); ++i) {
    423     tensor.data()[i] = i + 1;
    424   }
    425   Tensor<float, 5> result = tensor.extract_image_patches(
    426       ksize, ksize, row_stride, col_stride, 1, 1, PADDING_SAME);
    427   // row padding will be computed as -2 originally and then be clipped to 0.
    428   VERIFY_IS_EQUAL(result.coeff(0), 1.0f);
    429   VERIFY_IS_EQUAL(result.coeff(1), 6.0f);
    430   VERIFY_IS_EQUAL(result.coeff(2), 11.0f);
    431 
    432   VERIFY_IS_EQUAL(result.dimension(0), input_depth);    // depth
    433   VERIFY_IS_EQUAL(result.dimension(1), ksize);          // kernel rows
    434   VERIFY_IS_EQUAL(result.dimension(2), ksize);          // kernel cols
    435   VERIFY_IS_EQUAL(result.dimension(3), 3);              // number of patches
    436   VERIFY_IS_EQUAL(result.dimension(4), input_batches);  // number of batches
    437 
    438   // RowMajor
    439   Tensor<float, 4, RowMajor> tensor_row_major = tensor.swap_layout();
    440   VERIFY_IS_EQUAL(tensor.dimension(0), tensor_row_major.dimension(3));
    441   VERIFY_IS_EQUAL(tensor.dimension(1), tensor_row_major.dimension(2));
    442   VERIFY_IS_EQUAL(tensor.dimension(2), tensor_row_major.dimension(1));
    443   VERIFY_IS_EQUAL(tensor.dimension(3), tensor_row_major.dimension(0));
    444 
    445   Tensor<float, 5, RowMajor> result_row_major =
    446       tensor_row_major.extract_image_patches(ksize, ksize, row_stride,
    447                                              col_stride, 1, 1, PADDING_SAME);
    448   VERIFY_IS_EQUAL(result_row_major.coeff(0), 1.0f);
    449   VERIFY_IS_EQUAL(result_row_major.coeff(1), 6.0f);
    450   VERIFY_IS_EQUAL(result_row_major.coeff(2), 11.0f);
    451 
    452   VERIFY_IS_EQUAL(result.dimension(0), result_row_major.dimension(4));
    453   VERIFY_IS_EQUAL(result.dimension(1), result_row_major.dimension(3));
    454   VERIFY_IS_EQUAL(result.dimension(2), result_row_major.dimension(2));
    455   VERIFY_IS_EQUAL(result.dimension(3), result_row_major.dimension(1));
    456   VERIFY_IS_EQUAL(result.dimension(4), result_row_major.dimension(0));
    457 }
    458 
    459 void test_patch_no_extra_dim()
    460 {
    461   Tensor<float, 3> tensor(2,3,5);
    462   tensor.setRandom();
    463   Tensor<float, 3, RowMajor> tensor_row_major = tensor.swap_layout();
    464   VERIFY_IS_EQUAL(tensor.dimension(0), tensor_row_major.dimension(2));
    465   VERIFY_IS_EQUAL(tensor.dimension(1), tensor_row_major.dimension(1));
    466   VERIFY_IS_EQUAL(tensor.dimension(2), tensor_row_major.dimension(0));
    467 
    468   // Single pixel patch: ColMajor
    469   Tensor<float, 4> single_pixel_patch;
    470   single_pixel_patch = tensor.extract_image_patches(1, 1);
    471   VERIFY_IS_EQUAL(single_pixel_patch.dimension(0), 2);
    472   VERIFY_IS_EQUAL(single_pixel_patch.dimension(1), 1);
    473   VERIFY_IS_EQUAL(single_pixel_patch.dimension(2), 1);
    474   VERIFY_IS_EQUAL(single_pixel_patch.dimension(3), 3*5);
    475 
    476   // Single pixel patch: RowMajor
    477   Tensor<float, 4, RowMajor> single_pixel_patch_row_major;
    478   single_pixel_patch_row_major = tensor_row_major.extract_image_patches(1, 1);
    479   VERIFY_IS_EQUAL(single_pixel_patch_row_major.dimension(0), 3*5);
    480   VERIFY_IS_EQUAL(single_pixel_patch_row_major.dimension(1), 1);
    481   VERIFY_IS_EQUAL(single_pixel_patch_row_major.dimension(2), 1);
    482   VERIFY_IS_EQUAL(single_pixel_patch_row_major.dimension(3), 2);
    483 
    484   for (int i = 0; i < tensor.size(); ++i) {
    485     // ColMajor
    486     if (tensor.data()[i] != single_pixel_patch.data()[i]) {
    487       std::cout << "Mismatch detected at index " << i << " : " << tensor.data()[i] << " vs " << single_pixel_patch.data()[i] << std::endl;
    488     }
    489     VERIFY_IS_EQUAL(single_pixel_patch.data()[i], tensor.data()[i]);
    490     // RowMajor
    491     if (tensor_row_major.data()[i] != single_pixel_patch_row_major.data()[i]) {
    492       std::cout << "Mismatch detected at index " << i << " : "
    493            << tensor.data()[i] << " vs "
    494            << single_pixel_patch_row_major.data()[i] << std::endl;
    495     }
    496     VERIFY_IS_EQUAL(single_pixel_patch_row_major.data()[i],
    497                     tensor_row_major.data()[i]);
    498     VERIFY_IS_EQUAL(tensor.data()[i], tensor_row_major.data()[i]);
    499     VERIFY_IS_EQUAL(single_pixel_patch.data()[i],
    500                     single_pixel_patch_row_major.data()[i]);
    501   }
    502 
    503   // Entire image patch: ColMajor
    504   Tensor<float, 4> entire_image_patch;
    505   entire_image_patch = tensor.extract_image_patches(3, 5);
    506   VERIFY_IS_EQUAL(entire_image_patch.dimension(0), 2);
    507   VERIFY_IS_EQUAL(entire_image_patch.dimension(1), 3);
    508   VERIFY_IS_EQUAL(entire_image_patch.dimension(2), 5);
    509   VERIFY_IS_EQUAL(entire_image_patch.dimension(3), 3*5);
    510 
    511   // Entire image patch: RowMajor
    512   Tensor<float, 4, RowMajor> entire_image_patch_row_major;
    513   entire_image_patch_row_major = tensor_row_major.extract_image_patches(3, 5);
    514   VERIFY_IS_EQUAL(entire_image_patch_row_major.dimension(0), 3*5);
    515   VERIFY_IS_EQUAL(entire_image_patch_row_major.dimension(1), 5);
    516   VERIFY_IS_EQUAL(entire_image_patch_row_major.dimension(2), 3);
    517   VERIFY_IS_EQUAL(entire_image_patch_row_major.dimension(3), 2);
    518 
    519   for (int i = 0; i < 3; ++i) {
    520     for (int j = 0; j < 5; ++j) {
    521       int patchId = i+3*j;
    522       for (int r = 0; r < 3; ++r) {
    523         for (int c = 0; c < 5; ++c) {
    524           for (int d = 0; d < 2; ++d) {
    525             float expected = 0.0f;
    526             float expected_row_major = 0.0f;
    527             if (r-1+i >= 0 && c-2+j >= 0 && r-1+i < 3 && c-2+j < 5) {
    528               expected = tensor(d, r-1+i, c-2+j);
    529               expected_row_major = tensor_row_major(c-2+j, r-1+i, d);
    530             }
    531             // ColMajor
    532             if (entire_image_patch(d, r, c, patchId) != expected) {
    533               std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << std::endl;
    534             }
    535             VERIFY_IS_EQUAL(entire_image_patch(d, r, c, patchId), expected);
    536             // RowMajor
    537             if (entire_image_patch_row_major(patchId, c, r, d) !=
    538                 expected_row_major) {
    539               std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << std::endl;
    540             }
    541             VERIFY_IS_EQUAL(entire_image_patch_row_major(patchId, c, r, d),
    542                             expected_row_major);
    543             // Check that ColMajor and RowMajor agree.
    544             VERIFY_IS_EQUAL(expected, expected_row_major);
    545           }
    546         }
    547       }
    548     }
    549   }
    550 
    551   // 2D patch: ColMajor
    552   Tensor<float, 4> twod_patch;
    553   twod_patch = tensor.extract_image_patches(2, 2);
    554   VERIFY_IS_EQUAL(twod_patch.dimension(0), 2);
    555   VERIFY_IS_EQUAL(twod_patch.dimension(1), 2);
    556   VERIFY_IS_EQUAL(twod_patch.dimension(2), 2);
    557   VERIFY_IS_EQUAL(twod_patch.dimension(3), 3*5);
    558 
    559   // 2D patch: RowMajor
    560   Tensor<float, 4, RowMajor> twod_patch_row_major;
    561   twod_patch_row_major = tensor_row_major.extract_image_patches(2, 2);
    562   VERIFY_IS_EQUAL(twod_patch_row_major.dimension(0), 3*5);
    563   VERIFY_IS_EQUAL(twod_patch_row_major.dimension(1), 2);
    564   VERIFY_IS_EQUAL(twod_patch_row_major.dimension(2), 2);
    565   VERIFY_IS_EQUAL(twod_patch_row_major.dimension(3), 2);
    566 
    567   // Based on the calculation described in TensorTraits.h, padding happens to be 0.
    568   int row_padding = 0;
    569   int col_padding = 0;
    570   int stride = 1;
    571 
    572   for (int i = 0; i < 3; ++i) {
    573     for (int j = 0; j < 5; ++j) {
    574       int patchId = i+3*j;
    575       for (int r = 0; r < 2; ++r) {
    576         for (int c = 0; c < 2; ++c) {
    577           for (int d = 0; d < 2; ++d) {
    578             float expected = 0.0f;
    579             float expected_row_major = 0.0f;
    580             int row_offset = r*stride + i - row_padding;
    581             int col_offset = c*stride + j - col_padding;
    582             // ColMajor
    583             if (row_offset >= 0 && col_offset >= 0 && row_offset < tensor.dimension(1) && col_offset < tensor.dimension(2)) {
    584               expected = tensor(d, row_offset, col_offset);
    585             }
    586             if (twod_patch(d, r, c, patchId) != expected) {
    587               std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << std::endl;
    588             }
    589             VERIFY_IS_EQUAL(twod_patch(d, r, c, patchId), expected);
    590             // RowMajor
    591             if (row_offset >= 0 && col_offset >= 0 && row_offset < tensor_row_major.dimension(1) && col_offset < tensor_row_major.dimension(0)) {
    592               expected_row_major = tensor_row_major(col_offset, row_offset, d);
    593             }
    594             if (twod_patch_row_major(patchId, c, r, d) != expected_row_major) {
    595               std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << std::endl;
    596             }
    597             VERIFY_IS_EQUAL(twod_patch_row_major(patchId, c, r, d), expected_row_major);
    598             // Check that ColMajor and RowMajor agree.
    599             VERIFY_IS_EQUAL(expected, expected_row_major);
    600           }
    601         }
    602       }
    603     }
    604   }
    605 }
    606 
    607 void test_imagenet_patches()
    608 {
    609   // Test the code on typical configurations used by the 'imagenet' benchmarks at
    610   // https://github.com/soumith/convnet-benchmarks
    611   // ColMajor
    612   Tensor<float, 4> l_in(3, 128, 128, 16);
    613   l_in.setRandom();
    614   Tensor<float, 5> l_out = l_in.extract_image_patches(11, 11);
    615   VERIFY_IS_EQUAL(l_out.dimension(0), 3);
    616   VERIFY_IS_EQUAL(l_out.dimension(1), 11);
    617   VERIFY_IS_EQUAL(l_out.dimension(2), 11);
    618   VERIFY_IS_EQUAL(l_out.dimension(3), 128*128);
    619   VERIFY_IS_EQUAL(l_out.dimension(4), 16);
    620 
    621   // RowMajor
    622   Tensor<float, 5, RowMajor> l_out_row_major = l_in.swap_layout().extract_image_patches(11, 11);
    623   VERIFY_IS_EQUAL(l_out_row_major.dimension(0), 16);
    624   VERIFY_IS_EQUAL(l_out_row_major.dimension(1), 128*128);
    625   VERIFY_IS_EQUAL(l_out_row_major.dimension(2), 11);
    626   VERIFY_IS_EQUAL(l_out_row_major.dimension(3), 11);
    627   VERIFY_IS_EQUAL(l_out_row_major.dimension(4), 3);
    628 
    629   for (int b = 0; b < 16; ++b) {
    630     for (int i = 0; i < 128; ++i) {
    631       for (int j = 0; j < 128; ++j) {
    632         int patchId = i+128*j;
    633         for (int c = 0; c < 11; ++c) {
    634           for (int r = 0; r < 11; ++r) {
    635             for (int d = 0; d < 3; ++d) {
    636               float expected = 0.0f;
    637               if (r-5+i >= 0 && c-5+j >= 0 && r-5+i < 128 && c-5+j < 128) {
    638                 expected = l_in(d, r-5+i, c-5+j, b);
    639               }
    640               // ColMajor
    641               if (l_out(d, r, c, patchId, b) != expected) {
    642                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
    643               }
    644               VERIFY_IS_EQUAL(l_out(d, r, c, patchId, b), expected);
    645               // RowMajor
    646               if (l_out_row_major(b, patchId, c, r, d) !=
    647                   expected) {
    648                 std::cout << "Mismatch detected at index i=" << i << " j=" << j
    649                      << " r=" << r << " c=" << c << " d=" << d << " b=" << b
    650                      << std::endl;
    651               }
    652               VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d),
    653                               expected);
    654             }
    655           }
    656         }
    657       }
    658     }
    659   }
    660 
    661   // ColMajor
    662   l_in.resize(16, 64, 64, 32);
    663   l_in.setRandom();
    664   l_out = l_in.extract_image_patches(9, 9);
    665   VERIFY_IS_EQUAL(l_out.dimension(0), 16);
    666   VERIFY_IS_EQUAL(l_out.dimension(1), 9);
    667   VERIFY_IS_EQUAL(l_out.dimension(2), 9);
    668   VERIFY_IS_EQUAL(l_out.dimension(3), 64*64);
    669   VERIFY_IS_EQUAL(l_out.dimension(4), 32);
    670 
    671   // RowMajor
    672   l_out_row_major = l_in.swap_layout().extract_image_patches(9, 9);
    673   VERIFY_IS_EQUAL(l_out_row_major.dimension(0), 32);
    674   VERIFY_IS_EQUAL(l_out_row_major.dimension(1), 64*64);
    675   VERIFY_IS_EQUAL(l_out_row_major.dimension(2), 9);
    676   VERIFY_IS_EQUAL(l_out_row_major.dimension(3), 9);
    677   VERIFY_IS_EQUAL(l_out_row_major.dimension(4), 16);
    678 
    679   for (int b = 0; b < 32; ++b) {
    680     for (int i = 0; i < 64; ++i) {
    681       for (int j = 0; j < 64; ++j) {
    682         int patchId = i+64*j;
    683         for (int c = 0; c < 9; ++c) {
    684           for (int r = 0; r < 9; ++r) {
    685             for (int d = 0; d < 16; ++d) {
    686               float expected = 0.0f;
    687               if (r-4+i >= 0 && c-4+j >= 0 && r-4+i < 64 && c-4+j < 64) {
    688                 expected = l_in(d, r-4+i, c-4+j, b);
    689               }
    690               // ColMajor
    691               if (l_out(d, r, c, patchId, b) != expected) {
    692                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
    693               }
    694               VERIFY_IS_EQUAL(l_out(d, r, c, patchId, b), expected);
    695               // RowMajor
    696               if (l_out_row_major(b, patchId, c, r, d) != expected) {
    697                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
    698               }
    699               VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d), expected);
    700             }
    701           }
    702         }
    703       }
    704     }
    705   }
    706 
    707   // ColMajor
    708   l_in.resize(32, 16, 16, 32);
    709   l_in.setRandom();
    710   l_out = l_in.extract_image_patches(7, 7);
    711   VERIFY_IS_EQUAL(l_out.dimension(0), 32);
    712   VERIFY_IS_EQUAL(l_out.dimension(1), 7);
    713   VERIFY_IS_EQUAL(l_out.dimension(2), 7);
    714   VERIFY_IS_EQUAL(l_out.dimension(3), 16*16);
    715   VERIFY_IS_EQUAL(l_out.dimension(4), 32);
    716 
    717   // RowMajor
    718   l_out_row_major = l_in.swap_layout().extract_image_patches(7, 7);
    719   VERIFY_IS_EQUAL(l_out_row_major.dimension(0), 32);
    720   VERIFY_IS_EQUAL(l_out_row_major.dimension(1), 16*16);
    721   VERIFY_IS_EQUAL(l_out_row_major.dimension(2), 7);
    722   VERIFY_IS_EQUAL(l_out_row_major.dimension(3), 7);
    723   VERIFY_IS_EQUAL(l_out_row_major.dimension(4), 32);
    724 
    725   for (int b = 0; b < 32; ++b) {
    726     for (int i = 0; i < 16; ++i) {
    727       for (int j = 0; j < 16; ++j) {
    728         int patchId = i+16*j;
    729         for (int c = 0; c < 7; ++c) {
    730           for (int r = 0; r < 7; ++r) {
    731             for (int d = 0; d < 32; ++d) {
    732               float expected = 0.0f;
    733               if (r-3+i >= 0 && c-3+j >= 0 && r-3+i < 16 && c-3+j < 16) {
    734                 expected = l_in(d, r-3+i, c-3+j, b);
    735               }
    736               // ColMajor
    737               if (l_out(d, r, c, patchId, b) != expected) {
    738                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
    739               }
    740               VERIFY_IS_EQUAL(l_out(d, r, c, patchId, b), expected);
    741               // RowMajor
    742               if (l_out_row_major(b, patchId, c, r, d) != expected) {
    743                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
    744               }
    745               VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d), expected);
    746             }
    747           }
    748         }
    749       }
    750     }
    751   }
    752 
    753   // ColMajor
    754   l_in.resize(64, 13, 13, 32);
    755   l_in.setRandom();
    756   l_out = l_in.extract_image_patches(3, 3);
    757   VERIFY_IS_EQUAL(l_out.dimension(0), 64);
    758   VERIFY_IS_EQUAL(l_out.dimension(1), 3);
    759   VERIFY_IS_EQUAL(l_out.dimension(2), 3);
    760   VERIFY_IS_EQUAL(l_out.dimension(3), 13*13);
    761   VERIFY_IS_EQUAL(l_out.dimension(4), 32);
    762 
    763   // RowMajor
    764   l_out_row_major = l_in.swap_layout().extract_image_patches(3, 3);
    765   VERIFY_IS_EQUAL(l_out_row_major.dimension(0), 32);
    766   VERIFY_IS_EQUAL(l_out_row_major.dimension(1), 13*13);
    767   VERIFY_IS_EQUAL(l_out_row_major.dimension(2), 3);
    768   VERIFY_IS_EQUAL(l_out_row_major.dimension(3), 3);
    769   VERIFY_IS_EQUAL(l_out_row_major.dimension(4), 64);
    770 
    771   for (int b = 0; b < 32; ++b) {
    772     for (int i = 0; i < 13; ++i) {
    773       for (int j = 0; j < 13; ++j) {
    774         int patchId = i+13*j;
    775         for (int c = 0; c < 3; ++c) {
    776           for (int r = 0; r < 3; ++r) {
    777             for (int d = 0; d < 64; ++d) {
    778               float expected = 0.0f;
    779               if (r-1+i >= 0 && c-1+j >= 0 && r-1+i < 13 && c-1+j < 13) {
    780                 expected = l_in(d, r-1+i, c-1+j, b);
    781               }
    782               // ColMajor
    783               if (l_out(d, r, c, patchId, b) != expected) {
    784                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
    785               }
    786               VERIFY_IS_EQUAL(l_out(d, r, c, patchId, b), expected);
    787               // RowMajor
    788               if (l_out_row_major(b, patchId, c, r, d) != expected) {
    789                 std::cout << "Mismatch detected at index i=" << i << " j=" << j << " r=" << r << " c=" << c << " d=" << d << " b=" << b << std::endl;
    790               }
    791               VERIFY_IS_EQUAL(l_out_row_major(b, patchId, c, r, d), expected);
    792             }
    793           }
    794         }
    795       }
    796     }
    797   }
    798 }
    799 
    800 EIGEN_DECLARE_TEST(cxx11_tensor_image_patch)
    801 {
    802   CALL_SUBTEST_1(test_simple_patch());
    803   CALL_SUBTEST_2(test_patch_no_extra_dim());
    804   CALL_SUBTEST_3(test_patch_padding_valid());
    805   CALL_SUBTEST_4(test_patch_padding_valid_same_value());
    806   CALL_SUBTEST_5(test_patch_padding_same());
    807   CALL_SUBTEST_6(test_imagenet_patches());
    808   CALL_SUBTEST_7(test_patch_padding_same_negative_padding_clip_to_zero());
    809 }