Skip to content

TSP

template <class T>
class TSP {
    int N;
    vector<vector<T>> G;

    struct State {
        T dist_sum;
        vector<int> order;
    };

    class Transition {
       public:
        Transition() = default;
        virtual ~Transition() = default;
        virtual T calc() = 0;
        virtual void apply() = 0;
    };

    class TransitionSelector {
        vector<int> table;
        vector<unique_ptr<Transition>> transitions;
        int select_index;

       public:
        void add(unique_ptr<Transition>&& transition, int weight) {
            int idx = transitions.size();
            for (int i = 0; i < weight; i++) {
                table.push_back(idx);
            }
            transitions.emplace_back(move(transition));
        }
        inline void select_transition() {
            select_index = table[xor128() % table.size()];
        }
        inline T calc() {
            return transitions[select_index]->calc();
        }
        inline void apply() {
            transitions[select_index]->apply();
        }
    };

    class Transition2opt : public Transition {
       private:
        State& state;
        const vector<vector<T>>& G;
        T new_dist_sum;
        int l, r;

       public:
        Transition2opt(State& state, const vector<vector<T>>& G) : state(state), G(G) {
        }
        virtual T calc() override {
            new_dist_sum = state.dist_sum;

            const int N = state.order.size();
            l = 1 + xor128() % (N - 1);
            r = l;
            while (l == r) r = 1 + xor128() % (N - 1);
            if (l > r) swap(l, r);

            int pl = (l - 1 + N) % N;
            int nr = (r + 1) % N;

            new_dist_sum -= G[state.order[pl]][state.order[l]];
            new_dist_sum -= G[state.order[r]][state.order[nr]];

            new_dist_sum += G[state.order[pl]][state.order[r]];
            new_dist_sum += G[state.order[l]][state.order[nr]];

            return new_dist_sum;
        }
        virtual void apply() override {
            state.dist_sum = new_dist_sum;
            while (l < r) {
                swap(state.order[l], state.order[r]);
                l++;
                r--;
            }
        }
    };

    class Transition3opt : public Transition {
       private:
        State& state;
        const vector<vector<T>>& G;
        T new_dist_sum;
        vector<int> pos;
        int recon_type;
        vector<int> new_order;

        void init_pos() {
            const int N = state.order.size();
            pos[0] = -1;
            while (true) {
                for (int i = 0; i < 3; i++) pos[i] = xor128() % N;
                sort(pos.begin(), pos.end());
                // if (pos[0] + 1 < pos[1] && pos[1] + 1 < pos[2]) break;
                if (pos[0] < pos[1] && pos[1] < pos[2]) break;
            }
        }

       public:
        Transition3opt(State& state, const vector<vector<T>>& G) : state(state), G(G), pos(3, -1) {
            new_order.reserve(state.order.size());
        }
        virtual T calc() override {
            new_dist_sum = state.dist_sum;
            init_pos();
            if (pos.size() == 0) return new_dist_sum;
            recon_type = xor128() % 7;
            const int N = state.order.size();

            if (recon_type == 0) {
                new_dist_sum -= G[state.order[pos[0]]][state.order[(pos[0] + 1) % N]];
                new_dist_sum -= G[state.order[pos[1]]][state.order[(pos[1] + 1) % N]];
                new_dist_sum += G[state.order[pos[0]]][state.order[pos[1]]];
                new_dist_sum += G[state.order[(pos[1] + 1) % N]][state.order[(pos[0] + 1) % N]];
            } else if (recon_type == 1) {
                new_dist_sum -= G[state.order[pos[1]]][state.order[(pos[1] + 1) % N]];
                new_dist_sum -= G[state.order[pos[2]]][state.order[(pos[2] + 1) % N]];
                new_dist_sum += G[state.order[pos[1]]][state.order[pos[2]]];
                new_dist_sum += G[state.order[(pos[1] + 1) % N]][state.order[(pos[2] + 1) % N]];
            } else if (recon_type == 2) {
                new_dist_sum -= G[state.order[pos[0]]][state.order[(pos[0] + 1) % N]];
                new_dist_sum -= G[state.order[pos[2]]][state.order[(pos[2] + 1) % N]];
                new_dist_sum += G[state.order[pos[0]]][state.order[pos[2]]];
                new_dist_sum += G[state.order[(pos[0] + 1) % N]][state.order[(pos[2] + 1) % N]];
            } else if (recon_type == 3) {
                new_dist_sum -= G[state.order[pos[0]]][state.order[(pos[0] + 1) % N]];
                new_dist_sum -= G[state.order[pos[1]]][state.order[(pos[1] + 1) % N]];
                new_dist_sum -= G[state.order[pos[2]]][state.order[(pos[2] + 1) % N]];
                new_dist_sum += G[state.order[pos[0]]][state.order[pos[1]]];
                new_dist_sum += G[state.order[(pos[0] + 1) % N]][state.order[pos[2]]];
                new_dist_sum += G[state.order[(pos[1] + 1) % N]][state.order[(pos[2] + 1) % N]];
            } else if (recon_type == 4) {
                new_dist_sum -= G[state.order[pos[0]]][state.order[(pos[0] + 1) % N]];
                new_dist_sum -= G[state.order[pos[1]]][state.order[(pos[1] + 1) % N]];
                new_dist_sum -= G[state.order[pos[2]]][state.order[(pos[2] + 1) % N]];
                new_dist_sum += G[state.order[pos[0]]][state.order[pos[2]]];
                new_dist_sum += G[state.order[(pos[1] + 1) % N]][state.order[(pos[0] + 1) % N]];
                new_dist_sum += G[state.order[pos[1]]][state.order[(pos[2] + 1) % N]];
            } else if (recon_type == 5) {
                new_dist_sum -= G[state.order[pos[0]]][state.order[(pos[0] + 1) % N]];
                new_dist_sum -= G[state.order[pos[1]]][state.order[(pos[1] + 1) % N]];
                new_dist_sum -= G[state.order[pos[2]]][state.order[(pos[2] + 1) % N]];
                new_dist_sum += G[state.order[pos[0]]][state.order[(pos[1] + 1) % N]];
                new_dist_sum += G[state.order[pos[2]]][state.order[pos[1]]];
                new_dist_sum += G[state.order[(pos[0] + 1) % N]][state.order[(pos[2] + 1) % N]];
            } else if (recon_type == 6) {
                new_dist_sum -= G[state.order[pos[0]]][state.order[(pos[0] + 1) % N]];
                new_dist_sum -= G[state.order[pos[1]]][state.order[(pos[1] + 1) % N]];
                new_dist_sum -= G[state.order[pos[2]]][state.order[(pos[2] + 1) % N]];
                new_dist_sum += G[state.order[pos[0]]][state.order[(pos[1] + 1) % N]];
                new_dist_sum += G[state.order[pos[2]]][state.order[(pos[0] + 1) % N]];
                new_dist_sum += G[state.order[pos[1]]][state.order[(pos[2] + 1) % N]];
            }

            return new_dist_sum;
        }
        virtual void apply() override {
            if (pos[0] == -1) return;
            state.dist_sum = new_dist_sum;
            const int N = state.order.size();

            if (recon_type == 0) {
                int l = pos[0] + 1, r = pos[1];
                while (l < r) {
                    swap(state.order[l], state.order[r]);
                    l++;
                    r--;
                }
            } else if (recon_type == 1) {
                int l = pos[1] + 1, r = pos[2];
                while (l < r) {
                    swap(state.order[l], state.order[r]);
                    l++;
                    r--;
                }
            } else if (recon_type == 2) {
                int l = pos[0] + 1, r = pos[2];
                while (l < r) {
                    swap(state.order[l], state.order[r]);
                    l++;
                    r--;
                }
            } else if (recon_type == 3) {
                new_order.clear();
                for (int i = pos[1]; i >= pos[0] + 1; i--) new_order.push_back(state.order[i]);
                for (int i = pos[2]; i >= pos[1] + 1; i--) new_order.push_back(state.order[i]);
                for (int i = 0; i < new_order.size(); i++)
                    state.order[pos[0] + 1 + i] = new_order[i];
            } else if (recon_type == 4) {
                new_order.clear();
                for (int i = pos[2]; i >= pos[1] + 1; i--) new_order.push_back(state.order[i]);
                for (int i = pos[0] + 1; i <= pos[1]; i++) new_order.push_back(state.order[i]);
                for (int i = 0; i < new_order.size(); i++)
                    state.order[pos[0] + 1 + i] = new_order[i];
            } else if (recon_type == 5) {
                new_order.clear();
                for (int i = pos[1] + 1; i <= pos[2]; i++) new_order.push_back(state.order[i]);
                for (int i = pos[1]; i >= pos[0] + 1; i--) new_order.push_back(state.order[i]);
                for (int i = 0; i < new_order.size(); i++)
                    state.order[pos[0] + 1 + i] = new_order[i];
            } else if (recon_type == 6) {
                new_order.clear();
                for (int i = pos[1] + 1; i <= pos[2]; i++) new_order.push_back(state.order[i]);
                for (int i = pos[0] + 1; i <= pos[1]; i++) new_order.push_back(state.order[i]);
                for (int i = 0; i < new_order.size(); i++)
                    state.order[pos[0] + 1 + i] = new_order[i];
            }

            assert(state.order.size() == N);

            /*
            {
                double dsum = 0;
                for (int i = 0; i < N; i++) {
                    dsum += G[state.order[i]][state.order[(i + 1) % N]];
                }
                if (abs(dsum - state.dist_sum) > 1e-5) {
                    cerr << "! " << recon_type << " " << dsum << " " << state.dist_sum << endl;
                }
                assert(abs(dsum - state.dist_sum) < 1e-5);
            }
            //*/
        }
    };

    double calc_temp(double sec, double time_limit) {
        // 初期温度は近傍によって最適値が違いそうなので都度調整
        static const double START_TEMP = 1e-4;
        static const double END_TEMP = 1e-9;
        return START_TEMP + (END_TEMP - START_TEMP) * sec / time_limit;
    }

   public:
    TSP(int N) : N(N), G(N, vector<T>(N, 1e9)) {
    }
    void add_dist(int i, int j, const T& a) {
        G[i][j] = a;
        G[j][i] = a;
    }
    vector<int> calc(const double time_limit) {
        State state;
        for (int i = 0; i < N; i++) {
            state.order.push_back(i);
        }
        state.dist_sum = 0;
        for (int i = 0; i < N; i++) {
            state.dist_sum += G[state.order[i]][state.order[(i + 1) % N]];
        }

        State best_state = state;

        TransitionSelector selector;
        selector.add(make_unique<Transition2opt>(state, G), 50);
        selector.add(make_unique<Transition3opt>(state, G), 50);

        Timer timer;
        timer.start();
        double sec;
        int step = 0;
        while (true) {
            step++;
            sec = timer.getSec();
            if (sec > time_limit) break;
            double temp = calc_temp(sec, time_limit);

            selector.select_transition();

            T prev_score = state.dist_sum;
            T new_score = selector.calc();

            T delta = prev_score - new_score;
            bool undo = false;
            if (delta < 0) {
                if (exp(delta / temp) >= frand()) {
                    ;
                } else {
                    undo = true;
                }
            }
            if (undo) {
            } else {
                selector.apply();
                if (best_state.dist_sum > state.dist_sum) {
                    best_state = state;
                }
            }
        }
        cerr << "Step: " << step << endl;
        cerr << best_state.dist_sum << endl;

        return state.order;
    }
};