JOI公式テキスト『たくさんの数』模範解答プログラム完全理解のための解説です。

ここではJOI公式テキストp127ビット演算 たくさんお数のAtcoderに掲載されていた模範解答プログラムを完全に理解できる詳しく分かりやすい解説を載せています。
このプログラムはC++で作られていたので、そのまま載せています。

Python言語版が欲しい場合はBingやChatGPTなどの生成系AIに変換してもらうとあっという間に入手できるでしょう。

まずはテキストの模範解答プログラムです

#include <iostream>
#include <string>
using namespace std;
typedef long long ll;
string S;
int N;
typedef long long ll;
string S;
int N;
#define rep(i, a, n) for (int i = (a); i < (n); i++)

int main() { cin >> S;
N = S.length();

ll ans = 0;
rep(i, 0, 1 << (N - 1)) {
ll sm = 0;
ll a = S[0] - '0';
rep(j, 0, N - 1) {
if (i & (1 << j)) {
sm += a;
a = 0;
}
a = a * 10 + S[j + 1] - '0';
}
sm += a;
ans += sm;
}
cout << ans << endl;
}

模範解答プログラムの解説です

このプログラムのif (i & (1 << j)) は何をしているのか?

入力例  125で見てみよう。

125の間に+記号を記入する組み合わせは

125

1+25

12+5

1+2+5

です

なので、1の後に+を置くかどうか?で2通り。2の後に+を置くかどうか?で2通り。5の後には置かないので結局2の2乗通りとなるので4通りです。

これがrep(i, 0, 1 << (N - 1))   N=3なので1 << (N - 1)=2の2乗通り。

◆次に125の各数字の横に+を入れたら、その数を総和に足します。

1*25なら総和ansに1を足す。

ところが12*5ならどうでしょう?

12をansに足しますね。では12はどうやったら求まるか?

12を求めるのにsmといった、各数字列の値を求める変数作ります。それで

1回目 i=0 j=0 これは下で&=論理積を計算するので2進数になります。

(i & (1 << j))

論理積は 001 & 001 と両方の2進数が同じでなければ 答001 つまり1でtureを返さない。

1回目は両方とも0なので0

2回目 i=0 j=1 これも 答は0 でfalse

これで内側のrep()から飛び出し、外側のrep()でiが1増えて1になります。

3回目 i=1 j=0 でこれも 答は0 でfalse jの値はその数字の横に+を置くかどうかを示します。jが1なら置く。

4回目 i=1 j=1 でこれで 1+と1の隣に+記号を置くことになるので、文字列の最初の値1がaに入っているので、

ll a = S[0] - '0'; (ll はLong Long型の変数)

aをsmに記憶させる。

★smに足したのでa=0;とする。

i=2になった場合どうなるか?2進数で表すと 010である。

最初の右端の値が0なので、jの値が010になると一致する。

その前にjの値が000 001の時に12を足すために

aの値に10掛けて次の2の値がs[1]に入っているので

anに10掛けて、2と足してaを12にする。

a = a * 10 + S[j + 1] - '0';

その桁のif文での判断と処理が終了するごとにansにmsを足す。

これはiの値が1でjの値が1の時とする。

◆それからiが2になるので、jが2になった時

12をansに足します。

◆if (i & (1 << j)) {
sm += a;
a = 0;

jの値は

rep(j, 0, N - 1) {

で今回は3まで動く。

入力例が125の場合、N=3 jはoからN-1 つまり2まで増える。

となり、1 << jが取る2進数は 001 010 100

i=0 000: 125

i=1 001: 1+25

i=2 010: 12+5

i=3 011: 1+2+5

でi=1の場合 2進数は001 で 1 << j の最初の2進数 001と論理積をとって001となり

trueとなる。のでaに入ってる1をsmに記憶させaを0にする。

つまり、1 << jが取る2進数は 001 010 100  は

001は一桁目で+を入れて一桁目をansに足す。

同様に

010は二桁目で+を入れて二桁目をansに足す。

100は三桁目で+を入れて三桁目をansに足す。

◆さて、i=0 000: 125の場合はansには125を足さなくてはならない。それはどういうメカニズムか?

iは0でjは0から2まで変わる。

i=0の場合、一度もif (i & (1 << j)) がtrueになることはない。

なので 初回で a = a * 10 + S[j + 1] - '0';

の計算でaが12となる。

さらにjが1となりa = a * 10 + S[j + 1] - '0'; でaは125となる。

さらにaが2となるがa = a * 10 + S[j + 1] - '0';でs[3]は何も入っていない。0なのでaは125

これが 下の文でansに加算される。

sm += a;
ans += sm;

★if (i & (1 << j)) がtrueになるi=0の場合

初回 1 << j が001になりtrueとなる。

smにaを足し、aを0する。

sm += a;
a=0;

その次にa = a * 10 + S[j + 1] - '0';を通るので

aが0なのでaは2となる。

j=1でaは25となる。

j=2でもaは25のままで、

sm += a;
ans += sm;

でsmに1が入っていて、aは25なのでansに1と25が足ささる。

別解1 C++言語でpow()関数を利用して作りました!

今解説したプログラムは2進数のビット演算を使用してます。
このプログラムを理解するには、またビット演算を使用してプログラムを作りにはテキストp84『ビット演算の基本』の内容を2進数も含めて完全に理解していなければなりません。

独学でも腕を上げていける人は、模範解答プログラムのテキストに載っている解説をみて、ネットでビット演算について調べたりp84の基本解説を見つけ出して理解を図って、完全に理解してくと思います。

ところで、この課題はビット演算を使用しなくても、何ということも無く以下のようプログラムを作れてしまいます。
しかも、計算量の多くならず、Atcoderでも合格しました。

ですから、この課題はビット演算を使用しなくてはならない。と思い込むと得点力が伸びないのです。

この別解ではpow()関数を使用しました。 言語はC++を使用しました。


#include <iostream>
#include <string>
#include <vector>
#include <cmath>
using namespace std;

typedef long long ll;
string S;
int N;
#define rep(i, a, n) for (int i = (a); i < (n); i++)

int main() { cin >> S;
N = S.length();

ll ans = 0;
vector split(N - 1, false);
for (int i = 0; i < pow(2, N - 1); i++) {

for (int j = 0; j < N - 1; j++) {
split[j] = ((i / (int)pow(2, j)) % 2 == 1);
}

ll sm = 0;
ll a = S[0] - '0';
rep(j, 0, N - 1) {
if (split[j]) {
sm += a;
a = 0;
}
a = a * 10 + S[j + 1] - '0';
}
sm += a;
ans += sm;
}
cout << ans << endl;
}

別解2 pow()関数とシフト演算を使わないバージョンのプログラム

今度はpow()関数とシフト演算を使わないバージョンのプログラムを作りました。ここでは、自分で2のべき乗を計算するための繰り返し処理を書きました。

#include <iostream>
#include <string>
#include <vector>
#include <cmath>
using namespace std;

typedef long long ll;
string S;
int N;
#define rep(i, a, n) for (int i = (a); i < (n); i++) 

int main() { cin >> S;
N = S.length();

ll ans = 0;
vector split(N - 1, false);
for (int i = 0; i < (1 << (N - 1)); i++) {

for (int j = 0; j < N - 1; j++) {
int power_of_two = 1;
for (int k = 0; k < j; k++) {
power_of_two *= 2;
}
split[j] = ((i / power_of_two) % 2 == 1);
}

ll sm = 0;
ll a = S[0] - '0';
rep(j, 0, N - 1) {
if (split[j]) {
sm += a;
a = 0;
}
a = a * 10 + S[j + 1] - '0';
}
sm += a;
ans += sm;
}
cout << ans << endl;
}