[yukicoder] No. 837 Noelちゃんと星々2

問題

方針

絶対値の和の最小化問題

配列 \(a \) に対して、次の関数を最小化することを考えます。

\[ \sum_{k=1}^{N} |x – a_k|\]

答えから言うと、\( x \) が \( a \) の中央値のとき最小となります。証明は下のサイトから参照できます。

The Median Minimizes the Sum of Absolute Deviations (The L1 Norm)

データ数が偶数のときも、中央値の要素番号は、\( \lfloor \dfrac{N}{2} \rfloor\) として大丈夫です。

全探索

最終的に配列は二種類の値を取ることになるので、全ての分け方を考えます。配列の順序は関係ないので、昇順に並んでいるものとします。また、要素が \( 0 \) から始めるものとします。整数 \( k \ ( 0 \leq k \leq N – 2) \) を用いて次の二つの集合に分けます。

\[
\begin{eqnarray}
A &=& (Y_0, Y_1, \cdots , Y_k)\\
B &=& (Y_{k+1}, Y_{k+2}, \cdots , Y_{N-1})
\end{eqnarray}\]

このとき、\( A \) の中央値 \(m_a \) は \( i_a = \dfrac{k}{2}\) として、\( m_a = Y_{i_a} \) となり、\( B \) の中央値 \(m_b \) は \( i_b = \dfrac{N+k}{2}\) として、\( m_b = Y_{i_b} \) となります。

また、\( Y \) の累積和を \( S_i \) を

\[S_i = Y_0 + Y_1 + \cdots + Y_{i}\]

とします。これらを用いて、\( i = k \) のときのコストを \( c_k \) とすると、次のようになります。

\[\begin{eqnarray}
c_k &=& \sum_{j=0}^{k} |Y_j – m_a| + \sum_{j=k+1}^{N-1} |Y_j – m_b|\\
&=& \sum_{j=0}^{i_a} (m_a – Y_j) + \sum_{j=i_a + 1}^{k}(Y_j – m_a) + \sum_{j=k+1}^{i_b} (m_b – Y_j) + \sum_{j=i_b+1}^{N-1} (Y_j – m_b)\\
&=& (i_a + 1)Y_{i_a} – S_{i_a} + (S_k – S_{i_a}) – (k – i_a)Y_{i_a} + (i_b – k)Y_{i_b} – (S_{i_b} – S_k) + (S_{N-1} – S_{i_b}) – (N-i_b)Y_{i_b}
\end{eqnarray}\]

となります。

コード

提出したコード

全探索

  vector<ll> Y(N);
  for (int i = 0; i < N; i++) {
    cin >> Y[i];
  }
  if (N == 2) {
    if (Y[0] == Y[1]) {
      cout << 1 << "\n";
    } else {
      cout << 0 << "\n";
    }
    return 0;
  }
  sort(Y.begin(), Y.end());
  if (Y[0] == Y[N - 1]) {
    cout << "1\n";
    return 0;
  }
  if (Y[0] == Y[N - 2]) {
    cout << "0\n";
    return 0;
  }
  ll ans = pow(10, 15);
  ll S[N + 1]{};
  for (int i = 0; i < N; i++) {
    S[i + 1] += S[i] + Y[i];
  }
  for (int i = 0; i < N - 1; i++) {
    int m1 = i / 2;
    int m2 = (N + i) / 2;
    ll cost = (m1 + 1) * Y[m1] - S[m1 + 1] + S[i + 1] - S[m1 + 1] - (i - m1) * Y[m1];
    cost += (m2 - i) * Y[m2] - (S[m2 + 1] - S[i + 1]) + (S[N] - S[m2 + 1]) - (N - 1 - m2) * Y[m2];
    ans = min(ans, cost);
  }
  cout << ans << "\n";