Robię takie zadanko: http://www.usaco.org/index.php?page=viewproblem2&cpid=623
Rozwiązaniem wydaje się znalezienie MST, używam algorytmu Kruskala:

#include <bits/stdc++.h>
    
using namespace std;
 
#define ll long long
#define ull unsigned long long
 
const ll INF = 1e9 + 7, MAXN = (2000 + 1) * (2000 + 1) + 7;
 
void setIO(string s){
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    if (s != ""){
        freopen((s + ".in").c_str(), "r", stdin);
        freopen((s + ".out").c_str(), "w", stdout);
    }
}

vector<int> rep(MAXN), amount(MAXN, 1);

int Find(int x){
    if (rep[x] != x)
        rep[x] = Find(rep[x]);
    return rep[x];
}

void Union(int a, int b){
    a = Find(a);
    b = Find(b);

    if (a == b)
        return;
    
    if (amount[a] < amount[b])
        swap(a, b);
    
    rep[b] = a;
    amount[a] += amount[b];
}

int main(){
    setIO("");
    int A, B, n, m;
    cin >> A >> B >> n >> m;
    vector<int> a(n+1), b(m+1);
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    for (int i = 1; i <= m; i++)
        cin >> b[i];
    
    sort(a.begin(), a.end());
    sort(b.begin(), b.end());
    a.push_back(A);
    b.push_back(B);

    vector<pair<int, pair<int, int>>> edges;
    for (int i = 1; i < (int)a.size(); i++)
        for (int id = (i-1)*(m+1); id + 1 < i*(m+1); id++)
            edges.push_back({a[i] - a[i-1], {id, id+1}});
    
    for (int i = 1; i < (int)b.size(); i++)
        for (int id = i-1; id + m + 1 < (n+1)*(m+1)+(i-1); id += (m+1))
            edges.push_back({b[i] - b[i-1], {id, id+m+1}});
    
    sort(edges.begin(), edges.end());
    for (int i = 0; i <= (n+1)*(m+1); i++)
        rep[i] = i;

    ll ans = 0;
    for (auto& pp : edges)
        if (Find(pp.second.first) != Find(pp.second.second)){
            ans += pp.first;
            Union(pp.second.first, pp.second.second);
        }

    cout << ans << "\n";

    return 0;
}

Niestety nie przechodzi ostatniego testy ze względu na czas, co wydaje mi się dość dziwne skoro jego złożoność powinna być mniej więcej O(nm * log(nm)), próbowałem też w jakiś sposób wykorzystać to że wierzchołki tworzą tablicę dwuwymiarową z użyciem algorytmu Prima:

#include <bits/stdc++.h>
    
using namespace std;
 
#define ll long long
#define ull unsigned long long
 
const ll INF = 1e9 + 7, MAXN = 2000 + 7, MAXM = 2000 + 7;
 
void setIO(string s){
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    if (s != ""){
        freopen((s + ".in").c_str(), "r", stdin);
        freopen((s + ".out").c_str(), "w", stdout);
    }
}

vector<vector<int>> to_left(MAXN, vector<int>(MAXM)), to_down(MAXN, vector<int>(MAXM)), to_right(MAXN, vector<int>(MAXM)), to_up(MAXN, vector<int>(MAXM));
vector<vector<bool>> seen(MAXN, vector<bool>(MAXM));

ll prim(int n, int m){
    priority_queue<pair<int, pair<int, int>>, vector<pair<int, pair<int, int>>>, greater<pair<int, pair<int, int>>>> pq;
    pq.push({0, {0, 0}});
    ll ans = 0;

    while (!pq.empty()){
        pair<int, int> curr = pq.top().second;
        if (seen[curr.first][curr.second] == true){
            pq.pop();
            continue;
        }
        seen[curr.first][curr.second] = true;
        ans += pq.top().first;
        pq.pop();

        if (curr.second + 1 <= m && seen[curr.first][curr.second + 1] == false)
            pq.push({to_left[curr.first][curr.second+1], {curr.first, curr.second+1}});

        if (curr.second - 1 >= 0 && seen[curr.first][curr.second - 1] == false)
            pq.push({to_right[curr.first][curr.second-1], {curr.first, curr.second-1}});
        
        if (curr.first + 1 <= n && seen[curr.first+1][curr.second] == false)
            pq.push({to_down[curr.first+1][curr.second], {curr.first+1, curr.second}});
        
        if (curr.first - 1 >= 0 && seen[curr.first-1][curr.second] == false)
            pq.push({to_up[curr.first-1][curr.second], {curr.first-1, curr.second}});
    }

    return ans;
}

int main(){
    setIO("fencedin");
    int A, B, n, m;
    cin >> A >> B >> n >> m;
    vector<int> a(n+1), b(m+1);
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    for (int i = 1; i <= m; i++)
        cin >> b[i];
    
    sort(a.begin(), a.end());
    sort(b.begin(), b.end());
    a.push_back(A);
    b.push_back(B);

    for (int i = 1; i < (int)a.size(); i++)
        for (int j = 1; j < (m+1); j++)
            to_left[i-1][j] = to_right[i-1][j-1] = a[i] - a[i-1];

    
    for (int j = 1; j < (int)b.size(); j++)
        for (int i = 1; i < (n+1); i++)
            to_down[i][j-1] = to_up[i-1][j-1] = b[j] - b[j-1];
    
    cout << prim(n, m) << "\n";

    return 0;
}

Jednak to rozwiązanie jest jeszcze wolniejsze
Proszę o jakąś podpowiedź :D

EDIT:
Przyśpieszyłem kod, największym problemem był push_back() zamiast zadeklarowania miejsca, być może isnieje sprytniejsze rozwiązanie, przyśpieszony kod:

#include <bits/stdc++.h>

using namespace std;

#define ll long long
#define ull unsigned long long

const ll INF = 1e9 + 7, MAXN = (2000 + 1) * (2000 + 1) + 7;

void setIO(string s){
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    if (s != ""){
        freopen((s + ".in").c_str(), "r", stdin);
        freopen((s + ".out").c_str(), "w", stdout);
    }
}

int rep[MAXN], amount[MAXN], a[MAXN], b[MAXN];

int Find(int x){
    if (rep[x] != x)
        rep[x] = Find(rep[x]);
    return rep[x];
}

void Union(int x, int y){
    x = Find(x);
    y = Find(y);

    if (x == y)
        return;

    if (amount[x] < amount[y])
        swap(x, y);

    rep[y] = x;
    amount[x] += amount[y];
}

struct edge{
    int w;
    int x;
    int y;
};

bool cmp(const edge& e1, const edge& e2){
    return e1.w < e2.w;
}


int main(){
    setIO("fencedin");
    int A, B, n, m;
    cin >> A >> B >> n >> m;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    for (int i = 1; i <= m; i++)
        cin >> b[i];

    sort(a, a + n+1);
    sort(b, b + m+1);
    a[n+1] = A;
    b[m+1] = B;

    vector<edge> edges(MAXN * 2);
    int idx = 0;
    for (int i = 1; i < n+2; i++)
        for (int id = (i-1)*(m+1); id + 1 < i*(m+1); id++, idx++)
            edges[idx] = {a[i] - a[i-1], id, id+1};

    for (int i = 1; i < m+2; i++)
        for (int id = i-1; id + m + 1 < (n+1)*(m+1)+(i-1); id += (m+1), idx++)
            edges[idx] = {b[i] - b[i-1], id, id+m+1};

    sort(edges.begin(), edges.begin() + idx, cmp);
    for (int i = 0; i <= (n+1)*(m+1); i++){
        rep[i] = i;
        amount[i] = 1;
    }


    ll ans = 0;
    for (int i = 0; i < idx; i++){
        if (Find(edges[i].x) != Find(edges[i].y)){
            Union(edges[i].x, edges[i].y);
            ans += edges[i].w;
        }
    }
    cout << ans << "\n";

    return 0;
}