クルトンのプログラミング教室

Pythonの使い方やPythonを使った競技プログラミングの解法などを解説しています。

Pythonで理解する蟻本(プログラミングコンテストチャレンジブック)

 

この記事のコンセプト

プログラミングコンテストチャレンジブック(以下蟻本)は競技プログラミングの参考書として圧倒的な知名度を誇っており、まさに競技プログラマーにとっての必需品といえます。

その一方で、蟻本のコードはすべてC++で書かれており、Python競技プログラミングをしている人にとってはコードの理解が困難なものになっています。

蟻本を理解するためにC++を学ぶのも一つの手ではありますが、それだけで多大な労力が必要となってしまいます。

そこでこの記事では、Pythonコーダーが蟻本を理解する際の補助となるように、蟻本のコードをできるだけ忠実にPythonで書き直したものを記載しようと思います。

この記事が、Pythonコーダーが蟻本を理解する一助となれば幸いです。


記事一覧

1 いざチャレンジ!でもその前に―準備編

1-6 気楽にウォーミングアップ

kuruton.hatenablog.com

kuruton.hatenablog.com

Pythonで理解する蟻本「3-3 バブルソートの交換回数」(p.162)

この記事は「プログラミングコンテストチャレンジブック第2版」(蟻本)の
「3-3 バブルソートの交換回数」(p.162)
のコードをPythonで書き直したものとなっています。

入力

n\\a_1\,…\,a_n

入力例



4
3 1 4 2


解答

# 入力
n = int(input())
a = list(map(int,input().split()))

# BITのソースコード
class BinaryIndexedTree:
    def __init__(self, n):
        self.n = n
        self.bit = [0] * (n + 1)
        
    def sum(self, i):
        s = 0
        while i > 0:
            s += self.bit[i]
            i -= i & -i
        return s

    def add(self, i, x):
        while i <= self.n:
            self.bit[i] += x
            i += i & -i

BIT = BinaryIndexedTree(n)
ans = 0
for j in range(n):
    ans += j - BIT.sum(a[j])
    BIT.add(a[j], 1)
print(ans)

Pythonで理解する蟻本「3-3 BITの実装」(p.161)

この記事は「プログラミングコンテストチャレンジブック第2版」(蟻本)の
「3-3 BITの実装」(p.161)
のコードをPythonで書き直したものとなっています。

蟻本のコード

# [1, n]
n = int(input())
bit = [0] * (n + 1)

def sum_(i):
    s = 0
    while i > 0:
        s += bit[i]
        i -= i & -i
    return s

def add(i, x):
    while i <= n:
        bit[i] += x
        i += i & -i

classを使ったコード

class BinaryIndexedTree:
    def __init__(self, n):
        self.n = n
        self.bit = [0] * (n + 1)
        
    def sum(self, i):
        s = 0
        while i > 0:
            s += self.bit[i]
            i -= i & -i
        return s

    def add(self, i, x):
        while i <= self.n:
            self.bit[i] += x
            i += i & -i

Pythonで理解する蟻本「3-3 Crane(POJ 2991)」(p.156)

この記事は「プログラミングコンテストチャレンジブック第2版」(蟻本)の
「3-3 Crane(POJ 2991)」(p.156)
のコードをPythonで書き直したものとなっています。

入力

N\,C\\L_1\,…\,L_N\\S_1\,…\,S_C\\A_1\,…\,A_C

入力例1



2 1
10 5
1
90

入力例2



3 2
5 5 5
1 2
270 90


解答

import sys
sys.setrecursionlimit(4100000)
from math import sin, cos, pi

ST_SIZE = 1 << 15 - 1

# 入力
N, C = map(int,input().split())
L = list(map(int,input().split()))
S = list(map(int,input().split()))
A = list(map(int,input().split()))

# セグメント木のデータ
vx = [0] * ST_SIZE    #各節点のベクトル
vy = [0] * ST_SIZE
ang = [0] * ST_SIZE

# 角度の変化を調べるため、現在の角度を保存しておく
prv = [0] * N


# セグメント木を初期化する
# kは接点の番号、l, rはその節点が[l, r)に対応づいていることを表す
def init(k, l, r):
    ang[k] = vx[k] = 0.0
    if r - l == 1:
        # 葉
        vy[k] = L[l]
    else:
        # 葉でない節点
        chL = k * 2 + 1
        chR = k * 2 + 2
        init(chL, l, (l + r) // 2)
        init(chR, (l + r) // 2, r)
        vy[k] = vy[chL] + vy[chR]

# 場所sの角度がaだけ変更になった
# vは節点の番号、l, rはその節点が[l, r)に対応づいていることを表す
def change(s, a, v, l, r):
    if s <= l:
        return
    elif s < r:
        chL = v * 2 + 1
        chR = v * 2 + 2
        m = (l + r) // 2
        change(s, a, chL, l, m)
        change(s, a, chR, m, r)
        if s <= m:
            ang[v] += a
        
        s = sin(ang[v])
        c = cos(ang[v])
        vx[v] = vx[chL] + (c * vx[chR] - s * vy[chR])
        vy[v] = vy[chL] + (s * vx[chR] + c * vy[chR])

# 初期化
init(0, 0, N)
for i in range(1, N):
    prv[i] = pi

# 各クエリを処理
for i in range(C):
    s = S[i]
    a = A[i] / 360.0 * 2 * pi    # ラジアンに直す
    
    change(s, a - prv[s], 0, 0, N)
    prv[s] = a
    
    print("%.2f %.2f" % (vx[0], vy[0]))

Pythonで理解する蟻本「3-3 セグメント木によるRMQの実装」(p.155)

この記事は「プログラミングコンテストチャレンジブック第2版」(蟻本)の
「3-3 セグメント木によるRMQの実装」(p.155)
のコードをPythonで書き直したものとなっています。

コード

import sys
sys.setrecursionlimit(4100000)

MAX_N = 1 << 17
INT_MAX = (1 << 31) - 1

# セグメント木を持つグローバル配列
n = int(input())    # セグ木のサイズ
dat = [0] * (2 * MAX_N - 1)

# 初期化
def init(n_):
    # 簡単のため、要素数を2のべき乗に
    n = 1
    while n < n_:
        n *= 2
    
    # すべての値をINT_MAXに
    for i in range(2 * n - 1):
        dat[i] = INT_MAX
    return n
n = init(n)

# k番目の値(0-indexed)をaに変更
def update(k, a):
    # 葉の節点
    k += n - 1
    dat[k] = a
    # 登りながら更新
    while k > 0:
        k = (k - 1) // 2
        dat[k] = min(dat[k * 2 + 1], dat[k * 2 + 2])

# [a, b)の最小値を求める
# 後ろのほうの引数は、計算の簡単のための引数。
# kは節点の番号、l, rはその節点が[l, r)に対応づいていることを表す。
# したがって、外からはquery(a, b, 0, 0, n)として呼ぶ。
def query(a, b, k, l, r):
    # [a, b)と[l, r)が交差しなければ、INT_MAX
    if r <= a or b <= l:
        return INT_MAX
    
    # [a, b)と[l, r)を完全に含んでいれば、この節点の値
    if a <= l and r <= b:
        return dat[k]
    else:
        # そうでなければ、2つの子の最小値
        vl = query(a, b, k * 2 + 1, l, (l + r) // 2)
        vr = query(a, b, k * 2 + 2, (l + r) // 2, r)
        return min(vl, vr)

Python版 AtCoder Library (Fenwick Tree)

この記事では、AtCoder Library (ACL)のfenwicktreeをPythonで書き直したものを公開しています。

Fenwick Tree

長さ N の配列に対し、

  • 要素の1点変更
  • 区間の要素の総和

O(log\,N) で求めることが出来るデータ構造です。

コード

class fenwick_tree:
    def __init__(self, n):
        self.n = n
        self.bit = [0] * (n + 1)
    
    def add(self, p, x):
        p += 1
        while p <= self.n:
            self.bit[p] += x
            p += p & -p
    
    def sum_(self, p):
        s = 0
        while p > 0:
            s += bit[p]
            p -= p & -p
        return s
    
    def sum(self, l, r):
        return sum_(r) - sum_(l)

コンストラク

fw = fenwick_tree(n)
  • 長さnの配列a_0,a_1,…,a_{n-1}を作ります。初期値はすべて0です。

計算量

  • O(n)

add

fw.add(p, x)

a[p] += x を行う

制約

  • {0}\leq{p}<{n}

計算量

  • O(log\,n)

sum

fw.sum(l, r)

a[l] + a[l + 1] + ... + a[r - 1] を返す。

制約

  • {0}\leq{l}\leq{r}\leq{n}

計算量

  • O(log\,n)

使用例

AC code of https://atcoder.jp/contests/practice2/tasks/practice2_b

############################################

class fenwick_tree:
    def __init__(self, n):
        self.n = n
        self.bit = [0] * (n + 1)
    
    def add(self, p, x):
        p += 1
        while p <= self.n:
            self.bit[p] += x
            p += p & -p
    
    def sum_(self, p):
        s = 0
        while p > 0:
            s += self.bit[p]
            p -= p & -p
        return s
    
    def sum(self, l, r):
        return self.sum_(r) - self.sum_(l)

############################################

N, Q = map(int,input().split())
a = list(map(int,input().split()))

fw = fenwick_tree(N)

for i in range(N):
    fw.add(i, a[i])

for _ in range(Q):
    q = list(map(int,input().split()))
    if q[0] == 0:
        fw.add(q[1], q[2])
    else:
        print(fw.sum(q[1], q[2]))

Pythonで理解する蟻本「3-2 領域の個数」(p.150)

この記事は「プログラミングコンテストチャレンジブック第2版」(蟻本)の
「3-2 領域の個数」(p.150)
のコードをPythonで書き直したものとなっています。

入力

w\,h\,n\\x1_1\,…\,x1_n\\x2_1\,…\,x2_n\\y1_1\,…\,y1_n\\y2_1\,…\,y2_n

入力例



10 10 5
1 1 4 9 10
6 10 4 9 10
4 8 1 1 6
4 8 10 5 10


解答

from collections import deque

# 入力
W, H, N = map(int,input().split())
X1 = list(map(int,input().split()))
X2 = list(map(int,input().split()))
Y1 = list(map(int,input().split()))
Y2 = list(map(int,input().split()))

# 塗りつぶし用
fld = [[False] * (N * 6) for _ in range(N * 6)]

# x1, x2を座標圧縮し、座標圧縮した際の幅を返す
def compress(x1, x2, w):
    xs = []
    
    for i in range(N):
        for d in range(-1, 2):
            tx1 = x1[i] + d
            tx2 = x2[i] + d
            if 1 <= tx1 <= W:
                xs.append(tx1)
            if 1 <= tx2 <= W:
                xs.append(tx2)
    xs = list(set(xs))
    xs.sort()
    
    for i in range(N):
        x1[i] = xs.index(x1[i])
        x2[i] = xs.index(x2[i])
    
    return x1, x2, len(xs)

# 座標圧縮
X1, X2, W = compress(X1, X2, W)
Y1, Y2, H = compress(Y1, Y2, H)

# 腺のある部分を塗りつぶし
for i in range(N):
    for y in range(Y1[i], Y2[i] + 1):
        for x in range(X1[i], X2[i] + 1):
            fld[y][x] = True

# 領域を数える
ans = 0
dx = [1, 0, -1, 0]
dy = [0, 1, 0, -1]
for y in range(H):
    for x in range(W):
        if fld[y][x]:
            continue
        ans += 1
        
        # 幅優先探索
        que = deque([(x, y)])
        while len(que) != 0:
            sx, sy = que.popleft()
            
            for i in range(4):
                tx = sx + dx[i]
                ty = sy + dy[i]
                if tx < 0 or W <= tx or ty < 0 or H <= ty:
                    continue
                if fld[ty][tx]:
                    continue
                que.append((tx, ty))
                fld[ty][tx] = True

print(ans)

Pythonで理解する蟻本「3-2 巨大ナップサック」(p.148)

この記事は「プログラミングコンテストチャレンジブック第2版」(蟻本)の
「3-2 巨大ナップサック」(p.148)
のコードをPythonで書き直したものとなっています。

入力

n\\w_1\,…\,w_n\\v_1\,…\,v_n\\W

入力例



4
2 1 3 2
3 2 4 2
5


解答

from bisect import bisect_left
INF = float('inf')

# 入力
n = int(input())
w = list(map(int,input().split()))
v = list(map(int,input().split()))
W = int(input())

ps = []    # (重さ、価値)

# 前半分を全列挙
n2 = n // 2
for i in range(1 << n2):
    sw = 0
    sv = 0
    for j in range(n2):
        if (i >> j) & 1:
            sw += w[j]
            sv += v[j]
    ps.append((sw, sv))
    
# 無駄な要素を取り除く
ps.sort()
m = 1
for i in range(1, 1 << n2):
    if ps[m - 1][1] < ps[i][1]:
        ps[m] = ps[i]
        m += 1

# 後ろ半分を全列挙し解を求める
res = 0
for i in range(1 << (n - n2)):
    sw = 0
    sv = 0
    for j in range(n - n2):
        if (i >> j) & 1:
            sw += w[n2 + j]
            sv += v[n2 + j]
    if sw <= W:
        tv = ps[bisect_left(ps[:m], (W - sw, INF)) - 1][1]
        res = max(res, sv + tv)

print(res)