// Copyright (c) 2022 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef BASE_ALLOCATOR_DISPATCHER_INTERNAL_DISPATCHER_INTERNAL_H_
#define BASE_ALLOCATOR_DISPATCHER_INTERNAL_DISPATCHER_INTERNAL_H_

#include "base/allocator/buildflags.h"
#include "base/allocator/dispatcher/configuration.h"
#include "base/allocator/dispatcher/internal/dispatch_data.h"
#include "base/allocator/dispatcher/internal/tools.h"
#include "base/allocator/dispatcher/reentry_guard.h"
#include "base/allocator/dispatcher/subsystem.h"
#include "base/compiler_specific.h"
#include "build/build_config.h"

#if BUILDFLAG(USE_PARTITION_ALLOC)
#include "base/allocator/partition_allocator/partition_alloc.h"
#endif

#if BUILDFLAG(USE_ALLOCATOR_SHIM)
#include "base/allocator/allocator_shim.h"
#endif

#include <tuple>

namespace base::allocator::dispatcher::internal {

template <typename CheckObserverPredicate,
          typename... ObserverTypes,
          size_t... Indices>
void inline PerformObserverCheck(const std::tuple<ObserverTypes...>& observers,
                                 std::index_sequence<Indices...>,
                                 CheckObserverPredicate check_observer) {
  ((DCHECK(check_observer(std::get<Indices>(observers)))), ...);
}

template <typename... ObserverTypes, size_t... Indices>
ALWAYS_INLINE void PerformAllocationNotification(
    const std::tuple<ObserverTypes...>& observers,
    std::index_sequence<Indices...>,
    void* address,
    size_t size,
    AllocationSubsystem subSystem,
    const char* type_name) {
  ((std::get<Indices>(observers)->OnAllocation(address, size, subSystem,
                                               type_name)),
   ...);
}

template <typename... ObserverTypes, size_t... Indices>
ALWAYS_INLINE void PerformFreeNotification(
    const std::tuple<ObserverTypes...>& observers,
    std::index_sequence<Indices...>,
    void* address) {
  ((std::get<Indices>(observers)->OnFree(address)), ...);
}

// DispatcherImpl provides hooks into the various memory subsystems. These hooks
// are responsible for dispatching any notification to the observers.
// In order to provide as many information on the exact type of the observer and
// prevent any conditional jumps in the hot allocation path, observers are
// stored in a std::tuple. DispatcherImpl performs a CHECK at initialization
// time to ensure they are valid.
template <typename... ObserverTypes>
struct DispatcherImpl {
  using AllObservers = std::index_sequence_for<ObserverTypes...>;

  template <std::enable_if_t<
                internal::LessEqual(sizeof...(ObserverTypes),
                                    configuration::kMaximumNumberOfObservers),
                bool> = true>
  static DispatchData GetNotificationHooks(
      std::tuple<ObserverTypes*...> observers) {
    s_observers = std::move(observers);

    PerformObserverCheck(s_observers, AllObservers{}, IsValidObserver{});

    return CreateDispatchData();
  }

 private:
  static DispatchData CreateDispatchData() {
    return DispatchData()
#if BUILDFLAG(USE_PARTITION_ALLOC)
        .SetAllocationObserverHooks(&PartitionAllocatorAllocationHook,
                                    &PartitionAllocatorFreeHook)
#endif
#if BUILDFLAG(USE_ALLOCATOR_SHIM)
        .SetAllocatorDispatch(&allocator_dispatch_)
#endif
        ;
  }

#if BUILDFLAG(USE_PARTITION_ALLOC)
  static void PartitionAllocatorAllocationHook(void* address,
                                               size_t size,
                                               const char* type_name) {
    DoNotifyAllocation(address, size, AllocationSubsystem::kPartitionAllocator,
                       type_name);
  }

  static void PartitionAllocatorFreeHook(void* address) {
    DoNotifyFree(address);
  }
#endif

#if BUILDFLAG(USE_ALLOCATOR_SHIM)
  static void* AllocFn(const AllocatorDispatch* self,
                       size_t size,
                       void* context) {
    ReentryGuard guard;
    void* const address = self->next->alloc_function(self->next, size, context);
    if (LIKELY(guard)) {
      DoNotifyAllocation(address, size, AllocationSubsystem::kAllocatorShim);
    }
    return address;
  }

  static void* AllocUncheckedFn(const AllocatorDispatch* self,
                                size_t size,
                                void* context) {
    ReentryGuard guard;
    void* const address =
        self->next->alloc_unchecked_function(self->next, size, context);
    if (LIKELY(guard)) {
      DoNotifyAllocation(address, size, AllocationSubsystem::kAllocatorShim);
    }
    return address;
  }

  static void* AllocZeroInitializedFn(const AllocatorDispatch* self,
                                      size_t n,
                                      size_t size,
                                      void* context) {
    ReentryGuard guard;
    void* const address = self->next->alloc_zero_initialized_function(
        self->next, n, size, context);
    if (LIKELY(guard)) {
      DoNotifyAllocation(address, n * size,
                         AllocationSubsystem::kAllocatorShim);
    }
    return address;
  }

  static void* AllocAlignedFn(const AllocatorDispatch* self,
                              size_t alignment,
                              size_t size,
                              void* context) {
    ReentryGuard guard;
    void* const address = self->next->alloc_aligned_function(
        self->next, alignment, size, context);
    if (LIKELY(guard)) {
      DoNotifyAllocation(address, size, AllocationSubsystem::kAllocatorShim);
    }
    return address;
  }

  static void* ReallocFn(const AllocatorDispatch* self,
                         void* address,
                         size_t size,
                         void* context) {
    ReentryGuard guard;
    // Note: size == 0 actually performs free.
    // Note: ReentryGuard prevents from recursions introduced by malloc and
    // initialization of thread local storage which happen in the allocation
    // path only (please see docs of ReentryGuard for full details). Therefore,
    // the DoNotifyFree doesn't need to be guarded. Instead, making it unguarded
    // also ensures proper notification.
    DoNotifyFree(address);
    void* const reallocated_address =
        self->next->realloc_function(self->next, address, size, context);
    if (LIKELY(guard)) {
      DoNotifyAllocation(reallocated_address, size,
                         AllocationSubsystem::kAllocatorShim);
    }
    return reallocated_address;
  }

  static void FreeFn(const AllocatorDispatch* self,
                     void* address,
                     void* context) {
    // Note: The RecordFree should be called before free_function (here and in
    // other places). That is because observers need to handle the allocation
    // being freed before calling free_function, as once the latter is executed
    // the address becomes available and can be allocated by another thread.
    // That would be racy otherwise.
    // Note: The code doesn't need to protect from recursions using
    // ReentryGuard, see ReallocFn for details.
    DoNotifyFree(address);
    self->next->free_function(self->next, address, context);
  }

  static size_t GetSizeEstimateFn(const AllocatorDispatch* self,
                                  void* address,
                                  void* context) {
    return self->next->get_size_estimate_function(self->next, address, context);
  }

  static unsigned BatchMallocFn(const AllocatorDispatch* self,
                                size_t size,
                                void** results,
                                unsigned num_requested,
                                void* context) {
    ReentryGuard guard;
    unsigned const num_allocated = self->next->batch_malloc_function(
        self->next, size, results, num_requested, context);
    if (LIKELY(guard)) {
      for (unsigned i = 0; i < num_allocated; ++i) {
        DoNotifyAllocation(results[i], size,
                           AllocationSubsystem::kAllocatorShim);
      }
    }
    return num_allocated;
  }

  static void BatchFreeFn(const AllocatorDispatch* self,
                          void** to_be_freed,
                          unsigned num_to_be_freed,
                          void* context) {
    // Note: The code doesn't need to protect from recursions using
    // ReentryGuard, see ReallocFn for details.
    for (unsigned i = 0; i < num_to_be_freed; ++i) {
      DoNotifyFree(to_be_freed[i]);
    }
    self->next->batch_free_function(self->next, to_be_freed, num_to_be_freed,
                                    context);
  }

  static void FreeDefiniteSizeFn(const AllocatorDispatch* self,
                                 void* address,
                                 size_t size,
                                 void* context) {
    // Note: The code doesn't need to protect from recursions using
    // ReentryGuard, see ReallocFn for details.
    DoNotifyFree(address);
    self->next->free_definite_size_function(self->next, address, size, context);
  }

  static void* AlignedMallocFn(const AllocatorDispatch* self,
                               size_t size,
                               size_t alignment,
                               void* context) {
    ReentryGuard guard;
    void* const address = self->next->aligned_malloc_function(
        self->next, size, alignment, context);
    if (LIKELY(guard)) {
      DoNotifyAllocation(address, size, AllocationSubsystem::kAllocatorShim);
    }
    return address;
  }

  static void* AlignedReallocFn(const AllocatorDispatch* self,
                                void* address,
                                size_t size,
                                size_t alignment,
                                void* context) {
    ReentryGuard guard;
    // Note: size == 0 actually performs free.
    // Note: DoNotifyFree doesn't need to protect from recursions using
    // ReentryGuard, see ReallocFn for details.
    // Instead, making it unguarded also ensures proper notification of the free
    // portion.
    DoNotifyFree(address);
    address = self->next->aligned_realloc_function(self->next, address, size,
                                                   alignment, context);
    if (LIKELY(guard)) {
      DoNotifyAllocation(address, size, AllocationSubsystem::kAllocatorShim);
    }
    return address;
  }

  static void AlignedFreeFn(const AllocatorDispatch* self,
                            void* address,
                            void* context) {
    // Note: The code doesn't need to protect from recursions using
    // ReentryGuard, see ReallocFn for details.
    DoNotifyFree(address);
    self->next->aligned_free_function(self->next, address, context);
  }

  static AllocatorDispatch allocator_dispatch_;
#endif

  static ALWAYS_INLINE void DoNotifyAllocation(
      void* address,
      size_t size,
      AllocationSubsystem subSystem,
      const char* type_name = nullptr) {
    PerformAllocationNotification(s_observers, AllObservers{}, address, size,
                                  subSystem, type_name);
  }

  static ALWAYS_INLINE void DoNotifyFree(void* address) {
    PerformFreeNotification(s_observers, AllObservers{}, address);
  }

  static std::tuple<ObserverTypes*...> s_observers;
};

template <typename... ObserverTypes>
std::tuple<ObserverTypes*...> DispatcherImpl<ObserverTypes...>::s_observers;

#if BUILDFLAG(USE_ALLOCATOR_SHIM)
template <typename... ObserverTypes>
AllocatorDispatch DispatcherImpl<ObserverTypes...>::allocator_dispatch_ = {
    &AllocFn,
    &AllocUncheckedFn,
    &AllocZeroInitializedFn,
    &AllocAlignedFn,
    &ReallocFn,
    &FreeFn,
    &GetSizeEstimateFn,
    &BatchMallocFn,
    &BatchFreeFn,
    &FreeDefiniteSizeFn,
    &AlignedMallocFn,
    &AlignedReallocFn,
    &AlignedFreeFn,
    nullptr};
#endif

// Specialization of DispatcherImpl in case we have no observers to notify. In
// this special case we return a set of null pointers as the Dispatcher must not
// install any hooks at all.
template <>
struct DispatcherImpl<> {
  static DispatchData GetNotificationHooks(std::tuple<> /*observers*/) {
    return DispatchData()
#if BUILDFLAG(USE_PARTITION_ALLOC)
        .SetAllocationObserverHooks(nullptr, nullptr)
#endif
#if BUILDFLAG(USE_ALLOCATOR_SHIM)
        .SetAllocatorDispatch(nullptr)
#endif
        ;
  }
};

// A little utility function that helps using DispatcherImpl by providing
// automated type deduction for templates.
template <typename... ObserverTypes>
inline DispatchData GetNotificationHooks(
    std::tuple<ObserverTypes*...> observers) {
  return DispatcherImpl<ObserverTypes...>::GetNotificationHooks(
      std::move(observers));
}

}  // namespace base::allocator::dispatcher::internal

#endif  // BASE_ALLOCATOR_DISPATCHER_INTERNAL_DISPATCHER_INTERNAL_H_