Robię takie zadanko: http://www.usaco.org/index.php?page=viewproblem2&cpid=623
Rozwiązaniem wydaje się znalezienie MST, używam algorytmu Kruskala:
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
const ll INF = 1e9 + 7, MAXN = (2000 + 1) * (2000 + 1) + 7;
void setIO(string s){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
if (s != ""){
freopen((s + ".in").c_str(), "r", stdin);
freopen((s + ".out").c_str(), "w", stdout);
}
}
vector<int> rep(MAXN), amount(MAXN, 1);
int Find(int x){
if (rep[x] != x)
rep[x] = Find(rep[x]);
return rep[x];
}
void Union(int a, int b){
a = Find(a);
b = Find(b);
if (a == b)
return;
if (amount[a] < amount[b])
swap(a, b);
rep[b] = a;
amount[a] += amount[b];
}
int main(){
setIO("");
int A, B, n, m;
cin >> A >> B >> n >> m;
vector<int> a(n+1), b(m+1);
for (int i = 1; i <= n; i++)
cin >> a[i];
for (int i = 1; i <= m; i++)
cin >> b[i];
sort(a.begin(), a.end());
sort(b.begin(), b.end());
a.push_back(A);
b.push_back(B);
vector<pair<int, pair<int, int>>> edges;
for (int i = 1; i < (int)a.size(); i++)
for (int id = (i-1)*(m+1); id + 1 < i*(m+1); id++)
edges.push_back({a[i] - a[i-1], {id, id+1}});
for (int i = 1; i < (int)b.size(); i++)
for (int id = i-1; id + m + 1 < (n+1)*(m+1)+(i-1); id += (m+1))
edges.push_back({b[i] - b[i-1], {id, id+m+1}});
sort(edges.begin(), edges.end());
for (int i = 0; i <= (n+1)*(m+1); i++)
rep[i] = i;
ll ans = 0;
for (auto& pp : edges)
if (Find(pp.second.first) != Find(pp.second.second)){
ans += pp.first;
Union(pp.second.first, pp.second.second);
}
cout << ans << "\n";
return 0;
}
Niestety nie przechodzi ostatniego testy ze względu na czas, co wydaje mi się dość dziwne skoro jego złożoność powinna być mniej więcej O(nm * log(nm)), próbowałem też w jakiś sposób wykorzystać to że wierzchołki tworzą tablicę dwuwymiarową z użyciem algorytmu Prima:
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
const ll INF = 1e9 + 7, MAXN = 2000 + 7, MAXM = 2000 + 7;
void setIO(string s){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
if (s != ""){
freopen((s + ".in").c_str(), "r", stdin);
freopen((s + ".out").c_str(), "w", stdout);
}
}
vector<vector<int>> to_left(MAXN, vector<int>(MAXM)), to_down(MAXN, vector<int>(MAXM)), to_right(MAXN, vector<int>(MAXM)), to_up(MAXN, vector<int>(MAXM));
vector<vector<bool>> seen(MAXN, vector<bool>(MAXM));
ll prim(int n, int m){
priority_queue<pair<int, pair<int, int>>, vector<pair<int, pair<int, int>>>, greater<pair<int, pair<int, int>>>> pq;
pq.push({0, {0, 0}});
ll ans = 0;
while (!pq.empty()){
pair<int, int> curr = pq.top().second;
if (seen[curr.first][curr.second] == true){
pq.pop();
continue;
}
seen[curr.first][curr.second] = true;
ans += pq.top().first;
pq.pop();
if (curr.second + 1 <= m && seen[curr.first][curr.second + 1] == false)
pq.push({to_left[curr.first][curr.second+1], {curr.first, curr.second+1}});
if (curr.second - 1 >= 0 && seen[curr.first][curr.second - 1] == false)
pq.push({to_right[curr.first][curr.second-1], {curr.first, curr.second-1}});
if (curr.first + 1 <= n && seen[curr.first+1][curr.second] == false)
pq.push({to_down[curr.first+1][curr.second], {curr.first+1, curr.second}});
if (curr.first - 1 >= 0 && seen[curr.first-1][curr.second] == false)
pq.push({to_up[curr.first-1][curr.second], {curr.first-1, curr.second}});
}
return ans;
}
int main(){
setIO("fencedin");
int A, B, n, m;
cin >> A >> B >> n >> m;
vector<int> a(n+1), b(m+1);
for (int i = 1; i <= n; i++)
cin >> a[i];
for (int i = 1; i <= m; i++)
cin >> b[i];
sort(a.begin(), a.end());
sort(b.begin(), b.end());
a.push_back(A);
b.push_back(B);
for (int i = 1; i < (int)a.size(); i++)
for (int j = 1; j < (m+1); j++)
to_left[i-1][j] = to_right[i-1][j-1] = a[i] - a[i-1];
for (int j = 1; j < (int)b.size(); j++)
for (int i = 1; i < (n+1); i++)
to_down[i][j-1] = to_up[i-1][j-1] = b[j] - b[j-1];
cout << prim(n, m) << "\n";
return 0;
}
Jednak to rozwiązanie jest jeszcze wolniejsze
Proszę o jakąś podpowiedź :D
EDIT:
Przyśpieszyłem kod, największym problemem był push_back() zamiast zadeklarowania miejsca, być może isnieje sprytniejsze rozwiązanie, przyśpieszony kod:
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
const ll INF = 1e9 + 7, MAXN = (2000 + 1) * (2000 + 1) + 7;
void setIO(string s){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
if (s != ""){
freopen((s + ".in").c_str(), "r", stdin);
freopen((s + ".out").c_str(), "w", stdout);
}
}
int rep[MAXN], amount[MAXN], a[MAXN], b[MAXN];
int Find(int x){
if (rep[x] != x)
rep[x] = Find(rep[x]);
return rep[x];
}
void Union(int x, int y){
x = Find(x);
y = Find(y);
if (x == y)
return;
if (amount[x] < amount[y])
swap(x, y);
rep[y] = x;
amount[x] += amount[y];
}
struct edge{
int w;
int x;
int y;
};
bool cmp(const edge& e1, const edge& e2){
return e1.w < e2.w;
}
int main(){
setIO("fencedin");
int A, B, n, m;
cin >> A >> B >> n >> m;
for (int i = 1; i <= n; i++)
cin >> a[i];
for (int i = 1; i <= m; i++)
cin >> b[i];
sort(a, a + n+1);
sort(b, b + m+1);
a[n+1] = A;
b[m+1] = B;
vector<edge> edges(MAXN * 2);
int idx = 0;
for (int i = 1; i < n+2; i++)
for (int id = (i-1)*(m+1); id + 1 < i*(m+1); id++, idx++)
edges[idx] = {a[i] - a[i-1], id, id+1};
for (int i = 1; i < m+2; i++)
for (int id = i-1; id + m + 1 < (n+1)*(m+1)+(i-1); id += (m+1), idx++)
edges[idx] = {b[i] - b[i-1], id, id+m+1};
sort(edges.begin(), edges.begin() + idx, cmp);
for (int i = 0; i <= (n+1)*(m+1); i++){
rep[i] = i;
amount[i] = 1;
}
ll ans = 0;
for (int i = 0; i < idx; i++){
if (Find(edges[i].x) != Find(edges[i].y)){
Union(edges[i].x, edges[i].y);
ans += edges[i].w;
}
}
cout << ans << "\n";
return 0;
}