/**
 * @file   ivf_flat_index.h
 *
 * @section LICENSE
 *
 * The MIT License
 *
 * @copyright Copyright (c) 2023 TileDB, Inc.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 *
 * @section DESCRIPTION
 *
 * Header-only library of functions for building an inverted file (IVF) index,
 * generated by kmeans algorithm.
 *
 * The basic use case is:
 * - Create an instance of the index
 * - Call train() to build the index
 * - OR Call load() to load the index from TileDB arrays
 * - Call add() to add vectors to the index (alt. add with ids)
 * - Call search() to query the index, returning the ids of the nearest vectors,
 *   and optionally the distances.
 * - Compute the recall of the search results.
 *
 * - Call save() to save the index to disk
 * - Call reset() to clear the index
 *
 * Still WIP.
 */

#ifndef TILEDB_ivf_flat_index_H
#define TILEDB_ivf_flat_index_H

#include <atomic>
#include <random>
#include <thread>

#include "algorithm.h"
#include "concepts.h"
#include "cpos.h"
#include "index/index_defs.h"
#include "index/ivf_flat_group.h"
#include "index/kmeans.h"
#include "linalg.h"

#include "detail/flat/qv.h"
#include "detail/ivf/index.h"
#include "detail/ivf/partition.h"
#include "detail/ivf/qv.h"

#include <tiledb/group_experimental.h>
#include <tiledb/tiledb>

/**
 * Class representing an inverted file (IVF) index for flat (non-compressed)
 * feature vectors.  The class simply holds the index data itself, it is
 * unaware of where the data comes from -- reading and writing data is done
 * via an ivf_flat_group.  Thus, this class does not hold information
 * about the group (neither the group members, nor the group metadata).
 *
 * @tparam partitioned_vectors_feature_type
 * @tparam partitioned_ids_type
 * @tparam partitioning_index_type
 */
template <
    class partitioned_vectors_feature_type,
    class partitioned_ids_type = uint64_t,
    class partitioning_index_type = uint64_t>
class ivf_flat_index {
 public:
  using feature_type = partitioned_vectors_feature_type;
  using id_type = partitioned_ids_type;
  using indices_type = partitioning_index_type;
  using score_type = float;  // @todo -- this should be a parameter?
  using centroid_feature_type = score_type;

  using group_type = ivf_flat_group<ivf_flat_index>;
  using metadata_type = ivf_flat_index_metadata;

 private:
  using storage_type = ColMajorPartitionedMatrix<
      feature_type,
      partitioned_ids_type,
      indices_type>;

  using tdb_storage_type = tdbColMajorPartitionedMatrix<
      feature_type,
      partitioned_ids_type,
      indices_type>;

  using centroids_storage_type = ColMajorMatrix<centroid_feature_type>;

  constexpr static const IndexKind index_kind_ = IndexKind::IVFFlat;

  /****************************************************************************
   * Index group information
   ****************************************************************************/

  /** The timestamp at which the index was created */
  TemporalPolicy temporal_policy_{TimeTravel, 0};

  std::unique_ptr<ivf_flat_group<ivf_flat_index>> group_;

  /****************************************************************************
   * Index representation
   ****************************************************************************/

  // Cached information about the partitioned vectors in the index
  uint64_t dimensions_{0};
  uint64_t num_partitions_{0};

  // The PartitionedMatrix has indices and ids internally and is also
  // responsible for dealing with loading itself, etc
  std::unique_ptr<storage_type> partitioned_vectors_;
  centroids_storage_type centroids_;

  // Some parameters for performing kmeans clustering
  uint64_t max_iter_{1};
  float tol_{1.e-4};
  float reassign_ratio_{0.075};

  uint64_t num_threads_{std::thread::hardware_concurrency()};
  uint64_t seed_{std::random_device{}()};

 public:
  using value_type = feature_type;
  using index_type = partitioning_index_type;  // @todo This isn't right

  /****************************************************************************
   * Constructors (et al)
   ****************************************************************************/

  // ivf_flat_index() = delete;
  ivf_flat_index(const ivf_flat_index& index) = delete;
  ivf_flat_index& operator=(const ivf_flat_index& index) = delete;
  ivf_flat_index(ivf_flat_index&& index) = default;
  ivf_flat_index& operator=(ivf_flat_index&& index) = default;

  /**
   * @brief Construct a new `ivf_flat_index` object, setting a number of
   * parameters to be used subsequently in training.  To fully create an index
   * we will need to call `train()` and `add()`.
   *
   * @param dimensions Dimensions of the vectors comprising the training set and
   * the data set.
   * @param nlist Number of centroids / partitions to compute.
   * @param max_iter Maximum number of iterations for kmans algorithm.
   * @param tol Convergence tolerance for kmeans algorithm.
   * @param timestamp Timestamp for the index.
   * @param seed Random seed for kmeans algorithm.
   *
   * @todo Use chained parameter technique for arguments
   * @todo -- Need something equivalent to "None" since user could pass 0
   * @todo -- Or something equivalent to "Use current time" and something
   * to indicate "no time traveling"
   * @todo -- May also want start/stop?  Use a variant?  TemporalPolicy?
   */
  ivf_flat_index(
      // size_t dim,
      size_t nlist = 0,
      size_t max_iter = 2,
      float tol = 0.000025,
      TemporalPolicy temporal_policy = TemporalPolicy{TimeTravel, 0},
      uint64_t seed = std::random_device{}())
      :  // , dimensions_(dim)
      temporal_policy_{
          temporal_policy.timestamp_end() != 0 ?
              temporal_policy :
              TemporalPolicy{
                  TimeTravel,
                  static_cast<uint64_t>(
                      std::chrono::duration_cast<std::chrono::milliseconds>(
                          std::chrono::system_clock::now().time_since_epoch())
                          .count())}}
      , num_partitions_(nlist)
      , max_iter_(max_iter)
      , tol_(tol)
      , seed_{seed} {
    gen_.seed(seed_);
  }

  /**
   * @brief Open a previously created index, stored as a TileDB group.  This
   * class does not deal with the group itself, but rather calls the group
   * constructor.  The group constructor will initialize itself with information
   * about the different constituent arrays needed for operation of this class,
   * but will not initialize any member data of the class.
   *
   * The group is opened with a timestamp, so the correct values of base_size
   * and num_partitions will be set.
   *
   * We go ahead and load the centroids here.
   * @todo Is this the right place to load the centroids?
   *
   * @param ctx
   * @param uri
   * @param timestamp
   *
   */
  ivf_flat_index(
      const tiledb::Context& ctx,
      const std::string& uri,
      std::optional<TemporalPolicy> temporal_policy = std::nullopt)
      : temporal_policy_{temporal_policy.has_value() ? *temporal_policy : TemporalPolicy()}
      , group_{std::make_unique<ivf_flat_group<ivf_flat_index>>(
            ctx, uri, TILEDB_READ, temporal_policy_)} {
    /**
     * Read the centroids.  How the partitioned_vectors_ are read in will be
     * determined by the type of query we are doing.  But they will be read
     * in at this same timestamp.
     */
    dimensions_ = group_->get_dimensions();
    num_partitions_ = group_->get_num_partitions();
    // Read all rows from column 0 -> `num_partitions_`. Set no upper_bound.
    centroids_ =
        std::move(tdbPreLoadMatrix<centroid_feature_type, stdx::layout_left>(
            group_->cached_ctx(),
            group_->centroids_uri(),
            std::nullopt,
            num_partitions_,
            0,
            temporal_policy_));
  }

  /****************************************************************************
   * Methods for building, writing, and reading the complete index
   * @todo Create single function that trains and adds (ingests)
   * @todo Provide interface that takes URI rather than vectors
   * @todo Provide "kernel" interface for use in distributed computation
   * @todo Do we need an out-of-core version of this?
   ****************************************************************************/

  template <feature_vector_array V>
  void kmeans_random_init(const V& training_set) {
    ::kmeans_random_init(training_set, centroids_, num_partitions_);
  }

  template <feature_vector_array V, class Distance = sum_of_squares_distance>
  void kmeans_pp(const V& training_set) {
    ::kmeans_pp<std::remove_cvref_t<V>, decltype(centroids_), Distance>(
        training_set, centroids_, num_partitions_, num_threads_);
  }

  /**
   * Compute centroids of the training set data, using the kmeans algorithm.
   * The initialization algorithm used to generate the starting centroids
   * for kmeans is specified by the `init` parameter.  Either random
   * initialization or kmeans++ initialization can be used.
   *
   * @param training_set Array of vectors to cluster.
   * @param init Specify which initialization algorithm to use,
   * random (`random`) or kmeans++ (`kmeanspp`).
   */
  template <feature_vector_array V, class Distance = sum_of_squares_distance>
  void train(
      //       ColMajorMatrix<feature_type>& training_set,
      const V& training_set,
      kmeans_init init = kmeans_init::random) {
    dimensions_ = ::dimensions(training_set);
    if (num_partitions_ == 0) {
      num_partitions_ = std::sqrt(num_vectors(training_set));
    }

    centroids_ =
        ColMajorMatrix<centroid_feature_type>(dimensions_, num_partitions_);
    switch (init) {
      case (kmeans_init::none):
        break;
      case (kmeans_init::kmeanspp):
        kmeans_pp<std::remove_cvref_t<decltype(training_set)>, Distance>(
            training_set);
        break;
      case (kmeans_init::random):
        kmeans_random_init(training_set);
        break;
    };

// Temporary printf debugging
#if 0
    std::cout << "\nCentroids Before:\n" << std::endl;
    for (size_t j = 0; j < centroids_.num_cols(); ++j) {
      for (size_t i = 0; i < dimensions_; ++i) {
        std::cout << centroids_[j][i] << " ";
      }
      std::cout << std::endl;
    }
    std::cout << std::endl;
#endif

    // @todo Or we can pass a `distance` parameter as an argument and use
    // argument deduction
    train_no_init<
        std::remove_cvref_t<decltype(training_set)>,
        // std::remove_cvref<decltype(centroids_)>,
        decltype(centroids_),
        Distance>(
        training_set,
        centroids_,
        dimensions_,
        num_partitions_,
        max_iter_,
        tol_,
        num_threads_,
        reassign_ratio_);

// Temporary printf debugging
#if 0
    std::cout << "\nCentroids After:\n" << std::endl;
    for (size_t j = 0; j < centroids_.num_cols(); ++j) {
      for (size_t i = 0; i < dimensions_; ++i) {
        std::cout << centroids_[j][i] << " ";
      }
      std::cout << std::endl;
    }
    std::cout << std::endl;
#endif
  }

  /**
   * @brief Build the index from a training set, given the centroids.  This
   * will partition the training set into a contiguous array, with one
   * partition per centroid.  It will also create an array to record the
   * original ids locations of each vector (their locations in the original
   * training set) as well as a partitioning index array demarcating the
   * boundaries of each partition (including the very end of the array).
   *
   * @param training_set Array of vectors to partition.
   *
   * @todo Create and write index that is larger than RAM
   */
  template <feature_vector_array V>
  void add(const V& training_set) {
    auto partition_labels =
        detail::flat::qv_partition(centroids_, training_set, num_threads_);

    // @note parts is a vector containing the partition label for each
    // vector in training_set.  num_parts is how many unique labels there
    // are
    auto num_unique_labels = ::num_vectors(centroids_);

    // @todo Should we have a move here?
    partitioned_vectors_ = std::make_unique<storage_type>(
        training_set, partition_labels, num_unique_labels);
  }

  /*****************************************************************************
   * Methods for reading and reading the index from a group.
   *****************************************************************************/

  /**
   * @brief Read the the complete index arrays into ("infinite") memory.
   * This will read the centroids, indices, partitioned_ids, and
   * and the complete set of partitioned_vectors, along with metadata
   * from a group_uri.
   *
   * @param group_uri
   * @return bool indicating success or failure of read
   */
  auto read_index_infinite() {
    if (partitioned_vectors_ &&
        (::num_vectors(*partitioned_vectors_) != 0 ||
         ::num_vectors(partitioned_vectors_->ids()) != 0)) {
      throw std::runtime_error("Index already loaded");
    }

    // Load all partitions for infinite query
    // Note that the constructor will move the infinite_parts vector
    auto infinite_parts = std::vector<indices_type>(::num_vectors(centroids_));
    std::iota(begin(infinite_parts), end(infinite_parts), 0);
    partitioned_vectors_ = std::make_unique<tdb_storage_type>(
        group_->cached_ctx(),
        group_->parts_uri(),
        group_->indices_uri(),
        group_->get_num_partitions() + 1,
        group_->ids_uri(),
        infinite_parts,
        0,
        temporal_policy_);

    partitioned_vectors_->load();

    assert(
        ::num_vectors(*partitioned_vectors_) ==
        size(partitioned_vectors_->ids()));
    assert(
        size(partitioned_vectors_->indices()) == ::num_vectors(centroids_) + 1);
  }

  /**
   * @brief Open the index from the arrays contained in the group_uri.
   * The "finite" queries only load as much data (ids and vectors) as are
   * necessary for a given query -- so we can't load any data until we
   * know what the query is.  So, here we would have read the centroids and
   * indices into memory, when creating the index but would not have read
   * the partitioned_ids or partitioned_vectors.
   *
   * @param group_uri
   * @return bool indicating success or failure of read
   */
  template <feature_vector_array Q>
  auto read_index_finite(
      const Q& query_vectors, size_t nprobe, size_t upper_bound) {
    if (partitioned_vectors_ &&
        (::num_vectors(*partitioned_vectors_) != 0 ||
         ::num_vectors(partitioned_vectors_->ids()) != 0)) {
      throw std::runtime_error("Index already loaded");
    }

    auto&& [active_partitions, active_queries] =
        detail::ivf::partition_ivf_flat_index<indices_type>(
            centroids_, query_vectors, nprobe, num_threads_);

    partitioned_vectors_ = std::make_unique<tdb_storage_type>(
        group_->cached_ctx(),
        group_->parts_uri(),
        group_->indices_uri(),
        group_->get_num_partitions() + 1,
        group_->ids_uri(),
        active_partitions,
        upper_bound,
        temporal_policy_);

    // NB: We don't load the partitioned_vectors here.  We will load them
    // when we do the query.
    return std::make_tuple(
        std::move(active_partitions), std::move(active_queries));
  }

  /**
   * @brief Write the index to storage.  This would typically be done after a
   * set of input vectors has been read and a new group is created.  Or after
   * consolidation.
   *
   * We assume we have all of the data in memory, and that we are writing
   * all of it to a TileDB group.  Since we have all of it in memory,
   * we write from the PartitionedMatrix base class.
   *
   * @param group_uri The URI of the TileDB group where the index will be saved
   * @param storage_version The storage version to use. If empty, use the most
   * defult version.
   * @return Whether the write was successful
   */
  auto write_index(
      const tiledb::Context& ctx,
      const std::string& group_uri,
      const std::string& storage_version = "") const {
    // Write the group
    auto write_group = ivf_flat_group<ivf_flat_index>(
        ctx,
        group_uri,
        TILEDB_WRITE,
        temporal_policy_,
        storage_version,
        dimensions_);

    write_group.set_dimensions(dimensions_);

    write_group.append_ingestion_timestamp(temporal_policy_.timestamp_end());
    write_group.append_base_size(::num_vectors(*partitioned_vectors_));
    write_group.append_num_partitions(num_partitions_);

    write_matrix(
        ctx,
        centroids_,
        write_group.centroids_uri(),
        0,
        false,
        temporal_policy_);

    write_matrix(
        ctx,
        *partitioned_vectors_,
        write_group.parts_uri(),
        0,
        false,
        temporal_policy_);

    write_vector(
        ctx,
        partitioned_vectors_->ids(),
        write_group.ids_uri(),
        0,
        false,
        temporal_policy_);

    write_vector(
        ctx,
        partitioned_vectors_->indices(),
        write_group.indices_uri(),
        0,
        false,
        temporal_policy_);

    return true;
  }

  auto write_index_arrays(
      const tiledb::Context& ctx,
      const std::string& centroids_uri,
      const std::string& parts_uri,
      const std::string& ids_uri,
      const std::string& indices_uri) const {
    write_matrix(ctx, centroids_, centroids_uri, 0, true);
    write_matrix(ctx, *partitioned_vectors_, parts_uri, 0, true);
    write_vector(ctx, partitioned_vectors_->ids(), ids_uri, 0, true);
    write_vector(ctx, partitioned_vectors_->indices(), indices_uri, 0, true);

    return true;
  }

  /*****************************************************************************
   *
   * Queries, infinite and finite.
   *
   * An "infinite" query assumes there is enough RAM to load the entire array
   * of partitioned vectors into memory.  The query function then searches in
   * the appropriate partitions of the array for the query vectors.
   *
   * A "finite" query, on the other hand, examines the query and only loads
   * the partitions that are necessary for that particular search.  A finite
   * query also supports out of core operation, meaning that only a subset of
   * the necessary partitions are loaded into memory at any one time.  The
   * query is applied to each subset until all of the necessary partitions to
   * satisfy the query have been read in . The number of partitions to be held
   * in memory is controlled by an upper bound parameter that the user can set.
   * The upper bound limits the total number of vectors that will he held in
   * memory as the partitions are loaded.  Only complete partitions are loaded,
   * so the actual number of vectors in memory at any one time will generally
   * be less than the upper bound.
   *
   * @todo Add vq and dist queries (should dist be its own index?)
   * @todo Order queries so that partitions are queried in order
   *
   ****************************************************************************/

  /**
   * @brief Perform a query on the index, returning the nearest neighbors
   * and distances.  The function returns a matrix containing k_nn nearest
   * neighbors for each given query and a matrix containing the distances
   * corresponding to each returned neighbor.
   *
   * This function searches for the nearest neighbors using "infinite RAM",
   * that is, it loads the entire IVF index into memory and then applies the
   * query.
   *
   * @tparam Q Type of query vectors.
   * @param query_vectors Array of vectors to query.
   * @param k_nn Number of nearest neighbors to return.
   * @param nprobe Number of centroids to search.
   *
   * @return A tuple containing a matrix of nearest neighbors and a matrix
   * of the corresponding distances.
   *
   */
  template <feature_vector_array Q, class Distance = sum_of_squares_distance>
  auto query_infinite_ram(
      const Q& query_vectors,
      size_t k_nn,
      size_t nprobe,
      Distance distance = Distance{}) {
    if (!partitioned_vectors_ || ::num_vectors(*partitioned_vectors_) == 0) {
      read_index_infinite();
    }
    auto&& [active_partitions, active_queries] =
        detail::ivf::partition_ivf_flat_index<indices_type>(
            centroids_, query_vectors, nprobe, num_threads_);
    return detail::ivf::query_infinite_ram(
        *partitioned_vectors_,
        active_partitions,
        query_vectors,
        active_queries,
        k_nn,
        num_threads_);
  }

  /**
   * See the documentation for that function in detail/ivf/qv.h
   * for more details.
   */
  template <feature_vector_array Q, class Distance = sum_of_squares_distance>
  auto qv_query_heap_infinite_ram(
      const Q& query_vectors,
      size_t k_nn,
      size_t nprobe,
      Distance distance = Distance{}) {
    if (!partitioned_vectors_ || ::num_vectors(*partitioned_vectors_) == 0) {
      read_index_infinite();
    }

    auto top_centroids = detail::ivf::ivf_top_centroids(
        centroids_, query_vectors, nprobe, num_threads_);
    return detail::ivf::qv_query_heap_infinite_ram(
        top_centroids,
        *partitioned_vectors_,
        query_vectors,
        nprobe,
        k_nn,
        num_threads_,
        distance);
  }

  /**
   * @brief Same as query_infinite_ram, but using the
   * nuv_query_heap_infinite_ram function.
   * See the documentation for that function in detail/ivf/qv.h
   * for more details.
   */
  template <feature_vector_array Q, class Distance = sum_of_squares_distance>
  auto nuv_query_heap_infinite_ram(
      const Q& query_vectors,
      size_t k_nn,
      size_t nprobe,
      Distance distance = Distance{}) {
    if (!partitioned_vectors_ || ::num_vectors(*partitioned_vectors_) == 0) {
      read_index_infinite();
    }
    auto&& [active_partitions, active_queries] =
        detail::ivf::partition_ivf_flat_index<indices_type>(
            centroids_, query_vectors, nprobe, num_threads_);
    return detail::ivf::nuv_query_heap_infinite_ram(
        *partitioned_vectors_,
        active_partitions,
        query_vectors,
        active_queries,
        k_nn,
        num_threads_,
        distance);
  }

  /**
   * @brief Same as query_infinite_ram, but using the
   * nuv_query_heap_infinite_ram_reg_blocked function.
   * See the documentation for that function in detail/ivf/qv.h
   * for more details.
   */
  template <feature_vector_array Q, class Distance = sum_of_squares_distance>
  auto nuv_query_heap_infinite_ram_reg_blocked(
      const Q& query_vectors,
      size_t k_nn,
      size_t nprobe,
      Distance distance = Distance{}) {
    if (!partitioned_vectors_ || ::num_vectors(*partitioned_vectors_) == 0) {
      read_index_infinite();
    }
    auto&& [active_partitions, active_queries] =
        detail::ivf::partition_ivf_flat_index<indices_type>(
            centroids_, query_vectors, nprobe, num_threads_);
    return detail::ivf::nuv_query_heap_infinite_ram_reg_blocked(
        *partitioned_vectors_,
        active_partitions,
        query_vectors,
        active_queries,
        k_nn,
        num_threads_,
        distance);
  }

  // WIP
#if 0
  template <feature_vector_array Q, class Distance = sum_of_squares_distance>
  auto qv_query_heap_finite_ram(
      const Q& query_vectors,
      size_t k_nn,
      size_t nprobe,
      size_t upper_bound = 0, Distance distance = Distance{}) {
    if (partitioned_vectors_ && ::num_vectors(*partitioned_vectors_) != 0) {
      std::throw_with_nested(
          std::runtime_error("Vectors are already loaded. Cannot load twice. "
                             "Cannot do finite query on in-memory index."));
    }
    auto&& [active_partitions, active_queries] =
        read_index_finite(query_vectors, nprobe, upper_bound);

    return detail::ivf::qv_query_heap_finite_ram(
        centroids_,
        *partitioned_vectors_,
        query_vectors,
        active_queries,
        nprobe,
        k_nn,
        upper_bound,
        num_threads_, distance);
  }
#endif  // 0

  /**
   * @brief Perform a query on the index, returning the nearest neighbors
   * and distances.  The function returns a matrix containing k_nn nearest
   * neighbors for each given query and a matrix containing the distances
   * corresponding to each returned neighbor.
   *
   * This function searches for the nearest neighbors using "finite RAM",
   * that is, it only loads that portion of the IVF index into memory that
   * is necessary for the given query.  In addition, it supports out of core
   * operation, meaning that only a subset of the necessary partitions are
   * loaded into memory at any one time.
   *
   * See the documentation for that function in detail/ivf/qv.h
   * for more details.
   *
   * @param query_vectors Array of vectors to query.
   * @param k_nn Number of nearest neighbors to return.
   * @param nprobe Number of centroids to search.
   *
   * @return A tuple containing a matrix of nearest neighbors and a matrix
   * of the corresponding distances.
   *
   * @tparam Q
   * @param query_vectors
   * @param k_nn
   * @param nprobe
   * @return
   */
  template <feature_vector_array Q, class Distance = sum_of_squares_distance>
  auto query_finite_ram(
      const Q& query_vectors,
      size_t k_nn,
      size_t nprobe,
      size_t upper_bound = 0,
      Distance distance = Distance{}) {
    if (partitioned_vectors_ && ::num_vectors(*partitioned_vectors_) != 0) {
      throw std::runtime_error(
          "Vectors are already loaded. Cannot load twice. "
          "Cannot do finite query on in-memory index.");
    }
    auto&& [active_partitions, active_queries] =
        read_index_finite(query_vectors, nprobe, upper_bound);

    return detail::ivf::query_finite_ram(
        *partitioned_vectors_,
        query_vectors,
        active_queries,
        k_nn,
        upper_bound,
        num_threads_,
        distance);
  }

  /**
   * @brief Same as query_finite_ram, but using the
   * nuv_query_heap_infinite_ram function.
   * See the documentation for that function in detail/ivf/qv.h
   * for more details.
   */
  template <feature_vector_array Q, class Distance = sum_of_squares_distance>
  auto nuv_query_heap_finite_ram(
      const Q& query_vectors,
      size_t k_nn,
      size_t nprobe,
      size_t upper_bound = 0,
      Distance distance = Distance{}) {
    if (partitioned_vectors_ && ::num_vectors(*partitioned_vectors_) != 0) {
      std::throw_with_nested(
          std::runtime_error("Vectors are already loaded. Cannot load twice. "
                             "Cannot do finite query on in-memory index."));
    }
    auto&& [active_partitions, active_queries] =
        read_index_finite(query_vectors, nprobe, upper_bound);

    return detail::ivf::nuv_query_heap_finite_ram(
        *partitioned_vectors_,
        query_vectors,
        active_queries,
        k_nn,
        upper_bound,
        num_threads_,
        distance);
  }

  /**
   * @brief Same as query_finite_ram, but using the
   * nuv_query_heap_infinite_ram_reg_blocked function.
   * See the documentation for that function in detail/ivf/qv.h
   * for more details.
   */
  template <feature_vector_array Q, class Distance = sum_of_squares_distance>
  auto nuv_query_heap_finite_ram_reg_blocked(
      const Q& query_vectors,
      size_t k_nn,
      size_t nprobe,
      size_t upper_bound = 0,
      Distance distance = Distance{}) {
    if (partitioned_vectors_ && ::num_vectors(*partitioned_vectors_) != 0) {
      std::throw_with_nested(
          std::runtime_error("Vectors are already loaded. Cannot load twice. "
                             "Cannot do finite query on in-memory index."));
    }
    auto&& [active_partitions, active_queries] =
        read_index_finite(query_vectors, nprobe, upper_bound);

    return detail::ivf::nuv_query_heap_finite_ram_reg_blocked(
        *partitioned_vectors_,
        query_vectors,
        active_queries,
        k_nn,
        upper_bound,
        num_threads_,
        distance);
  }

  /***************************************************************************
   * Getters (copilot weirded me out again -- it suggested "Getters" based
   * only on the two character "/ *" that I typed to begin a comment,
   * and with the two functions below.)
   * Note that we don't have a `num_vectors` because it isn't clear what
   * that means for a partitioned (possibly out-of-core) index.
   ***************************************************************************/
  auto dimensions() const {
    return dimensions_;
  }

  auto num_partitions() const {
    return ::num_vectors(centroids_);
  }

  /***************************************************************************
   * Methods to aid Testing and Debugging
   *
   * @todo -- As elsewhere in this class, there is huge code duplication here
   *
   **************************************************************************/

  /**
   * @brief Compare groups associated with two ivf_flat_index objects for
   * equality.  Note that both indexes will have had to perform a read or
   * a write.  An index created from partitioning will not yet have a group
   * associated with it.
   *
   * Comparing groups will also compare metadata associated with each group.
   *
   * @param rhs the index against which to compare
   * @return bool indicating equality of the groups
   */
  bool compare_group(const ivf_flat_index& rhs) const {
    return group_->compare_group(*(rhs.group_));
  }

  /**
   * @brief Compare metadata associated with two ivf_flat_index objects for
   * equality.  Thi is not the same as the metadata associated with the index
   * group.  Rather, it is the metadata associated with the index itself and is
   * only a small number of cached quantities.
   *
   * Note that `max_iter` et al are only relevant for partitioning an index
   * and are not stored (and would not be meaningful to compare at any rate).
   *
   * @param rhs the index against which to compare
   * @return bool indicating equality of the index metadata
   */
  bool compare_cached_metadata(const ivf_flat_index& rhs) {
    if (dimensions_ != rhs.dimensions_) {
      return false;
    }
    if (num_partitions_ != rhs.num_partitions_) {
      return false;
    }

    return true;
  }

  /**
   * @brief Compare two `feature_vector_arrays` for equality
   *
   * @tparam L Type of the lhs `feature_vector_array`
   * @tparam R Type of the rhs `feature_vector_array`
   * @param rhs the index against which to compare
   * @param lhs The lhs `feature_vector_array`
   * @return bool indicating equality of the `feature_vector_arrays`
   */
  template <feature_vector_array L, feature_vector_array R>
  auto compare_feature_vector_arrays(const L& lhs, const R& rhs) const {
    if (::num_vectors(lhs) != ::num_vectors(rhs) ||
        ::dimensions(lhs) != ::dimensions(rhs)) {
      std::cout << "num_vectors(lhs) != num_vectors(rhs) || dimensions(lhs) != "
                   "dimensions(rhs)n"
                << std::endl;
      std::cout << "num_vectors(lhs): " << ::num_vectors(lhs)
                << " num_vectors(rhs): " << ::num_vectors(rhs) << std::endl;
      std::cout << "dimensions(lhs): " << ::dimensions(lhs)
                << " dimensions(rhs): " << ::dimensions(rhs) << std::endl;
      return false;
    }
    for (size_t i = 0; i < ::num_vectors(lhs); ++i) {
      if (!std::equal(begin(lhs[i]), end(lhs[i]), begin(rhs[i]))) {
        std::cout << "lhs[" << i << "] != rhs[" << i << "]" << std::endl;
        std::cout << "lhs[" << i << "]: ";
        for (size_t j = 0; j < ::dimensions(lhs); ++j) {
          std::cout << lhs[i][j] << " ";
        }
        std::cout << std::endl;
        std::cout << "rhs[" << i << "]: ";
        for (size_t j = 0; j < ::dimensions(rhs); ++j) {
          std::cout << rhs[i][j] << " ";
        }
        return false;
      }
    }
    return true;
  }

  /**
   * @brief Compare two `feature_vectors` for equality
   * @tparam L Type of the lhs `feature_vector`
   * @tparam R Type of the rhs `feature_vector`
   * @param lhs The lhs `feature_vector`
   * @param rhs The rhs `feature_vector`
   * @return
   */
  template <feature_vector L, feature_vector R>
  auto compare_vectors(const L& lhs, const R& rhs) const {
    if (::dimensions(lhs) != ::dimensions(rhs)) {
      std::cout << "dimensions(lhs) != dimensions(rhs) (" << ::dimensions(lhs)
                << " != " << ::dimensions(rhs) << ")" << std::endl;
      return false;
    }
    return std::equal(begin(lhs), end(lhs), begin(rhs));
  }

  /**
   * @brief Compare `centroids_` against another index
   * @param rhs The index against which to compare
   * @return bool indicating equality of the centroids
   */
  auto compare_centroids(const ivf_flat_index& rhs) {
    return compare_feature_vector_arrays(centroids_, rhs.centroids_);
  }

  /**
   * @brief Compare all of the `feature_vector_arrays` against another index
   * @param rhs The other index
   * @return bool indicating equality of the `feature_vector_arrays`
   */
  auto compare_feature_vectors(const ivf_flat_index& rhs) {
    if (partitioned_vectors_->num_vectors() !=
        rhs.partitioned_vectors_->num_vectors()) {
      std::cout << "partitioned_vectors_->num_vectors() != "
                   "rhs.partitioned_vectors_->num_vectors() ("
                << partitioned_vectors_->num_vectors()
                << " != " << rhs.partitioned_vectors_->num_vectors() << ")"
                << std::endl;
      return false;
    }
    if (partitioned_vectors_->num_partitions() !=
        rhs.partitioned_vectors_->num_partitions()) {
      std::cout << "partitioned_vectors_->num_parts() != "
                   "rhs.partitioned_vectors_->num_parts() ("
                << partitioned_vectors_->num_partitions()
                << " != " << rhs.partitioned_vectors_->num_partitions() << ")"
                << std::endl;
      return false;
    }

    return compare_feature_vector_arrays(
        *partitioned_vectors_, *(rhs.partitioned_vectors_));
  }

  /** Compare the stored `indices_` vector */
  auto compare_indices(const ivf_flat_index& rhs) {
    return compare_vectors(
        partitioned_vectors_->indices(), rhs.partitioned_vectors_->indices());
  }

  /** Compare the stored `partitioned_ids_` vector */
  auto compare_partitioned_ids(const ivf_flat_index& rhs) {
    return compare_vectors(
        partitioned_vectors_->ids(), rhs.partitioned_vectors_->ids());
  }

  auto set_centroids(const ColMajorMatrix<feature_type>& centroids) {
    centroids_ = ColMajorMatrix<centroid_feature_type>(
        ::dimensions(centroids), ::num_vectors(centroids));
    std::copy(
        centroids.data(),
        centroids.data() + centroids.num_rows() * centroids.num_cols(),
        centroids_.data());
  }

  auto& get_centroids() {
    return centroids_;
  }

  /**
   * @brief Used for evaluating quality of partitioning
   * @param centroids
   * @param vectors
   * @return
   */
  static std::vector<indices_type> predict(
      const ColMajorMatrix<feature_type>& centroids,
      const ColMajorMatrix<feature_type>& vectors) {
    // Return a vector of indices of the nearest centroid for each vector in
    // the matrix. Write the code below:
    auto nClusters = centroids.num_cols();
    std::vector<indices_type> indices(vectors.num_cols());
    std::vector<score_type> distances(nClusters);
    for (size_t i = 0; i < vectors.num_cols(); ++i) {
      for (size_t j = 0; j < nClusters; ++j) {
        distances[j] = l2_distance(vectors[i], centroids[j]);
      }
      indices[i] =
          std::min_element(begin(distances), end(distances)) - begin(distances);
    }
    return indices;
  }

  void dump_group(const std::string& msg) {
    group_->dump(msg);
  }

  void dump_metadata(const std::string& msg) {
    group_->metadata.dump(msg);
  }
};

#endif  // TILEDB_ivf_flat_index_H
