[AtCoder] ABC 155 D – Pairs

2020年12月15日

問題

方針

\( A_i = 0 \) となる個数を \(N_0 \) とし、\( A_i < 0 \) となる数列を新たに \( B \) とし、\( A_i > 0 \) となる数列を新たに \( C \) とします。このとき、\( B \) の要素数を \( N_b \) とし、\( C \) の要素数を \( N_c \) とします。つまり、\( N_0 + N_b + N_c = N \) となります。また、\( B, C \) は昇順に整列されてるものとします。

\( K \) 番目の値の符号

積が負となる組み合わせは、\( N_bN_c\) となるので、\( K \leq N_bN_c \) ならば、\( K \) 番目の値は負であることが分かります。次に、積が \( 0 \) となる組み合わせを \( g \) とすると、

\[ g = N_0(N_b + N_c) + \dfrac{N_0 (N_0 – 1)}{2}\]

となります。したがって、\( N_bN_c + g \leq K \) ならば \( K \) 番目の値は \( 0 \) となります。そして、\( N_bN_c + g > K \) ならば、\( K \) 番目の値は正となります。

\( K \) 番目の値が負のとき

整数 \( x \) 以下の積の組み合わせの数を \( f(x) \) とすると、\( f(x) \) は尺取り法を使うことで \( O(N) \) で計算できます。また、\( f(x) \) は単調増加なので二分探索によって、\( f(x) \geq K \) を満たす最小の \( x \) を求めます。ここで、\( B_0 \leq B_1 \leq \cdots \leq B_{N_b – 1} < 0\) と \( 0 < C_0 \leq C_1 \leq \cdots \leq C_{N_c – 1} \) であることに注意すると、次のことが言えます。\( B_{i + 1}C_{j} \leq x \) ならば、\( B_{i}C_{j} \leq x \) を満たします。したがって、尺取り法を用いて次のように数え上げます。まず初めに、\( B_{N_b – 1}C_j \leq x\) を満たす最小の \( j \) を求めます。その \( j \) を \( j_1 \) とすると、 \( N_c – j_1 \) 個の組み合わせがあることが分かります。次の、\( B_{N_b – 2}C_j \leq x \) を満たす最小の \( j \) は \( j_1 \) 以下であることが保証されます。また、その値を \( j_2 \) とすると、\( N_c – j_2 \) 個の組み合わせがあることが分かります。このようにして数え上げます。

\( K \) 番目の値が正のとき

\( K \) 番目の値が負のときと同様に尺取り法と二分探索を行います。このとき、\( B \) と \( C \) を分けて尺取り法を行います。まず、\( B \) について考えます。\( B_{i + 1}B_j \leq B_{i}B_j \) であることを利用します。まず初めに、\( B_0B_j \leq x \) を満たす最小の \( j \) を \( j_1 \) とします。このとき、\( j = 0 \) のペアはカウントしてはいけないので、\( \min(N_b – j_{1} – 1 , N_b – 1)\) の組み合わせがあります。つぎに、\( B_1B_j \leq x \) を満たす最小の \( j \) は、\( j_1 \) 以下から探せばよいので、その値を \( j_2 \) とします。このとき、\( j \leq 1 \) を満たすペアをカウントしてはいけないので、\( \min(N_b – j_2 – 1, N_b – 2)\) の組み合わせがあります。よって、\( B_iB_j  \leq x \) を満たす最小の \( j\) を \( j_{i – 1} \) とすると、\( \min(N_b – j_{i – 1}, N_b – i – 1) \) の組み合わせがあります。同様にして、\( C \) の方も尺取り法で数え上げます。

コード

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int main() {
    int N;
    ll K;
    cin >> N >> K;
    ll A[N];
    vector<ll> B;
    vector<ll> C;
    ll z = 0;
    for (int i = 0; i < N; i++) {
        cin >> A[i];
        if (A[i] == 0) z++;
        else if (A[i] < 0) B.push_back(A[i]);
        else C.push_back(A[i]);
    }
    ll nb = B.size();
    ll nc = C.size();
    ll f0 = B.size() * C.size();
    ll f1 = f0 + z * (B.size() + C.size()) + (z * (z - 1)) / 2;
    sort(B.begin(), B.end());
    sort(C.begin(), C.end());
    if (K <= f0) {
         ll l = -1e18 - 1ll;
         ll r = 0;
         while (r - l > 1) {
             ll m = (l + r) / 2;
             ll sum = 0;
             int j = nc - 1;
             for (int i = B.size() - 1; i >= 0; i--) {
                 while (j >= 0 && B[i] * C[j] <= m) j--;
                 sum += nc - 1 - j;
             }
              cout << sum << " " << m << "\n";
             if (sum < K) l = m;
             else r = m;
         }
         
         cout << r << "\n";
    } else if (K <= f1) {
        cout << "0\n";
    } else {
        ll l = 0;
        ll r = 1e18 + 1ll;
        while (r - l > 1) {
            ll m = (l + r) / 2;
            ll sum = f1;
            int j = nb - 1;
            for (int i = 0; i < nb - 1; i++) {
                while (j >= 0 && B[i] * B[j] <= m) j--;
                sum += min((int)nb - j - 1, (int)nb - i - 1);
            }
            
            j = 0;
            for (int i = nc - 1; i >= 1; i--) {
                while (j < nc && C[i] * C[j] <= m) j++;
                sum += min(i, j);
            }
            // cout << sum << " " << m << "\n";
            if (sum < K) l = m;
            else r = m;
        }
        cout << r << "\n";
    }
    return 0;
}