[AtCoder] AGC 039 A – Connection and Disconnection

2019年12月25日

問題

方針

同じ文字が連続する文字列のコスト

長さが \( n \) の単一の文字からなる文字列をどの隣り合う \( 2 \) 文字を相異なる文字列にするために必要なコストは、

\[ \lfloor \dfrac{n}{2} \rfloor\]

となります。例えば、\( n = 3 \) のとき、aaa は aba とするとコストは \(1 \) で、\( n = 4 \) のとき、aaaa は abab とするとコストは \( 2 \) となります。

連結された文字列

\( S \) の長さを \( n \) とします。\( S \) の \( i \) 番目の文字を \( c_i \) とすると、\( S = c_1c_2\cdots c_n \) と表現できます。

\( c_1 \neq c_n\) のとき

\( T \) は \( S \) を \(K \) 回連結させてできた文字列であり、\( c_1 \neq c_n \) であることから、連結によって先頭と末尾の間に新たに連続する文字列は出現しません。よって、\( S \) のコストを \( K \) 倍したものが \( T \) のコストとなります。

\( c_1 = c_n \) のとき

ここで、\( S_l \) を \( S \) の先頭から \( c_1 \) が連続してできる文字列とし、\( S_r \) を \( S \) の末尾から \( c_n \) が連続してできる文字列とします。次に、\( S_c \) を

\[S = S_lS_cS_r\]

を満たす文字列とします。例えば、\( S = aabbccdd\)  では、\( S_l = aa \) , \( S_c = bbcc\), \( S_r = dd \) となり、\( S = aaa\) では、\( S_l = aaa \) となり、\( S_c, S_r \) は空文字となります。

\( T \) は、\(T = S_lS_cS_r S_lS_cS_r \cdots S_lS_cS_r \) と表現できます。ここで、\( f(S) \) を文字列 \( S \) のコストとすると、

\[ f(T) = f(S_l) + Kf(S_c) + (K – 1)f(S_lS_r) + f(S_r) \]

と計算することができます。 これは、連結によって新たに \( K \) 個の \( S_c \) のコストがかかり、\(S_lS_c\) という文字列が \(K – 1\) 個出現することを考えると分かりやすいと思います。注意として、\( S_lS_r = S_rS_l \) であり、\( S_lS_r \) は単一の文字からなる文字列です。また、\( S_c \) は任意の文字列であり、空文字となることもあるので、コストが掛からない場合もあります。

コード

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int main() {
    string S;
    ll K;
    cin >> S >> K;
    ll n = S.length();
    if (n == 1) {
        cout << K / 2 << "\n";
        return 0;
    }
    ll l = 0;
    ll r = 0;
    for (int i = 0; i < n; i++) {
        if (S[0] == S[i]) {
            l++;
        } else {
            break;
        }
    }
    if (l == n) {
        cout << n * K / 2 << "\n";
        return 0;
    }
    for (int i = n - 1; i >= 0; i--) {
        if (S[n - 1] == S[i]) {
            r++;
        } else {
            break;
        }
    }
    if (l + r == n) {
        // S[0] != S[n - 1]
        ll ans = l / 2 * K + r / 2 * K;
        cout << ans << "\n";
        return 0;
    }
    ll ans = 0;
    ll c = 1;
    for (int i = l; i < n - r; i++) {
        if (S[i] == S[i + 1]) {
            c++;
        } else {
            ans += c / 2 * K;
            c = 1;
        }
    }
    ans += c / 2 * (K - 1);
   
    if (S[0] == S[n - 1]) {
        ans += l / 2 + (l + r) / 2 * (K - 1) + r / 2;
    } else {
        ans += l / 2 * K + r / 2 * K;
    }
    //cout << l << " " << r << "\n";
    cout << ans << "\n";
    return 0;
}