AcWing_1221 四平方和(哈希/二分、枚举、sort函数的重载运算符写法)

AcWing

1221. 四平方和

哈希:枚举cd,计算c^2 + d^2,同一个和,只保留最先出现的cd(为了保证字典序优先)。这样,再按字典序枚举ab,对于一组ab,如果n - a^2 + b^2存在,那么其对应的abcd就是题解。

这样做的正确性在于:题解一定存在,是按字典序排列的,按字典序枚举ab,一定能枚举到题解的ab;那么此时题解的cd也一定就是此时枚举到的ab对应的cd,因为cd也是按字典序优先的。

#pragma GCC optimize(3,"Ofast","inline")
#include <bits/stdc++.h>
using namespace std;
int n;
const int N = 2240 * 2240;
int cc[N], dd[N];
bool flag[N];
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n;
    for(int c = 0; c * c <= n; c++)
        for(int d = c; c * c + d * d <= n; d++){
            int t = c * c + d * d;
            if(!flag[t]){
                flag[t] = true;
                cc[t] = c, dd[t] = d;
            }
        }
    for(int a = 0; a * a <= n; a++)
        for(int b = a; a * a + b * b <= n; b++){
            int t = n - a * a - b * b;
            if(flag[t]){
                cout << a << ' ' << b << ' ' << cc[t] << ' ' << dd[t] << "\n";
                return 0;
            }
        }
    return 0;
}

二分:使用结构体维护所有枚举出的c^2 + d^2cd,并按照c^2 + d^2优先(小)于c优先(小)于d排序。这样,按字典序枚举ab,搜索结构体数组中sn - a^2 + b^2的最左端点(如果存在,否则继续下一次枚举),其对应的cd即为题解中的cd.

#pragma GCC optimize(3,"Ofast","inline")
#include <bits/stdc++.h>
using namespace std;
int n;
const int N = 2240 * 2240;
struct Sum{
    int s, c, d;
    bool operator < (const Sum &t) const{
        if(s != t.s) return s < t.s;
        if(c != t.c) return c < t.c;
        return d < t.d;
    }
}cdsum[N];
int pt = 0;
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n;
    for(int c = 0; c * c <= n; c++)
        for(int d = c; c * c + d * d <= n; d++){
            int t = c * c + d * d;
            cdsum[pt++] = {t, c, d};
        }
    sort(cdsum, cdsum + pt);
    for(int a = 0; a * a <= n; a++)
        for(int b = a; a * a + b * b <= n; b++){
            int t = n - a * a - b * b;
            int l = 0, r = pt - 1;
            while(l < r){
                int mid = l + r >> 1;
                if(cdsum[mid].s >= t) r = mid;
                else l = mid + 1;
            }
            if(cdsum[r].s == t){
                cout << a << ' ' << b << ' ' << cdsum[r].c << ' ' << cdsum[r].d << "\n";
                return 0;
            }
        }
    return 0;
}