Program Listing for File closest_pair2D.hpp

Return to documentation for file (rcppsw/algorithm/closest_pair2D.hpp)

#pragma once

/*******************************************************************************
 * Includes
 ******************************************************************************/
#include <algorithm>
#include <cassert>
#include <cmath>
#include <functional>
#include <iostream>
#include <limits>
#include <string>
#include <vector>

#include "rcppsw/rcppsw.hpp"

/*******************************************************************************
 * Namespaces/Decls
 ******************************************************************************/
namespace rcppsw::algorithm {

/*******************************************************************************
 * Structure Definitions
 ******************************************************************************/
template <typename T>
struct result_type2D {
  T p1{};
  T p2{};
  double dist{ 0.0 };
  bool operator<(double d) const { return dist < d; }
  bool operator<(const result_type2D<T>& other) const {
    return this->dist < other.dist;
  }
};

/*******************************************************************************
 * Class Definitions
 ******************************************************************************/
template <typename T>
class closest_pair2D {
 public:
  static inline const std::string kBruteForce = "brute_force";

  static inline const std::string kRecursive = "recursive";

  template<typename TDistFunc>
  result_type2D<T> operator()(const std::string& method,
                              std::vector<T> points,
                              const TDistFunc& dist_func) {
    if (kBruteForce == method) {
      return brute_force(points, dist_func);
    } else if (kRecursive == method) {
      std::sort(points.begin(), points.end(), [](const T& a, const T& b) {
        return a.x() < b.x();
      });
      std::vector<T> strip;
      return recursive(points, strip, dist_func);
    }
    // Should never be hit
    return result_type2D<T>();
  }

  template <typename TDistFunc>
  result_type2D<T> brute_force(const std::vector<T>& points,
                               const TDistFunc& dist_func) {
    result_type2D<T> r;
    r.dist = std::numeric_limits<double>::max();

    for (size_t i = 0; i < points.size(); ++i) {
      for (size_t j = i + 1; j < points.size(); ++j) {
        if (dist_func(points[i], points[j]) < r.dist) {
          r.dist = dist_func(points[i], points[j]);
          r.p1 = points[i];
          r.p2 = points[j];
        }
      } /* for(j..) */
    } /* for(i..) */
    return r;
  }

  template <typename TDistFunc>
  result_type2D<T> recursive(const std::vector<T>& points,
                             std::vector<T>& strip,
                             const TDistFunc& dist_func) {
    /* base case */
    if (points.size() <= 3) {
      return brute_force(points, dist_func);
    }

    /* mid point */
    size_t mid = points.size() / 2;
    T mid_point = points[mid];

    /*
     * Calculate the smallest distance
     * dl: left of mid point
     * dr: right side of the mid point
     */
    result_type2D<T> dl = recursive(
        std::vector<T>(points.begin(), points.begin() + mid), strip, dist_func);
    result_type2D<T> dr = recursive(
        std::vector<T>(points.begin() + mid, points.end()), strip, dist_func);
    result_type2D<T> dmin = std::min(dl, dr);

    for (size_t i = 0; i < points.size(); ++i) {
      if (std::fabs(points[i].x() - mid_point.x()) < dmin.dist) {
        strip.push_back(points[i]);
      }
    } /* for(i..) */
    auto res = strip_points(strip, dmin, dist_func);
    strip.clear();
    return res;
  }

 private:
  template <typename TDistFunc>
  result_type2D<T> strip_points(std::vector<T> strip,
                                const result_type2D<T>& dmin,
                                const TDistFunc& dist_func) {
    result_type2D<T> min = dmin;

    std::sort(strip.begin(), strip.end(), [](const T& a, const T& b) {
      return a.y() < b.y();
    });

    /*
     * Pick all points one by one and try the next points till the difference
     * between y's is smaller than d.
     */
    for (size_t i = 0; i < strip.size(); ++i) {
      for (size_t j = i + 1;
           j < strip.size() && (strip[j].y() - strip[i].y()) < min.dist;
           ++j) {
        if (dist_func(strip[i], strip[j]) < min.dist) {
          min.dist = dist_func(strip[i], strip[j]);
          min.p1 = strip[i];
          min.p2 = strip[j];
        }
      } /* for(i..) */
    } /* for(j..) */
    return min;
  }
};

} /* namespace rcppsw::algorithm */