クルトンのプログラミング教室

Pythonの使い方やPythonを使った競技プログラミングの解法などを解説しています。

Pythonで理解する蟻本「3-3 バブルソートの交換回数」(p.162)

この記事は「プログラミングコンテストチャレンジブック第2版」(蟻本)の
「3-3 バブルソートの交換回数」(p.162)
のコードをPythonで書き直したものとなっています。

入力

n\\a_1\,…\,a_n

入力例



4
3 1 4 2


解答

# 入力
n = int(input())
a = list(map(int,input().split()))

# BITのソースコード
class BinaryIndexedTree:
    def __init__(self, n):
        self.n = n
        self.bit = [0] * (n + 1)
        
    def sum(self, i):
        s = 0
        while i > 0:
            s += self.bit[i]
            i -= i & -i
        return s

    def add(self, i, x):
        while i <= self.n:
            self.bit[i] += x
            i += i & -i

BIT = BinaryIndexedTree(n)
ans = 0
for j in range(n):
    ans += j - BIT.sum(a[j])
    BIT.add(a[j], 1)
print(ans)

Pythonで理解する蟻本「3-3 BITの実装」(p.161)

この記事は「プログラミングコンテストチャレンジブック第2版」(蟻本)の
「3-3 BITの実装」(p.161)
のコードをPythonで書き直したものとなっています。

蟻本のコード

# [1, n]
n = int(input())
bit = [0] * (n + 1)

def sum_(i):
    s = 0
    while i > 0:
        s += bit[i]
        i -= i & -i
    return s

def add(i, x):
    while i <= n:
        bit[i] += x
        i += i & -i

classを使ったコード

class BinaryIndexedTree:
    def __init__(self, n):
        self.n = n
        self.bit = [0] * (n + 1)
        
    def sum(self, i):
        s = 0
        while i > 0:
            s += self.bit[i]
            i -= i & -i
        return s

    def add(self, i, x):
        while i <= self.n:
            self.bit[i] += x
            i += i & -i

Pythonで理解する蟻本「3-3 Crane(POJ 2991)」(p.156)

この記事は「プログラミングコンテストチャレンジブック第2版」(蟻本)の
「3-3 Crane(POJ 2991)」(p.156)
のコードをPythonで書き直したものとなっています。

入力

N\,C\\L_1\,…\,L_N\\S_1\,…\,S_C\\A_1\,…\,A_C

入力例1



2 1
10 5
1
90

入力例2



3 2
5 5 5
1 2
270 90


解答

import sys
sys.setrecursionlimit(4100000)
from math import sin, cos, pi

ST_SIZE = 1 << 15 - 1

# 入力
N, C = map(int,input().split())
L = list(map(int,input().split()))
S = list(map(int,input().split()))
A = list(map(int,input().split()))

# セグメント木のデータ
vx = [0] * ST_SIZE    #各節点のベクトル
vy = [0] * ST_SIZE
ang = [0] * ST_SIZE

# 角度の変化を調べるため、現在の角度を保存しておく
prv = [0] * N


# セグメント木を初期化する
# kは接点の番号、l, rはその節点が[l, r)に対応づいていることを表す
def init(k, l, r):
    ang[k] = vx[k] = 0.0
    if r - l == 1:
        # 葉
        vy[k] = L[l]
    else:
        # 葉でない節点
        chL = k * 2 + 1
        chR = k * 2 + 2
        init(chL, l, (l + r) // 2)
        init(chR, (l + r) // 2, r)
        vy[k] = vy[chL] + vy[chR]

# 場所sの角度がaだけ変更になった
# vは節点の番号、l, rはその節点が[l, r)に対応づいていることを表す
def change(s, a, v, l, r):
    if s <= l:
        return
    elif s < r:
        chL = v * 2 + 1
        chR = v * 2 + 2
        m = (l + r) // 2
        change(s, a, chL, l, m)
        change(s, a, chR, m, r)
        if s <= m:
            ang[v] += a
        
        s = sin(ang[v])
        c = cos(ang[v])
        vx[v] = vx[chL] + (c * vx[chR] - s * vy[chR])
        vy[v] = vy[chL] + (s * vx[chR] + c * vy[chR])

# 初期化
init(0, 0, N)
for i in range(1, N):
    prv[i] = pi

# 各クエリを処理
for i in range(C):
    s = S[i]
    a = A[i] / 360.0 * 2 * pi    # ラジアンに直す
    
    change(s, a - prv[s], 0, 0, N)
    prv[s] = a
    
    print("%.2f %.2f" % (vx[0], vy[0]))

Pythonで理解する蟻本「3-3 セグメント木によるRMQの実装」(p.155)

この記事は「プログラミングコンテストチャレンジブック第2版」(蟻本)の
「3-3 セグメント木によるRMQの実装」(p.155)
のコードをPythonで書き直したものとなっています。

コード

import sys
sys.setrecursionlimit(4100000)

MAX_N = 1 << 17
INT_MAX = (1 << 31) - 1

# セグメント木を持つグローバル配列
n = int(input())    # セグ木のサイズ
dat = [0] * (2 * MAX_N - 1)

# 初期化
def init(n_):
    # 簡単のため、要素数を2のべき乗に
    n = 1
    while n < n_:
        n *= 2
    
    # すべての値をINT_MAXに
    for i in range(2 * n - 1):
        dat[i] = INT_MAX
    return n
n = init(n)

# k番目の値(0-indexed)をaに変更
def update(k, a):
    # 葉の節点
    k += n - 1
    dat[k] = a
    # 登りながら更新
    while k > 0:
        k = (k - 1) // 2
        dat[k] = min(dat[k * 2 + 1], dat[k * 2 + 2])

# [a, b)の最小値を求める
# 後ろのほうの引数は、計算の簡単のための引数。
# kは節点の番号、l, rはその節点が[l, r)に対応づいていることを表す。
# したがって、外からはquery(a, b, 0, 0, n)として呼ぶ。
def query(a, b, k, l, r):
    # [a, b)と[l, r)が交差しなければ、INT_MAX
    if r <= a or b <= l:
        return INT_MAX
    
    # [a, b)と[l, r)を完全に含んでいれば、この節点の値
    if a <= l and r <= b:
        return dat[k]
    else:
        # そうでなければ、2つの子の最小値
        vl = query(a, b, k * 2 + 1, l, (l + r) // 2)
        vr = query(a, b, k * 2 + 2, (l + r) // 2, r)
        return min(vl, vr)

Python版 AtCoder Library (Fenwick Tree)

この記事では、AtCoder Library (ACL)のfenwicktreeをPythonで書き直したものを公開しています。

Fenwick Tree

長さ N の配列に対し、

  • 要素の1点変更
  • 区間の要素の総和

O(log\,N) で求めることが出来るデータ構造です。

コード

class fenwick_tree:
    def __init__(self, n):
        self.n = n
        self.bit = [0] * (n + 1)
    
    def add(self, p, x):
        p += 1
        while p <= self.n:
            self.bit[p] += x
            p += p & -p
    
    def sum_(self, p):
        s = 0
        while p > 0:
            s += bit[p]
            p -= p & -p
        return s
    
    def sum(self, l, r):
        return sum_(r) - sum_(l)

コンストラク

fw = fenwick_tree(n)
  • 長さnの配列a_0,a_1,…,a_{n-1}を作ります。初期値はすべて0です。

計算量

  • O(n)

add

fw.add(p, x)

a[p] += x を行う

制約

  • {0}\leq{p}<{n}

計算量

  • O(log\,n)

sum

fw.sum(l, r)

a[l] + a[l + 1] + ... + a[r - 1] を返す。

制約

  • {0}\leq{l}\leq{r}\leq{n}

計算量

  • O(log\,n)

使用例

AC code of https://atcoder.jp/contests/practice2/tasks/practice2_b

############################################

class fenwick_tree:
    def __init__(self, n):
        self.n = n
        self.bit = [0] * (n + 1)
    
    def add(self, p, x):
        p += 1
        while p <= self.n:
            self.bit[p] += x
            p += p & -p
    
    def sum_(self, p):
        s = 0
        while p > 0:
            s += self.bit[p]
            p -= p & -p
        return s
    
    def sum(self, l, r):
        return self.sum_(r) - self.sum_(l)

############################################

N, Q = map(int,input().split())
a = list(map(int,input().split()))

fw = fenwick_tree(N)

for i in range(N):
    fw.add(i, a[i])

for _ in range(Q):
    q = list(map(int,input().split()))
    if q[0] == 0:
        fw.add(q[1], q[2])
    else:
        print(fw.sum(q[1], q[2]))

Pythonで理解する蟻本「3-2 領域の個数」(p.150)

この記事は「プログラミングコンテストチャレンジブック第2版」(蟻本)の
「3-2 領域の個数」(p.150)
のコードをPythonで書き直したものとなっています。

入力

w\,h\,n\\x1_1\,…\,x1_n\\x2_1\,…\,x2_n\\y1_1\,…\,y1_n\\y2_1\,…\,y2_n

入力例



10 10 5
1 1 4 9 10
6 10 4 9 10
4 8 1 1 6
4 8 10 5 10


解答

from collections import deque

# 入力
W, H, N = map(int,input().split())
X1 = list(map(int,input().split()))
X2 = list(map(int,input().split()))
Y1 = list(map(int,input().split()))
Y2 = list(map(int,input().split()))

# 塗りつぶし用
fld = [[False] * (N * 6) for _ in range(N * 6)]

# x1, x2を座標圧縮し、座標圧縮した際の幅を返す
def compress(x1, x2, w):
    xs = []
    
    for i in range(N):
        for d in range(-1, 2):
            tx1 = x1[i] + d
            tx2 = x2[i] + d
            if 1 <= tx1 <= W:
                xs.append(tx1)
            if 1 <= tx2 <= W:
                xs.append(tx2)
    xs = list(set(xs))
    xs.sort()
    
    for i in range(N):
        x1[i] = xs.index(x1[i])
        x2[i] = xs.index(x2[i])
    
    return x1, x2, len(xs)

# 座標圧縮
X1, X2, W = compress(X1, X2, W)
Y1, Y2, H = compress(Y1, Y2, H)

# 腺のある部分を塗りつぶし
for i in range(N):
    for y in range(Y1[i], Y2[i] + 1):
        for x in range(X1[i], X2[i] + 1):
            fld[y][x] = True

# 領域を数える
ans = 0
dx = [1, 0, -1, 0]
dy = [0, 1, 0, -1]
for y in range(H):
    for x in range(W):
        if fld[y][x]:
            continue
        ans += 1
        
        # 幅優先探索
        que = deque([(x, y)])
        while len(que) != 0:
            sx, sy = que.popleft()
            
            for i in range(4):
                tx = sx + dx[i]
                ty = sy + dy[i]
                if tx < 0 or W <= tx or ty < 0 or H <= ty:
                    continue
                if fld[ty][tx]:
                    continue
                que.append((tx, ty))
                fld[ty][tx] = True

print(ans)

Pythonで理解する蟻本「3-2 巨大ナップサック」(p.148)

この記事は「プログラミングコンテストチャレンジブック第2版」(蟻本)の
「3-2 巨大ナップサック」(p.148)
のコードをPythonで書き直したものとなっています。

入力

n\\w_1\,…\,w_n\\v_1\,…\,v_n\\W

入力例



4
2 1 3 2
3 2 4 2
5


解答

from bisect import bisect_left
INF = float('inf')

# 入力
n = int(input())
w = list(map(int,input().split()))
v = list(map(int,input().split()))
W = int(input())

ps = []    # (重さ、価値)

# 前半分を全列挙
n2 = n // 2
for i in range(1 << n2):
    sw = 0
    sv = 0
    for j in range(n2):
        if (i >> j) & 1:
            sw += w[j]
            sv += v[j]
    ps.append((sw, sv))
    
# 無駄な要素を取り除く
ps.sort()
m = 1
for i in range(1, 1 << n2):
    if ps[m - 1][1] < ps[i][1]:
        ps[m] = ps[i]
        m += 1

# 後ろ半分を全列挙し解を求める
res = 0
for i in range(1 << (n - n2)):
    sw = 0
    sv = 0
    for j in range(n - n2):
        if (i >> j) & 1:
            sw += w[n2 + j]
            sv += v[n2 + j]
    if sw <= W:
        tv = ps[bisect_left(ps[:m], (W - sw, INF)) - 1][1]
        res = max(res, sv + tv)

print(res)