// Copyright (c) by respective owners including Yahoo!, Microsoft, and
// individual contributors. All rights reserved. Released under a BSD (revised)
// license as described in the file LICENSE.

#pragma once

#include "vw/common/future_compat.h"
#include "vw/explore/explore_error_codes.h"

#include <cstdint>

#define S_EXPLORATION_OK 0
#define E_EXPLORATION_BAD_RANGE 1
#define E_EXPLORATION_PMF_RANKING_SIZE_MISMATCH 2
#define E_EXPLORATION_BAD_PDF 3
#define E_EXPLORATION_BAD_EPSILON 4

namespace VW
{

namespace explore
{
/**
 * @brief Experimental: Generates epsilon-greedy style exploration distribution.
 *
 * @tparam It Iterator type of the pre-allocated pmf. Must be a RandomAccessIterator.
 * @param epsilon Minimum probability used to explore among options. Each action is explored with at least
 * epsilon/num_actions.
 * @param top_action Index of the exploit actions. This action will be get probability mass of 1-epsilon +
 * (epsilon/num_actions).
 * @param pmf_first Iterator pointing to the pre-allocated beginning of the pmf to be generated by this function.
 * @param pmf_last Iterator pointing to the pre-allocated end of the pmf to be generated by this function.
 * @return int returns 0 on success, otherwise an error code as defined by E_EXPLORATION_*.
 */
template <typename It>
int generate_epsilon_greedy(float epsilon, uint32_t top_action, It pmf_first, It pmf_last);

/**
 * @brief Generates softmax style exploration distribution.
 *
 * @tparam InputIt Iterator type of the input scores. Must be an InputIterator.
 * @tparam OutputIt Iterator type of the pre-allocated pmf. Must be a RandomAccessIterator.
 * @param lambda Lambda parameter of softmax.
 * @param scores_first Iterator pointing to beginning of the scores.
 * @param scores_last Iterator pointing to end of the scores.
 * @param pmf_first Iterator pointing to the pre-allocated beginning of the pmf to be generated by this function.
 * @param pmf_last Iterator pointing to the pre-allocated end of the pmf to be generated by this function.
 * @return int returns 0 on success, otherwise an error code as defined by E_EXPLORATION_*.
 */
template <typename InputIt, typename OutputIt>
int generate_softmax(float lambda, InputIt scores_first, InputIt scores_last, OutputIt pmf_first, OutputIt pmf_last);

/**
 * @brief Generates an exploration distribution according to votes on actions.
 *
 * @tparam InputIt Iterator type of the input actions. Must be an InputIterator.
 * @tparam OutputIt Iterator type of the pre-allocated pmf. Must be a RandomAccessIterator.
 * @param top_actions_first Iterator pointing to the beginning of the top actions.
 * @param top_actions_last Iterator pointing to the end of the top actions.
 * @param pmf_first Iterator pointing to the pre-allocated beginning of the pmf to be generated by this function.
 * @param pmf_last Iterator pointing to the pre-allocated end of the pmf to be generated by this function.
 * @return int returns 0 on success, otherwise an error code as defined by E_EXPLORATION_*.
 */
template <typename InputIt, typename OutputIt>
int generate_bag(InputIt top_actions_first, InputIt top_actions_last, OutputIt pmf_first, OutputIt pmf_last);

/**
 * @brief Updates the pmf to ensure each action is explored with at least minimum_uniform/num_actions.
 *
 * @tparam It Iterator type of the pmf. Must be a RandomAccessIterator.
 * @param uniform_epsilon The minimum amount of uniform distribution to impose on the pmf.
 * @param consider_zero_valued_elements If true elements with zero probability are updated, otherwise those actions will
 * be unchanged.
 * @param pmf_first Iterator pointing to the pre-allocated beginning of the pmf to be generated by this function.
 * @param pmf_last Iterator pointing to the pre-allocated end of the pmf to be generated by this function.
 * @return int returns 0 on success, otherwise an error code as defined by E_EXPLORATION_*.
 */
template <typename It>
int enforce_minimum_probability(float uniform_epsilon, bool consider_zero_valued_elements, It pmf_first, It pmf_last);

/**
 * @brief Mix original PMF with uniform distribution.
 *
 * @tparam It It Iterator type of the pmf. Must be a RandomAccessIterator.
 * @param uniform_epsilon The minimum amount of uniform distribution to be mixed with the pmf.
 * @param pmf_first Iterator pointing to the pmf to be updated.
 * @param pmf_last Iterator pointing to the pmf to be updated.
 * @return int returns 0 on success, otherwise an error code as defined by E_EXPLORATION_*.
 */
template <typename It>
int mix_with_uniform(float uniform_epsilon, It pmf_first, It pmf_last);

/**
 * @brief Sample an index from the provided pmf. If the pmf is not normalized it will be updated in-place.
 *
 * @tparam InputIt Iterator type of the pmf. Must be a RandomAccessIterator.
 * @param seed The seed for the pseudo-random generator.
 * @param pmf_first Iterator pointing to the beginning of the pmf.
 * @param pmf_last Iterator pointing to the end of the pmf.
 * @param chosen_index returns the chosen index.
 * @return int returns 0 on success, otherwise an error code as defined by E_EXPLORATION_*.
 */
template <typename It>
int sample_after_normalizing(uint64_t seed, It pmf_first, It pmf_last, uint32_t& chosen_index);

/**
 * @brief Sample an index from the provided pmf.  If the pmf is not normalized it will be updated in-place.
 *
 * @tparam It Iterator type of the pmf. Must be a RandomAccessIterator.
 * @param seed The seed for the pseudo-random generator. Will be hashed using MURMUR hash.
 * @param pmf_first Iterator pointing to the beginning of the pmf.
 * @param pmf_last Iterator pointing to the end of the pmf.
 * @param chosen_index returns the chosen index.
 * @return int returns 0 on success, otherwise an error code as defined by E_EXPLORATION_*.
 */
template <typename It>
int sample_after_normalizing(const char* seed, It pmf_first, It pmf_last, uint32_t& chosen_index);

/**
 * @brief Swap the first value with the chosen index.
 *
 * @tparam ActionIt Iterator type of the action. Must be a forward_iterator.
 * @param action_first Iterator pointing to the beginning of the pdf.
 * @param action_last Iterator pointing to the end of the pdf.
 * @param chosen_index The index value that should be swapped with the first element
 * @return int returns 0 on success, otherwise an error code as defined by E_EXPLORATION_*.
 */
template <typename ActionIt>
int swap_chosen(ActionIt action_first, ActionIt action_last, uint32_t chosen_index);

/**
 * @brief Sample a continuous value from the provided pdf.
 *
 * Warning: `seed` must be sufficiently random for the PRNG to produce uniform random values. Using sequential seeds
 * will result in a very biased distribution. If unsure how to update seed between calls, merand48 (in random_details.h)
 * can be used to inplace mutate it.
 *
 * @tparam It Iterator type of the pmf. Must be a RandomAccessIterator.
 * @param p_seed The seed for the pseudo-random generator. Will be hashed using MURMUR hash. The seed state will be
 * advanced
 * @param pdf_first Iterator pointing to the beginning of the pdf.
 * @param pdf_last Iterator pointing to the end of the pdf.
 * @param chosen_value returns the sampled continuous value.
 * @param pdf_value returns the probablity density at the sampled location.
 * @return int returns 0 on success, otherwise an error code as defined by E_EXPLORATION_*.
 */
template <typename It>
int sample_pdf(uint64_t* p_seed, It pdf_first, It pdf_last, float& chosen_value, float& pdf_value);

}  // namespace explore
}  // namespace VW

namespace exploration
{
/// Function moved to VW::explore::generate_epsilon_greedy()
template <typename It>
VW_DEPRECATED("Moved to VW::exploration explorece")
int generate_epsilon_greedy(float epsilon, uint32_t top_action, It pmf_first, It pmf_last)
{
  // call vw version
  return VW::explore::generate_epsilon_greedy(epsilon, top_action, pmf_first, pmf_last);
}

/// Function moved to VW::explore::generate_softmax()
template <typename InputIt, typename OutputIt>
VW_DEPRECATED("Moved to VW::exploration explorece")
int generate_softmax(float lambda, InputIt scores_first, InputIt scores_last, OutputIt pmf_first, OutputIt pmf_last)
{
  // call vw version
  return VW::explore::generate_softmax(lambda, scores_first, scores_last, pmf_first, pmf_last);
}

/// Function moved to VW::explore::generate_bag()
template <typename InputIt, typename OutputIt>
VW_DEPRECATED("Moved to VW::exploration explorece")
int generate_bag(InputIt top_actions_first, InputIt top_actions_last, OutputIt pmf_first, OutputIt pmf_last)
{
  // call vw version
  return VW::explore::generate_bag(top_actions_first, top_actions_last, pmf_first, pmf_last);
}

/// Function moved to VW::explore::enforce_minimum_probability()
template <typename It>
VW_DEPRECATED("Moved to VW::exploration explorece")
int enforce_minimum_probability(float uniform_epsilon, bool consider_zero_valued_elements, It pmf_first, It pmf_last)
{
  // call vw version
  return VW::explore::enforce_minimum_probability(uniform_epsilon, consider_zero_valued_elements, pmf_first, pmf_last);
}

/// Function moved to VW::explore::mix_with_uniform()
template <typename It>
VW_DEPRECATED("Moved to VW::exploration explorece")
int mix_with_uniform(float uniform_epsilon, It pmf_first, It pmf_last)
{
  // call vw version
  return VW::explore::mix_with_uniform(uniform_epsilon, pmf_first, pmf_last);
}

/// Function moved to VW::explore::sample_after_normalizing()
template <typename It>
VW_DEPRECATED("Moved to VW::exploration explorece")
int sample_after_normalizing(uint64_t seed, It pmf_first, It pmf_last, uint32_t& chosen_index)
{
  // call vw version
  return VW::explore::sample_after_normalizing(seed, pmf_first, pmf_last, chosen_index);
}

/// Function moved to VW::explore::sample_after_normalizing()
template <typename It>
VW_DEPRECATED("Moved to VW::exploration explorece")
int sample_after_normalizing(const char* seed, It pmf_first, It pmf_last, uint32_t& chosen_index)
{
  // call vw version
  return VW::explore::sample_after_normalizing(seed, pmf_first, pmf_last, chosen_index);
}

/// Function moved to VW::explore::swap_chosen()
template <typename ActionIt>
VW_DEPRECATED("Moved to VW::exploration explorece")
int swap_chosen(ActionIt action_first, ActionIt action_last, uint32_t chosen_index)
{
  // call vw version
  return VW::explore::swap_chosen(action_first, action_last, chosen_index);
}

/// Function moved to VW::explore::sample_pdf()
template <typename It>
VW_DEPRECATED("Moved to VW::exploration explorece")
int sample_pdf(uint64_t* p_seed, It pdf_first, It pdf_last, float& chosen_value, float& pdf_value)
{
  // call vw version
  return VW::explore::sample_pdf(p_seed, pdf_first, pdf_last, chosen_value, pdf_value);
}
}  // namespace exploration

// Implementations can be found in the internal header.
#include "explore_internal.h"
