ビームサーチ用テンプレート
単純なビームサーチ
- 実装
- Stateクラス
- 状態を定義
- 評価関数を定義
- ハッシュ関数を定義
- デバッグ用printを実装
- generate_next_states()関数
- 重複排除
- 改善
- 行動の改善
- 評価関数の改善
- 出力を永続stackにする
- 状態のコピーコスト削減(状態のサイズを小さくする、スコアだけを先に計算して上位のみ状態を生成する)
- priority_queue部分の改善(逆にしてBEAM_WIDTH個以上ならpopする、など)
- 重複排除
- 多様性確保
- 評価値、ハッシュ値の差分計算
- ビーム幅の調整、自動調整
- 差分更新可能な場合は、オイラーツアービームサーチなどに変更
- など
struct BeamConfig {
int MAX_TURN;
int BEAM_WIDTH;
};
struct State {
// ...
double eval; // 評価値
uint64_t hash; // ハッシュ値
vector<int> actions; // 行動列
State() : eval(compute_eval()), hash(compute_hash()) {
}
// 現在の状態から愚直に評価値を計算
double compute_eval() {
// ...
return 0;
}
// 現在の状態から愚直にハッシュ値を計算
uint64_t compute_hash() {
// ...
return 0;
}
void print() {
// ...
}
inline bool operator<(const State& other) const {
return eval < other.eval;
}
inline bool operator>(const State& other) const {
return eval > other.eval;
}
};
using Heap = priority_queue<State>; // 最大化タスク
// using Heap = priority_queue<State, vector<State>, greater<State>>; // 最小化タスク
void generate_next_states(Heap& next_states, const State& state) {
State next_state = state;
// ...
next_state.eval = 0;
next_state.hash = 0;
next_state.actions.push_back(0);
next_states.push(next_state);
}
void solve_beam(const BeamConfig& config) {
State init_state;
Heap states;
states.push(init_state);
rep(t, config.MAX_TURN) {
Heap next_states;
rep(b, config.BEAM_WIDTH) {
if (states.empty()) break;
generate_next_states(next_states, states.top());
states.pop();
}
swap(states, next_states);
}
assert(!states.empty());
State best_state = states.top();
// ...
}
Candidate生成によるビームサーチ
- Stateが重い場合、上記のように遷移先をStateで生成すると、処理時間的にもメモリ使用量的にも厳しい
- Stateから、評価値や操作差分だけなど必要最低限の情報だけをCandidateとして生成して、それで評価値上位を求めるようにすると軽くできる
前方更新型ビームサーチ
- 各ターンごと、State集合からCandidate集合を生成し、その評価値上位のもののみCandidateからStateを生成する
- 注意
- 毎ターン、最大でビーム幅個のState生成(Stateのコピー+変更操作)が入るので重い
- Stateの差分更新ができる場合は、差分更新型ビームサーチを検討する
namespace ForwardBeamSearch {
namespace Utils {
template <class T>
class CutOffPriorityQueue {
int max_size;
vector<T> v;
bool update;
void sort_and_cutoff() {
sort(v.rbegin(), v.rend());
if (v.size() > max_size) v.resize(max_size);
update = false;
}
public:
CutOffPriorityQueue(int max_size) : max_size(max_size), update(false) {
}
void clear() {
v.clear();
update = false;
}
void push(const T& x) {
v.emplace_back(x);
update = true;
if (v.size() >= 2 * max_size) sort_and_cutoff();
}
vector<T>& all() {
if (update) sort_and_cutoff();
return v;
}
};
}; // namespace Utils
////////// 以下を問題ごとに書き換える //////////
// TODO
using Score = double; // 注: 高いほどよい
using Hash = uint32_t;
enum class OpType {
// TODO
NONE,
OP1,
OP2
};
struct Op {
// 状態に対する変更操作の種類
OpType type;
// 状態に対する変更操作の追加情報
// TODO
Op() {
}
};
struct State {
Score score;
Hash hash;
// 状態に関するもの
// TODO
// 出力に関するもの
// TODO
State() {
}
inline bool operator<(const State& other) const {
return score < other.score;
}
inline bool operator>(const State& other) const {
return score > other.score;
}
// 状態に対する操作をサポートする関数
// TODO
};
ostream& operator<<(ostream& os, const State& state) {
// TODO
os << "stateの解";
return os;
}
// 注: できるだけ軽くする
struct Candidate {
Score score; // 変更後のスコア
int parent_idx; // 元となった親stateのindex
Hash hash; // 重複排除、多様性確保用
Op op; // 変更操作列
bool finished; // 終了したか
inline bool operator<(const Candidate& other) const {
return score < other.score;
}
inline bool operator>(const Candidate& other) const {
return score > other.score;
}
};
using Heap = Utils::CutOffPriorityQueue<Candidate>;
// 指定したStateからCandidateを作成
void expand(State& state, int parent_idx, Heap& cands) {
Candidate cand;
// TODO
cands.push(cand);
}
// CandidateからStateを生成
State make_state_from_candidate(const State& base_state, const Candidate& cand) {
State state = base_state; // !!!Copy!!!
state.score = cand.score;
state.hash = cand.hash;
// TODO
return state;
}
////////// ここまで //////////
struct Config {
int max_turn;
int beam_width;
bool with_timeout; // 各ターンでのタイムアウトありで動かすかどうか
double time_limit; // 処理全体のtime limit
double expand_weight; // expandフェーズに使う時間割合(1.0未満)
};
double get_timeout_sec(double remain_time, int remain_turn, double expand_weight) {
// 残り時間を等分割したものをこのターンで使える時間制限にする
return (remain_time / remain_turn) * expand_weight;
}
void solve(const State& init_state, const Config& config) {
Timer beam_timer;
beam_timer.start();
vector<State> states;
states.emplace_back(init_state);
Heap cands(config.beam_width);
for (int turn = 0; turn < config.max_turn; turn++) {
const double start_time = beam_timer.getSec();
const double timeout_sec = get_timeout_sec(config.time_limit - start_time,
config.max_turn - turn, config.expand_weight);
// expand
cands.clear();
for (size_t b = 0; b < states.size(); b++) {
if (b > config.beam_width) break;
if (b > 0 && config.with_timeout && beam_timer.getSec() - start_time > timeout_sec)
break;
expand(states[b], b, cands);
}
// cand2state
bool finished = false;
vector<State> next_states;
next_states.reserve(config.beam_width);
for (const Candidate& cand : cands.all()) {
next_states.emplace_back(make_state_from_candidate(states[cand.parent_idx], cand));
finished |= cand.finished;
}
swap(states, next_states);
if (finished) break;
}
assert(!states.empty());
State& best = states[0];
cout << best << endl;
}
}; // namespace ForwardBeamSearch
差分更新型ビームサーチ
- Stateに対して、変更操作とその逆操作の処理が軽い場合、Stateのコピー無しで状態生成することで高速化できる可能性がある
- move_forward(op), move_backward(op)
- EulerTourで操作を辺として持つ実装
- 注意
- Opだけの情報ではStateを戻せないような場合、Opを拡張するか、Candidate/Edgeに情報を持たせて、move_forward()/move_back_ward()時に更新する
namespace EulerTourBeamSearch {
namespace Utils {
template <class T>
class CutOffPriorityQueue {
int max_size;
vector<T> v;
bool update;
void sort_and_cutoff() {
sort(v.rbegin(), v.rend());
if (v.size() > max_size) v.resize(max_size);
update = false;
}
public:
CutOffPriorityQueue(int max_size) : max_size(max_size), update(false) {
}
void clear() {
v.clear();
update = false;
}
void push(const T& x) {
v.emplace_back(x);
update = true;
if (v.size() >= 2 * max_size) sort_and_cutoff();
}
vector<T>& all() {
if (update) sort_and_cutoff();
return v;
}
};
}; // namespace Utils
////////// 以下を問題ごとに書き換える //////////
// TODO
using Score = double; // 注: 高いほどよい
using Hash = uint32_t;
enum class OpType {
// TODO
NONE,
OP1,
OP2
};
struct Op {
// 状態に対する変更操作の種類
OpType type;
// 状態に対する変更操作の追加情報
// TODO
Op() {
}
bool operator==(const Op& op) const {
// TODO
return true;
}
};
struct State {
Score score;
Hash hash;
// 状態に関するもの
// TODO
State() {
}
inline bool operator<(const State& other) const {
return score < other.score;
}
inline bool operator>(const State& other) const {
return score > other.score;
}
// 現在の状態からopを適用した状態に遷移する
// 注: Opだけで状態が復元できない場合は、Candidate/Edgeに持たせて更新する
void move_forward(const Op& op) {
// TODO
}
// 現在の状態にopをて供する前の状態に遷移する
void move_backward(const Op& op) {
// TODO
}
};
ostream& operator<<(ostream& os, const State& state) {
// TODO
os << "stateの解";
return os;
}
// 注: できるだけ軽くする
struct Candidate {
Score score; // 変更後のスコア
int parent_idx; // 元となった親stateのindex
Hash hash; // 重複排除、多様性確保用
Op op; // 変更操作列
bool finished; // 終了したか
inline bool operator<(const Candidate& other) const {
return score < other.score;
}
inline bool operator>(const Candidate& other) const {
return score > other.score;
}
};
using Heap = Utils::CutOffPriorityQueue<Candidate>;
// 指定したStateからCandidateを作成
void expand(State& state, int parent_idx, Heap& cands) {
Candidate cand;
// TODO
cands.push(cand);
}
////////// ここまで //////////
struct Config {
int max_turn;
int beam_width;
};
class EulerTour {
struct Edge {
bool is_forward; // 前進辺か
int leaf_idx; // 遷移後が葉ノードなら0以上の番号
Op op; // 操作列
Edge() {
}
Edge(bool is_forward, int leaf_idx, const Op& op)
: is_forward(is_forward), leaf_idx(leaf_idx), op(op) {
}
};
State state;
vector<Op> direct_path;
vector<Edge> current_tour;
vector<Edge> next_tour;
vector<vector<pair<int, Candidate>>> buckets;
public:
EulerTour(const Config& config, const State& init_state)
: state(init_state), buckets(config.beam_width) {
}
void make_cands(Heap& cands) {
cands.clear();
if (current_tour.empty()) {
expand(state, 0, cands);
return;
}
for (const Edge& e : current_tour) {
if (e.is_forward) {
state.move_forward(e.op);
} else {
state.move_backward(e.op);
}
if (e.leaf_idx >= 0) {
expand(state, e.leaf_idx, cands);
}
}
}
bool update(Heap& cands) {
const vector<Candidate>& leaves = cands.all();
bool finished = false;
if (current_tour.empty()) {
for (size_t i = 0; i < leaves.size(); i++) {
const Candidate& c = leaves[i];
current_tour.emplace_back(true, i, c.op);
current_tour.emplace_back(false, -1, c.op);
finished |= c.finished;
}
return finished;
}
for (size_t i = 0; i < leaves.size(); i++) {
const Candidate& c = leaves[i];
buckets[c.parent_idx].emplace_back(i, c);
finished |= c.finished;
}
auto it = current_tour.begin();
while (it->is_forward && it->leaf_idx < 0 && it->op == current_tour.back().op) {
direct_path.emplace_back(it->op);
state.move_forward(it->op);
current_tour.pop_back();
++it;
if (it == current_tour.end()) break;
}
while (it != current_tour.end()) {
if (it->is_forward) {
next_tour.emplace_back(true, -1, it->op);
} else {
if (next_tour.back().is_forward) {
next_tour.pop_back();
} else {
next_tour.emplace_back(false, -1, it->op);
}
}
if (it->leaf_idx >= 0 && !buckets[it->leaf_idx].empty()) {
for (const auto& pr : buckets[it->leaf_idx]) {
const int next_leaf_idx = pr.first;
const Candidate& c = pr.second;
next_tour.emplace_back(true, next_leaf_idx, c.op);
next_tour.emplace_back(false, -1, c.op);
}
buckets[it->leaf_idx].clear();
}
++it;
}
swap(current_tour, next_tour);
next_tour.clear();
return finished;
}
vector<Op> get_path(int leaf_idx) const {
assert(leaf_idx >= 0);
vector<Op> ret = direct_path;
for (const Edge& e : current_tour) {
if (e.is_forward) {
ret.emplace_back(e.op);
} else {
ret.pop_back();
}
if (e.leaf_idx == leaf_idx) {
return ret;
}
}
return vector<Op>();
}
};
void solve(const State& init_state, const Config& config) {
EulerTour et(config, init_state);
Heap cands(config.beam_width);
for (int turn = 0; turn < config.max_turn; turn++) {
// expand
cands.clear();
et.make_cands(cands);
// cand2state
if (et.update(cands)) break;
}
Score best_score = cands.all()[0].score;
vector<Op> best_ops = et.get_path(0);
assert(!best_ops.empty());
// output best_ops
}
}; // namespace EulerTourBeamSearch