Raw File
thread_local_test.cc
//  Copyright (c) 2011-present, Facebook, Inc.  All rights reserved.
//  This source code is licensed under both the GPLv2 (found in the
//  COPYING file in the root directory) and Apache 2.0 License
//  (found in the LICENSE.Apache file in the root directory).

#include "util/thread_local.h"

#include <atomic>
#include <string>
#include <thread>

#include "port/port.h"
#include "rocksdb/env.h"
#include "test_util/sync_point.h"
#include "test_util/testharness.h"
#include "test_util/testutil.h"
#include "util/autovector.h"

namespace ROCKSDB_NAMESPACE {

class ThreadLocalTest : public testing::Test {
 public:
  ThreadLocalTest() : env_(Env::Default()) {}

  Env* env_;
};

namespace {

struct Params {
  Params(port::Mutex* m, port::CondVar* c, int* u, int n,
         UnrefHandler handler = nullptr)
      : mu(m),
        cv(c),
        unref(u),
        total(n),
        started(0),
        completed(0),
        doWrite(false),
        tls1(handler),
        tls2(nullptr) {}

  port::Mutex* mu;
  port::CondVar* cv;
  int* unref;
  int total;
  int started;
  int completed;
  bool doWrite;
  ThreadLocalPtr tls1;
  ThreadLocalPtr* tls2;
};

class IDChecker : public ThreadLocalPtr {
 public:
  static uint32_t PeekId() { return TEST_PeekId(); }
};

}  // anonymous namespace

// Suppress false positive clang analyzer warnings.
#ifndef __clang_analyzer__
TEST_F(ThreadLocalTest, UniqueIdTest) {
  port::Mutex mu;
  port::CondVar cv(&mu);

  uint32_t base_id = IDChecker::PeekId();
  // New ThreadLocal instance bumps id by 1
  {
    // Id used 0
    Params p1(&mu, &cv, nullptr, 1u);
    ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
    // Id used 1
    Params p2(&mu, &cv, nullptr, 1u);
    ASSERT_EQ(IDChecker::PeekId(), base_id + 2u);
    // Id used 2
    Params p3(&mu, &cv, nullptr, 1u);
    ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
    // Id used 3
    Params p4(&mu, &cv, nullptr, 1u);
    ASSERT_EQ(IDChecker::PeekId(), base_id + 4u);
  }
  // id 3, 2, 1, 0 are in the free queue in order
  ASSERT_EQ(IDChecker::PeekId(), base_id + 0u);

  // pick up 0
  Params p1(&mu, &cv, nullptr, 1u);
  ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
  // pick up 1
  Params* p2 = new Params(&mu, &cv, nullptr, 1u);
  ASSERT_EQ(IDChecker::PeekId(), base_id + 2u);
  // pick up 2
  Params p3(&mu, &cv, nullptr, 1u);
  ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
  // return up 1
  delete p2;
  ASSERT_EQ(IDChecker::PeekId(), base_id + 1u);
  // Now we have 3, 1 in queue
  // pick up 1
  Params p4(&mu, &cv, nullptr, 1u);
  ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
  // pick up 3
  Params p5(&mu, &cv, nullptr, 1u);
  // next new id
  ASSERT_EQ(IDChecker::PeekId(), base_id + 4u);
  // After exit, id sequence in queue:
  // 3, 1, 2, 0
}
#endif  // __clang_analyzer__

TEST_F(ThreadLocalTest, SequentialReadWriteTest) {
  // global id list carries over 3, 1, 2, 0
  uint32_t base_id = IDChecker::PeekId();

  port::Mutex mu;
  port::CondVar cv(&mu);
  Params p(&mu, &cv, nullptr, 1);
  ThreadLocalPtr tls2;
  p.tls2 = &tls2;

  ASSERT_GT(IDChecker::PeekId(), base_id);
  base_id = IDChecker::PeekId();

  auto func = [](Params* ptr) {
    Params& params = *ptr;
    ASSERT_TRUE(params.tls1.Get() == nullptr);
    params.tls1.Reset(reinterpret_cast<int*>(1));
    ASSERT_TRUE(params.tls1.Get() == reinterpret_cast<int*>(1));
    params.tls1.Reset(reinterpret_cast<int*>(2));
    ASSERT_TRUE(params.tls1.Get() == reinterpret_cast<int*>(2));

    ASSERT_TRUE(params.tls2->Get() == nullptr);
    params.tls2->Reset(reinterpret_cast<int*>(1));
    ASSERT_TRUE(params.tls2->Get() == reinterpret_cast<int*>(1));
    params.tls2->Reset(reinterpret_cast<int*>(2));
    ASSERT_TRUE(params.tls2->Get() == reinterpret_cast<int*>(2));

    params.mu->Lock();
    ++(params.completed);
    params.cv->SignalAll();
    params.mu->Unlock();
  };

  for (int iter = 0; iter < 1024; ++iter) {
    ASSERT_EQ(IDChecker::PeekId(), base_id);
    // Another new thread, read/write should not see value from previous thread
    env_->StartThreadTyped(func, &p);

    mu.Lock();
    while (p.completed != iter + 1) {
      cv.Wait();
    }
    mu.Unlock();
    ASSERT_EQ(IDChecker::PeekId(), base_id);
  }
}

TEST_F(ThreadLocalTest, ConcurrentReadWriteTest) {
  // global id list carries over 3, 1, 2, 0
  uint32_t base_id = IDChecker::PeekId();

  ThreadLocalPtr tls2;
  port::Mutex mu1;
  port::CondVar cv1(&mu1);
  Params p1(&mu1, &cv1, nullptr, 16);
  p1.tls2 = &tls2;

  port::Mutex mu2;
  port::CondVar cv2(&mu2);
  Params p2(&mu2, &cv2, nullptr, 16);
  p2.doWrite = true;
  p2.tls2 = &tls2;

  auto func = [](void* ptr) {
    auto& p = *static_cast<Params*>(ptr);

    p.mu->Lock();
    // Size_T switches size along with the ptr size
    // we want to cast to.
    size_t own = ++(p.started);
    p.cv->SignalAll();
    while (p.started != p.total) {
      p.cv->Wait();
    }
    p.mu->Unlock();

    // Let write threads write a different value from the read threads
    if (p.doWrite) {
      own += 8192;
    }

    ASSERT_TRUE(p.tls1.Get() == nullptr);
    ASSERT_TRUE(p.tls2->Get() == nullptr);

    auto* env = Env::Default();
    auto start = env->NowMicros();

    p.tls1.Reset(reinterpret_cast<size_t*>(own));
    p.tls2->Reset(reinterpret_cast<size_t*>(own + 1));
    // Loop for 1 second
    while (env->NowMicros() - start < 1000 * 1000) {
      for (int iter = 0; iter < 100000; ++iter) {
        ASSERT_TRUE(p.tls1.Get() == reinterpret_cast<size_t*>(own));
        ASSERT_TRUE(p.tls2->Get() == reinterpret_cast<size_t*>(own + 1));
        if (p.doWrite) {
          p.tls1.Reset(reinterpret_cast<size_t*>(own));
          p.tls2->Reset(reinterpret_cast<size_t*>(own + 1));
        }
      }
    }

    p.mu->Lock();
    ++(p.completed);
    p.cv->SignalAll();
    p.mu->Unlock();
  };

  // Initiate 2 instnaces: one keeps writing and one keeps reading.
  // The read instance should not see data from the write instance.
  // Each thread local copy of the value are also different from each
  // other.
  for (int th = 0; th < p1.total; ++th) {
    env_->StartThreadTyped(func, &p1);
  }
  for (int th = 0; th < p2.total; ++th) {
    env_->StartThreadTyped(func, &p2);
  }

  mu1.Lock();
  while (p1.completed != p1.total) {
    cv1.Wait();
  }
  mu1.Unlock();

  mu2.Lock();
  while (p2.completed != p2.total) {
    cv2.Wait();
  }
  mu2.Unlock();

  ASSERT_EQ(IDChecker::PeekId(), base_id + 3u);
}

TEST_F(ThreadLocalTest, Unref) {
  auto unref = [](void* ptr) {
    auto& p = *static_cast<Params*>(ptr);
    p.mu->Lock();
    ++(*p.unref);
    p.mu->Unlock();
  };

  // Case 0: no unref triggered if ThreadLocalPtr is never accessed
  auto func0 = [](Params* ptr) {
    auto& p = *ptr;
    p.mu->Lock();
    ++(p.started);
    p.cv->SignalAll();
    while (p.started != p.total) {
      p.cv->Wait();
    }
    p.mu->Unlock();
  };

  for (int th = 1; th <= 128; th += th) {
    port::Mutex mu;
    port::CondVar cv(&mu);
    int unref_count = 0;
    Params p(&mu, &cv, &unref_count, th, unref);

    for (int i = 0; i < p.total; ++i) {
      env_->StartThreadTyped(func0, &p);
    }
    env_->WaitForJoin();
    ASSERT_EQ(unref_count, 0);
  }

  // Case 1: unref triggered by thread exit
  auto func1 = [](Params* ptr) {
    auto& p = *ptr;

    p.mu->Lock();
    ++(p.started);
    p.cv->SignalAll();
    while (p.started != p.total) {
      p.cv->Wait();
    }
    p.mu->Unlock();

    ASSERT_TRUE(p.tls1.Get() == nullptr);
    ASSERT_TRUE(p.tls2->Get() == nullptr);

    p.tls1.Reset(ptr);
    p.tls2->Reset(ptr);

    p.tls1.Reset(ptr);
    p.tls2->Reset(ptr);
  };

  for (int th = 1; th <= 128; th += th) {
    port::Mutex mu;
    port::CondVar cv(&mu);
    int unref_count = 0;
    ThreadLocalPtr tls2(unref);
    Params p(&mu, &cv, &unref_count, th, unref);
    p.tls2 = &tls2;

    for (int i = 0; i < p.total; ++i) {
      env_->StartThreadTyped(func1, &p);
    }

    env_->WaitForJoin();

    // N threads x 2 ThreadLocal instance cleanup on thread exit
    ASSERT_EQ(unref_count, 2 * p.total);
  }

  // Case 2: unref triggered by ThreadLocal instance destruction
  auto func2 = [](Params* ptr) {
    auto& p = *ptr;

    p.mu->Lock();
    ++(p.started);
    p.cv->SignalAll();
    while (p.started != p.total) {
      p.cv->Wait();
    }
    p.mu->Unlock();

    ASSERT_TRUE(p.tls1.Get() == nullptr);
    ASSERT_TRUE(p.tls2->Get() == nullptr);

    p.tls1.Reset(ptr);
    p.tls2->Reset(ptr);

    p.tls1.Reset(ptr);
    p.tls2->Reset(ptr);

    p.mu->Lock();
    ++(p.completed);
    p.cv->SignalAll();

    // Waiting for instruction to exit thread
    while (p.completed != 0) {
      p.cv->Wait();
    }
    p.mu->Unlock();
  };

  for (int th = 1; th <= 128; th += th) {
    port::Mutex mu;
    port::CondVar cv(&mu);
    int unref_count = 0;
    Params p(&mu, &cv, &unref_count, th, unref);
    p.tls2 = new ThreadLocalPtr(unref);

    for (int i = 0; i < p.total; ++i) {
      env_->StartThreadTyped(func2, &p);
    }

    // Wait for all threads to finish using Params
    mu.Lock();
    while (p.completed != p.total) {
      cv.Wait();
    }
    mu.Unlock();

    // Now destroy one ThreadLocal instance
    delete p.tls2;
    p.tls2 = nullptr;
    // instance destroy for N threads
    ASSERT_EQ(unref_count, p.total);

    // Signal to exit
    mu.Lock();
    p.completed = 0;
    cv.SignalAll();
    mu.Unlock();
    env_->WaitForJoin();
    // additional N threads exit unref for the left instance
    ASSERT_EQ(unref_count, 2 * p.total);
  }
}

TEST_F(ThreadLocalTest, Swap) {
  ThreadLocalPtr tls;
  tls.Reset(reinterpret_cast<void*>(1));
  ASSERT_EQ(reinterpret_cast<int64_t>(tls.Swap(nullptr)), 1);
  ASSERT_TRUE(tls.Swap(reinterpret_cast<void*>(2)) == nullptr);
  ASSERT_EQ(reinterpret_cast<int64_t>(tls.Get()), 2);
  ASSERT_EQ(reinterpret_cast<int64_t>(tls.Swap(reinterpret_cast<void*>(3))), 2);
}

TEST_F(ThreadLocalTest, Scrape) {
  auto unref = [](void* ptr) {
    auto& p = *static_cast<Params*>(ptr);
    p.mu->Lock();
    ++(*p.unref);
    p.mu->Unlock();
  };

  auto func = [](void* ptr) {
    auto& p = *static_cast<Params*>(ptr);

    ASSERT_TRUE(p.tls1.Get() == nullptr);
    ASSERT_TRUE(p.tls2->Get() == nullptr);

    p.tls1.Reset(ptr);
    p.tls2->Reset(ptr);

    p.tls1.Reset(ptr);
    p.tls2->Reset(ptr);

    p.mu->Lock();
    ++(p.completed);
    p.cv->SignalAll();

    // Waiting for instruction to exit thread
    while (p.completed != 0) {
      p.cv->Wait();
    }
    p.mu->Unlock();
  };

  for (int th = 1; th <= 128; th += th) {
    port::Mutex mu;
    port::CondVar cv(&mu);
    int unref_count = 0;
    Params p(&mu, &cv, &unref_count, th, unref);
    p.tls2 = new ThreadLocalPtr(unref);

    for (int i = 0; i < p.total; ++i) {
      env_->StartThreadTyped(func, &p);
    }

    // Wait for all threads to finish using Params
    mu.Lock();
    while (p.completed != p.total) {
      cv.Wait();
    }
    mu.Unlock();

    ASSERT_EQ(unref_count, 0);

    // Scrape all thread local data. No unref at thread
    // exit or ThreadLocalPtr destruction
    autovector<void*> ptrs;
    p.tls1.Scrape(&ptrs, nullptr);
    p.tls2->Scrape(&ptrs, nullptr);
    delete p.tls2;
    // Signal to exit
    mu.Lock();
    p.completed = 0;
    cv.SignalAll();
    mu.Unlock();
    env_->WaitForJoin();

    ASSERT_EQ(unref_count, 0);
  }
}

TEST_F(ThreadLocalTest, Fold) {
  auto unref = [](void* ptr) {
    delete static_cast<std::atomic<int64_t>*>(ptr);
  };
  static const int kNumThreads = 16;
  static const int kItersPerThread = 10;
  port::Mutex mu;
  port::CondVar cv(&mu);
  Params params(&mu, &cv, nullptr, kNumThreads, unref);
  auto func = [](void* ptr) {
    auto& p = *static_cast<Params*>(ptr);
    ASSERT_TRUE(p.tls1.Get() == nullptr);
    p.tls1.Reset(new std::atomic<int64_t>(0));

    for (int i = 0; i < kItersPerThread; ++i) {
      static_cast<std::atomic<int64_t>*>(p.tls1.Get())->fetch_add(1);
    }

    p.mu->Lock();
    ++(p.completed);
    p.cv->SignalAll();

    // Waiting for instruction to exit thread
    while (p.completed != 0) {
      p.cv->Wait();
    }
    p.mu->Unlock();
  };

  for (int th = 0; th < params.total; ++th) {
    env_->StartThread(func, &params);
  }

  // Wait for all threads to finish using Params
  mu.Lock();
  while (params.completed != params.total) {
    cv.Wait();
  }
  mu.Unlock();

  // Verify Fold() behavior
  int64_t sum = 0;
  params.tls1.Fold(
      [](void* ptr, void* res) {
        auto sum_ptr = static_cast<int64_t*>(res);
        *sum_ptr += static_cast<std::atomic<int64_t>*>(ptr)->load();
      },
      &sum);
  ASSERT_EQ(sum, kNumThreads * kItersPerThread);

  // Signal to exit
  mu.Lock();
  params.completed = 0;
  cv.SignalAll();
  mu.Unlock();
  env_->WaitForJoin();
}

TEST_F(ThreadLocalTest, CompareAndSwap) {
  ThreadLocalPtr tls;
  ASSERT_TRUE(tls.Swap(reinterpret_cast<void*>(1)) == nullptr);
  void* expected = reinterpret_cast<void*>(1);
  // Swap in 2
  ASSERT_TRUE(tls.CompareAndSwap(reinterpret_cast<void*>(2), expected));
  expected = reinterpret_cast<void*>(100);
  // Fail Swap, still 2
  ASSERT_TRUE(!tls.CompareAndSwap(reinterpret_cast<void*>(2), expected));
  ASSERT_EQ(expected, reinterpret_cast<void*>(2));
  // Swap in 3
  expected = reinterpret_cast<void*>(2);
  ASSERT_TRUE(tls.CompareAndSwap(reinterpret_cast<void*>(3), expected));
  ASSERT_EQ(tls.Get(), reinterpret_cast<void*>(3));
}

namespace {

void* AccessThreadLocal(void* /*arg*/) {
  TEST_SYNC_POINT("AccessThreadLocal:Start");
  ThreadLocalPtr tlp;
  tlp.Reset(new std::string("hello RocksDB"));
  TEST_SYNC_POINT("AccessThreadLocal:End");
  return nullptr;
}

}  // namespace

// The following test is disabled as it requires manual steps to run it
// correctly.
//
// Currently we have no way to acess SyncPoint w/o ASAN error when the
// child thread dies after the main thread dies.  So if you manually enable
// this test and only see an ASAN error on SyncPoint, it means you pass the
// test.
TEST_F(ThreadLocalTest, DISABLED_MainThreadDiesFirst) {
  ROCKSDB_NAMESPACE::SyncPoint::GetInstance()->LoadDependency(
      {{"AccessThreadLocal:Start", "MainThreadDiesFirst:End"},
       {"PosixEnv::~PosixEnv():End", "AccessThreadLocal:End"}});

  // Triggers the initialization of singletons.
  Env::Default();

  try {
    ROCKSDB_NAMESPACE::port::Thread th(&AccessThreadLocal, nullptr);
    th.detach();
    TEST_SYNC_POINT("MainThreadDiesFirst:End");
  } catch (const std::system_error& ex) {
    std::cerr << "Start thread: " << ex.code() << std::endl;
    FAIL();
  }
}

}  // namespace ROCKSDB_NAMESPACE

int main(int argc, char** argv) {
  ROCKSDB_NAMESPACE::port::InstallStackTraceHandler();
  ::testing::InitGoogleTest(&argc, argv);
  return RUN_ALL_TESTS();
}
back to top