ビームサーチ用テンプレート
#include <bits/stdc++.h>
#include <atcoder/all>
using namespace std;
using namespace atcoder;
#define REP(i, a, n) for (int i = (a); i < (int)(n); i++)
#define rep(i, n) REP(i, 0, n)
#define FOR(it, c) for (__typeof((c).begin()) it = (c).begin(); it != (c).end(); ++it)
#define ALLOF(c) (c).begin(), (c).end()
typedef long long ll;
typedef unsigned long long ull;
class Timer {
std::chrono::system_clock::time_point start_time;
std::chrono::system_clock::time_point getNow() {
return std::chrono::system_clock::now();
}
public:
void start() {
start_time = getNow();
}
float getSec() {
float count =
std::chrono::duration_cast<std::chrono::microseconds>(getNow() - start_time).count();
return count / 1e6;
}
};
uint32_t xor128() {
static uint32_t x = 123456789, y = 362436069, z = 521288629, w = 88675123;
uint32_t t;
t = (x ^ (x << 11));
x = y;
y = z;
z = w;
return w = (w ^ (w >> 19)) ^ (t ^ (t >> 8));
}
inline float frand() {
return xor128() % UINT32_MAX / static_cast<float>(UINT32_MAX);
}
inline int exprand(int x) {
return (int)pow(x, frand());
}
template <class RandomAccessIterator>
void xor128_shuffle(RandomAccessIterator first, RandomAccessIterator last) {
typename iterator_traits<RandomAccessIterator>::difference_type i, n;
n = (last - first);
for (i = n - 1; i > 0; --i) swap(first[i], first[xor128() % (i + 1)]);
}
static constexpr double GLOBAL_TIME_LIMIT = 1.97;
Timer global_timer;
using Action = int;
using Actions = vector<Action>;
struct State {
int score; // 評価値
Action first_action;
State() {
reset();
}
void reset() {
// ...
score = compute_score();
}
int compute_score() {
// ...
return 0;
}
void check_score() {
assert(compute_score() == score);
}
// ターン型の場合は状態に遷移操作を持たせる
void advance(const Action& action) {
// ...
}
Actions legal_actions() const {
Actions actions;
// ...
return actions;
}
void print() {
// ...
}
inline bool operator<(const State& other) const {
return score < other.score;
}
inline bool operator>(const State& other) const {
return score > other.score;
}
};
Action random_action(const State& state) {
auto legal_actions = state.legal_actions();
assert(legal_actions.size() > 0);
return legal_actions[xor128() % legal_actions.size()];
}
Action greedy_action(const State& state) {
auto legal_actions = state.legal_actions();
assert(legal_actions.size() > 0);
int best_score = -1;
Action best_action;
for (const auto& action : legal_actions) {
State now = state;
now.advance(action);
if (best_score < now.score) {
best_score = now.score;
best_action = action;
}
}
return best_action;
}
Action beam_search_action(const State& state, const int beam_width, const int beam_depth) {
vector<shared_ptr<State>> states;
states.emplace_back(make_shared<State>(state));
for (int turn = 0; turn < beam_depth; turn++) {
vector<shared_ptr<State>> next_states;
for (const shared_ptr<State>& now_state : states) {
auto legal_actions = now_state->legal_actions();
for (const auto& action : legal_actions) {
shared_ptr<State> next_state = make_shared<State>(*now_state);
next_state->advance(action);
if (turn == 0) {
next_state->first_action = action;
}
next_states.emplace_back(next_state);
}
}
if (next_states.size() > beam_width) {
// c++20以降はshared_ptrのoperator<は削除されているので、compを指定する必要がある
nth_element(next_states.begin(), next_states.begin() + beam_width, next_states.end(),
[](const auto& lhs, const auto& rhs) { return *lhs > *rhs; });
next_states.resize(beam_width);
}
swap(states, next_states);
}
shared_ptr<State> best_state = make_shared<State>();
for (const shared_ptr<State>& now_state : states) {
if (*best_state < *now_state) {
best_state = now_state;
}
}
return best_state->first_action;
}
void input() {
}
void init(State& state) {
}
void solve_look_ahead() {
State state;
init(state);
const double TIME_LIMIT = GLOBAL_TIME_LIMIT - global_timer.getSec();
Timer timer;
timer.start();
double sec;
while (true) {
sec = timer.getSec();
if (sec > TIME_LIMIT) break;
Action action = random_action(state);
state.advance(action);
}
}
void solve_all(const int beam_width, const int max_turn) {
State state;
init(state);
vector<State> states;
states.emplace_back(state);
for (int turn = 0; turn < max_turn; turn++) {
vector<State> next_states;
for (const State& now_state : states) {
auto legal_actions = now_state.legal_actions();
for (const auto& action : legal_actions) {
State next_state = now_state;
next_state.advance(action);
if (turn == 0) {
next_state.first_action = action;
}
next_states.emplace_back(next_state);
}
}
if (next_states.size() > beam_width) {
nth_element(next_states.begin(), next_states.begin() + beam_width, next_states.end(),
greater<>());
next_states.resize(beam_width);
}
swap(states, next_states);
}
}
int main() {
global_timer.start();
input();
solve_look_ahead();
solve_all(10, 100);
return 0;
}