くろたんく雑記帳

日常とか、わんちゃんとか、機械学習とか、競プロとか、

MENU

Python3で実装 問題解決力を鍛える!アルゴリズムとデータ構造【第3章】設計技法(1):全探索

「問題解決力を鍛える!アルゴリズムとデータ構造」の章末問題をPython3で実装していくシリーズ。(続くかわからんが)

今回は【第3章】設計技法(1):全探索について扱う

問題解決力を鍛える!アルゴリズムとデータ構造 (KS情報科学専門書)

問題解決力を鍛える!アルゴリズムとデータ構造 (KS情報科学専門書)

  • 作者:大槻 兼資
  • 発売日: 2020/10/02
  • メディア: 単行本(ソフトカバー)


内容

  • Python3でやっている。
  • 簡単に自分の理解を整理する用。
  • そこまで厳密ではない可能性が高い。
  • 記事を書いたら、書評の方に記事を載せていく。
  • コード等はGitHubにおいた。 github.com

【第3章】設計技法(1):全探索の概要

ここの章で特に大変というか慣れないときついのはbitの考え方。組み合わせ探索のところがちょっと大変。

章末問題

3.1 線形探索で一致する最大のindex探し

こんな入力を想定して

10 7
3 5 8 9 7 10 22 1 2 100
  • 配列{\displaystyle a}を左から見ていって、{\displaystyle v}と一致する値があるか確認する。
  • 一致したら、indexを代入とする。
def main():
    # 入力受け取り
    N, v = map(int, input().split())
    a = [int(x) for x in input().split()]
    # 初期化としてあり得ない数値を代入
    found_id = -1
    for i in range(N):
        if a[i] == v:
            found_id = i
    print(found_id)
if __name__ == '__main__':
    main()
3.2 線形探索で一致する数をカウント

こんな入力を想定して

10 7 
3 5 8 9 7 7 7 7 7 1

3.1を少し変えて、 * 配列{\displaystyle a}を左から見ていって、{\displaystyle v}と一致するか確認する。 * 値が一致したら、カウントをインクリメントする。

def main():
    # 入力受け取り
    N, v = map(int, input().split())
    a = [int(x) for x in input().split()]
    # 線形探索
    cnt = 0
    for i in range(N):
        if a[i] == v:
            cnt += 1
    print(cnt)
if __name__ == '__main__':
    main()
3.3 線形探索で最も小さいとその次に小さい探す

こんな入力を想定して

10 7
3 5 8 9 7 10 22 1 2 100
  • 二番目に小さい値(second_minv)をどうとるかで、配列{\displaystyle a}を左から見ていく時に、最小値(minv)よりも小さい値を見つけたか、2番目に小さい値を見つけたかで場合分けする。
def main():
    # 入力受け取り
    N = int(input())
    a = [int(x) for x in input().split()]
    # 初期化として無限大
    minv = float('inf')
    second_minv = float('inf')
    for i in range(N):
        # 最小値を見つけた場合
        if a[i] < minv:
            # minvを2番目に
            second_minv = minv
            # 探索した値をminvに
            minv = a[i]
        # 2番目に小さい値を見つけた場合
        elif a[i] < second_minv:
            second_minv = a[i]
    print(second_minv)
if __name__ == '__main__':
    main()
3.4 maxとmin

こんな入力を想定して

10 7
3 5 8 9 7 10 22 1 2 100
  • for文で探索するというよりは、関数的にmaxとminを求めて、引き算する。
  • 別の方法として、maxvとminvを保持しながらforで回してもいい。
def main():
    # 入力受け取り
    N = int(input())
    a = [int(x) for x in input().split()]
    # 最大・最小を求めて引き算
    # max, minの計算量はO(n)
    maxv = max(a)
    minv = min(a)
    ans = maxv - minv
    print(ans)
if __name__ == '__main__':
    main()
3.5 半分にし続けられる回数

こんな入力を想定して

6
382253568 723152896 37802240 379425024 404894720 471526144

リスト内包表記で2で割り切れるかどうかでboolを返すようにして、all()ですべてがTrueならカウントをインクリメントして、リストの中身を半分にしたものを再度代入と言うのを繰り返す。配列で扱っている時は、リスト内包表記が便利。

def main():
    # 入力受け取り
    N = int(input())
    a = [int(x) for x in input().split()]
    cnt = 0
    # aがすべて2で割り切れている(全部Trueなら回り続ける)
    while all([True if i % 2 == 0 else False for i in a]):
        cnt += 1
        a = [i / 2 for i in a]
    print(cnt)
if __name__ == '__main__':
    main()
3.6 2重for文で3つ変数の全探索

こんな入力を想定して

2 2

{\displaystyle O(n^2)}で実装しろってことなので、{\displaystyle x, y, z}の3重のfor文はできない。そこで、{\displaystyle z = n - x - y}であることを用いて、zが{\displaystyle 0 \leq z \leq}を満たすものをカウントすればいい。そうすれば2重のfor文で済むので{\displaystyle O(n^2)}で実装できた。

def main():
    k, n = map(int, input().split())
    cnt = 0
    # 0-kまで探索 
    for x in range(k+1):
        # 0-kまで探索
        for y in range(k+1):
            # z = n - x - yなので0以上k以下ならカウントする
            if 0 <= n - x - y <= k:
                cnt += 1
    print(cnt)
if __name__ == '__main__':
    main()
3.7 bitで全探索

こんな入力を想定して

125

いったんの目標として、以下のような式を作って評価して、その合計を足し合わせればいい。
{\displaystyle 125}
{\displaystyle 1 + 25}
{\displaystyle 12 + 5}
{\displaystyle 1+ 2 + 5}
どのように作るかと言うと、数字の文字の間に'+'を入れ込むと考える。数字の文字数を{\displaystyle n}とすると{\displaystyle n-1}個間があって、そこに'+'が入るかどうかの2通りだから{\displaystyle 2^{n-1}}パターンあることになる。
そして、そのパターンを作る時に色々やり方はあるが、書籍に倣うと、bitを用いることで全パターンを生み出す。上記の例だと、文字の間が2つなので、 [0, 0] → {\displaystyle 125}
[1, 0] → {\displaystyle 1 + 25}
[0, 1] → {\displaystyle 12 + 5}
[1, 1] → {\displaystyle 1+ 2 + 5}
の4パターンで、1が立っているところに'+'が入っている状態が、目標とする状態であることがわかる。それで、どう作るかと言うと、bit演算子を使いながら実装する。

  • 最初のfor文の範囲は{\displaystyle 2^{n-1}}でいいんだけど、あえてbit演算子で書く以下のようになる。 1 << (n-1) これはなんぞってなるんだけど、砕くと、

1 << i の意味は
2進数表現で右からi番目(一番右は0番目として)のみ1がである値を10進数表現に戻した値。

なので例えば、今回だと、{\displaystyle n-1 = 2}なので、右から2番目のみが1である値(一番右は0番目として)は、'0b100' なので、これを10進数表現に戻すと{\displaystyle 2^2 = 4}である。

  • 肝心の全パターンのbitを立たせる方法は以下である。
for bit in range(1 << (n-1)):
    for i in range(n-1):
        if ((bit>>i)&1) == 1:
        ~~~

ここで、(bit>>i)&1って言う初心者泣かせの演算が出てくるんだけど。砕くと、

(bit>>i) の意味は
bit >> i はbitを2進数表現でi個右にずらした(右からi個削った)値を10進数表現に戻した値

例えば、
4 >> 1は '0b100'→'0b10'っていうことで10進数表現に戻されて、2となる。
3 >> 1は '0b11'→'0b1'っていうこと10進数表現に戻されて、1となる。 これは実際にコンソールで叩いてみながらやると良い。確認で2進数表現にしたければbin()でできるので。

ほんでもう一つ、(bit>>i)&1&1ってなんぞっていうことなんだけど、砕くと、

a & bは、
aとbを2進数表現でみた時に、同じ桁がどちらも1の時に1を返しその他の場合は0をかえした値を10進数表現に戻した値

なので、x & 1っていうのはxを2進数表現でみた時に一番右が1なら1そうでないなら0って言うことになる。これらを全て使うと、全パターンのbitをたてられる。 実装上はフラグの立ち方をわかりやすい形でにみることはできないので、あえて配列に残るように書いておいた。(bitsっていうリスト)

  • あとは出来上がった式をeval()で評価していけばいい。
  • もしくは、'+'でsplitして配列にして足せばいい。(この場合は'+'っていう文字はただの飾りにはなる)
def main():
    # 入力受け取り
    s = input()
    n = len(s)
    ans = 0
    # 1 << i は2進数表現で右からi番目(一番右は0番目)のみ1がである値のことなので
    # 2**iと同義
    # 数文字の間に'+'を入れるかどうかなので2**(n-1)パターンある
    for bit in range(1 << (n-1)):
        formula = s[0]
        # bitの立ち方確認用
        bits = []
        for i in range(n-1):
            bits.append((bit>>i)&1)
            if ((bit>>i)&1) == 1:
                formula += '+'
            formula += s[i+1]
        # 状況確認用
        # print(bits)
        # print(formula)
        # split sum方式
        # ans += sum(map(int, formula.split('+')))
        # eval方式
        ans += eval(formula)
    print(ans)
if __name__ == '__main__':
    main()

まとめ

  • 全探索は計算量が多くなることがあるので制約に注意。
  • maxvとかminvで一時的にそこまでで最大とか最小っていう値をとっておいて比較しながらforを回すように使うことが多い。
  • bit演算は覚えることたくさんだけど、使えるようになると便利っぽい。(わかんなくなって、itertoolsで代用しちゃうけど)

終わりに

結構ちゃんと書くと大変。まとめる方が時間かかるなぁ。続けるか未定。

問題解決力を鍛える!アルゴリズムとデータ構造 (KS情報科学専門書)

問題解決力を鍛える!アルゴリズムとデータ構造 (KS情報科学専門書)

  • 作者:大槻 兼資
  • 発売日: 2020/10/02
  • メディア: 単行本(ソフトカバー)