cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

EventCount.h (9121B)


      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2016 Dmitry Vyukov <dvyukov@google.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
     11 #define EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
     12 
     13 namespace Eigen {
     14 
     15 // EventCount allows to wait for arbitrary predicates in non-blocking
     16 // algorithms. Think of condition variable, but wait predicate does not need to
     17 // be protected by a mutex. Usage:
     18 // Waiting thread does:
     19 //
     20 //   if (predicate)
     21 //     return act();
     22 //   EventCount::Waiter& w = waiters[my_index];
     23 //   ec.Prewait(&w);
     24 //   if (predicate) {
     25 //     ec.CancelWait(&w);
     26 //     return act();
     27 //   }
     28 //   ec.CommitWait(&w);
     29 //
     30 // Notifying thread does:
     31 //
     32 //   predicate = true;
     33 //   ec.Notify(true);
     34 //
     35 // Notify is cheap if there are no waiting threads. Prewait/CommitWait are not
     36 // cheap, but they are executed only if the preceding predicate check has
     37 // failed.
     38 //
     39 // Algorithm outline:
     40 // There are two main variables: predicate (managed by user) and state_.
     41 // Operation closely resembles Dekker mutual algorithm:
     42 // https://en.wikipedia.org/wiki/Dekker%27s_algorithm
     43 // Waiting thread sets state_ then checks predicate, Notifying thread sets
     44 // predicate then checks state_. Due to seq_cst fences in between these
     45 // operations it is guaranteed than either waiter will see predicate change
     46 // and won't block, or notifying thread will see state_ change and will unblock
     47 // the waiter, or both. But it can't happen that both threads don't see each
     48 // other changes, which would lead to deadlock.
     49 class EventCount {
     50  public:
     51   class Waiter;
     52 
     53   EventCount(MaxSizeVector<Waiter>& waiters)
     54       : state_(kStackMask), waiters_(waiters) {
     55     eigen_plain_assert(waiters.size() < (1 << kWaiterBits) - 1);
     56   }
     57 
     58   ~EventCount() {
     59     // Ensure there are no waiters.
     60     eigen_plain_assert(state_.load() == kStackMask);
     61   }
     62 
     63   // Prewait prepares for waiting.
     64   // After calling Prewait, the thread must re-check the wait predicate
     65   // and then call either CancelWait or CommitWait.
     66   void Prewait() {
     67     uint64_t state = state_.load(std::memory_order_relaxed);
     68     for (;;) {
     69       CheckState(state);
     70       uint64_t newstate = state + kWaiterInc;
     71       CheckState(newstate);
     72       if (state_.compare_exchange_weak(state, newstate,
     73                                        std::memory_order_seq_cst))
     74         return;
     75     }
     76   }
     77 
     78   // CommitWait commits waiting after Prewait.
     79   void CommitWait(Waiter* w) {
     80     eigen_plain_assert((w->epoch & ~kEpochMask) == 0);
     81     w->state = Waiter::kNotSignaled;
     82     const uint64_t me = (w - &waiters_[0]) | w->epoch;
     83     uint64_t state = state_.load(std::memory_order_seq_cst);
     84     for (;;) {
     85       CheckState(state, true);
     86       uint64_t newstate;
     87       if ((state & kSignalMask) != 0) {
     88         // Consume the signal and return immidiately.
     89         newstate = state - kWaiterInc - kSignalInc;
     90       } else {
     91         // Remove this thread from pre-wait counter and add to the waiter stack.
     92         newstate = ((state & kWaiterMask) - kWaiterInc) | me;
     93         w->next.store(state & (kStackMask | kEpochMask),
     94                       std::memory_order_relaxed);
     95       }
     96       CheckState(newstate);
     97       if (state_.compare_exchange_weak(state, newstate,
     98                                        std::memory_order_acq_rel)) {
     99         if ((state & kSignalMask) == 0) {
    100           w->epoch += kEpochInc;
    101           Park(w);
    102         }
    103         return;
    104       }
    105     }
    106   }
    107 
    108   // CancelWait cancels effects of the previous Prewait call.
    109   void CancelWait() {
    110     uint64_t state = state_.load(std::memory_order_relaxed);
    111     for (;;) {
    112       CheckState(state, true);
    113       uint64_t newstate = state - kWaiterInc;
    114       // We don't know if the thread was also notified or not,
    115       // so we should not consume a signal unconditionaly.
    116       // Only if number of waiters is equal to number of signals,
    117       // we know that the thread was notified and we must take away the signal.
    118       if (((state & kWaiterMask) >> kWaiterShift) ==
    119           ((state & kSignalMask) >> kSignalShift))
    120         newstate -= kSignalInc;
    121       CheckState(newstate);
    122       if (state_.compare_exchange_weak(state, newstate,
    123                                        std::memory_order_acq_rel))
    124         return;
    125     }
    126   }
    127 
    128   // Notify wakes one or all waiting threads.
    129   // Must be called after changing the associated wait predicate.
    130   void Notify(bool notifyAll) {
    131     std::atomic_thread_fence(std::memory_order_seq_cst);
    132     uint64_t state = state_.load(std::memory_order_acquire);
    133     for (;;) {
    134       CheckState(state);
    135       const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
    136       const uint64_t signals = (state & kSignalMask) >> kSignalShift;
    137       // Easy case: no waiters.
    138       if ((state & kStackMask) == kStackMask && waiters == signals) return;
    139       uint64_t newstate;
    140       if (notifyAll) {
    141         // Empty wait stack and set signal to number of pre-wait threads.
    142         newstate =
    143             (state & kWaiterMask) | (waiters << kSignalShift) | kStackMask;
    144       } else if (signals < waiters) {
    145         // There is a thread in pre-wait state, unblock it.
    146         newstate = state + kSignalInc;
    147       } else {
    148         // Pop a waiter from list and unpark it.
    149         Waiter* w = &waiters_[state & kStackMask];
    150         uint64_t next = w->next.load(std::memory_order_relaxed);
    151         newstate = (state & (kWaiterMask | kSignalMask)) | next;
    152       }
    153       CheckState(newstate);
    154       if (state_.compare_exchange_weak(state, newstate,
    155                                        std::memory_order_acq_rel)) {
    156         if (!notifyAll && (signals < waiters))
    157           return;  // unblocked pre-wait thread
    158         if ((state & kStackMask) == kStackMask) return;
    159         Waiter* w = &waiters_[state & kStackMask];
    160         if (!notifyAll) w->next.store(kStackMask, std::memory_order_relaxed);
    161         Unpark(w);
    162         return;
    163       }
    164     }
    165   }
    166 
    167   class Waiter {
    168     friend class EventCount;
    169     // Align to 128 byte boundary to prevent false sharing with other Waiter
    170     // objects in the same vector.
    171     EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<uint64_t> next;
    172     std::mutex mu;
    173     std::condition_variable cv;
    174     uint64_t epoch = 0;
    175     unsigned state = kNotSignaled;
    176     enum {
    177       kNotSignaled,
    178       kWaiting,
    179       kSignaled,
    180     };
    181   };
    182 
    183  private:
    184   // State_ layout:
    185   // - low kWaiterBits is a stack of waiters committed wait
    186   //   (indexes in waiters_ array are used as stack elements,
    187   //   kStackMask means empty stack).
    188   // - next kWaiterBits is count of waiters in prewait state.
    189   // - next kWaiterBits is count of pending signals.
    190   // - remaining bits are ABA counter for the stack.
    191   //   (stored in Waiter node and incremented on push).
    192   static const uint64_t kWaiterBits = 14;
    193   static const uint64_t kStackMask = (1ull << kWaiterBits) - 1;
    194   static const uint64_t kWaiterShift = kWaiterBits;
    195   static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1)
    196                                       << kWaiterShift;
    197   static const uint64_t kWaiterInc = 1ull << kWaiterShift;
    198   static const uint64_t kSignalShift = 2 * kWaiterBits;
    199   static const uint64_t kSignalMask = ((1ull << kWaiterBits) - 1)
    200                                       << kSignalShift;
    201   static const uint64_t kSignalInc = 1ull << kSignalShift;
    202   static const uint64_t kEpochShift = 3 * kWaiterBits;
    203   static const uint64_t kEpochBits = 64 - kEpochShift;
    204   static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift;
    205   static const uint64_t kEpochInc = 1ull << kEpochShift;
    206   std::atomic<uint64_t> state_;
    207   MaxSizeVector<Waiter>& waiters_;
    208 
    209   static void CheckState(uint64_t state, bool waiter = false) {
    210     static_assert(kEpochBits >= 20, "not enough bits to prevent ABA problem");
    211     const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
    212     const uint64_t signals = (state & kSignalMask) >> kSignalShift;
    213     eigen_plain_assert(waiters >= signals);
    214     eigen_plain_assert(waiters < (1 << kWaiterBits) - 1);
    215     eigen_plain_assert(!waiter || waiters > 0);
    216     (void)waiters;
    217     (void)signals;
    218   }
    219 
    220   void Park(Waiter* w) {
    221     std::unique_lock<std::mutex> lock(w->mu);
    222     while (w->state != Waiter::kSignaled) {
    223       w->state = Waiter::kWaiting;
    224       w->cv.wait(lock);
    225     }
    226   }
    227 
    228   void Unpark(Waiter* w) {
    229     for (Waiter* next; w; w = next) {
    230       uint64_t wnext = w->next.load(std::memory_order_relaxed) & kStackMask;
    231       next = wnext == kStackMask ? nullptr : &waiters_[wnext];
    232       unsigned state;
    233       {
    234         std::unique_lock<std::mutex> lock(w->mu);
    235         state = w->state;
    236         w->state = Waiter::kSignaled;
    237       }
    238       // Avoid notifying if it wasn't waiting.
    239       if (state == Waiter::kWaiting) w->cv.notify_one();
    240     }
    241   }
    242 
    243   EventCount(const EventCount&) = delete;
    244   void operator=(const EventCount&) = delete;
    245 };
    246 
    247 }  // namespace Eigen
    248 
    249 #endif  // EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_