クルトンのプログラミング教室

Pythonの使い方やPythonを使った競技プログラミングの解法などを解説しています。

Pythonで理解する蟻本「2-3 01 ナップサック問題」(p.52)

この記事は「プログラミングコンテストチャレンジブック第2版」(蟻本)の
「2-3 01 ナップサック問題」(p.52)
のコードをPythonで書き直したものとなっています。

入力

n\,W\\w_1\,…\,w_n\\v_1\,…\,v_n

入力例



4 5
2 1 3 2
3 2 4 2


解答

再帰による愚直解(O(2^N))

# 入力
n, W = map(int,input().split())
w = list(map(int,input().split()))
v = list(map(int,input().split()))

# i番目以降の品物から重さの総和がj以下となるように選ぶ
def rec(i, j):
    if i == n:
        # もう品物は残っていない
        res = 0
    elif j < w[i]:
        # この品物は入らない
        res = rec(i + 1, j)
    else:
        # 入れない場合と入れる場合を両方試す
        res = max(rec(i + 1, j), rec(i + 1, j - w[i]) + v[i])
    return res

print(rec(0, W))

メモ化再帰(O(nW))

# 入力
n, W = map(int,input().split())
w = list(map(int,input().split()))
v = list(map(int,input().split()))

MAX_N = 100
MAX_W = 10000
dp = [[-1] * (MAX_W + 1) for _ in range(MAX_N + 1)]    # メモ化テーブル

def rec(i, j):
    if dp[i][j] >= 0:
        # すでに調べたことがあるならばその結果を再利用
        return dp[i][j]
    if i == n:
        res = 0
    elif j < w[i]:
        res = rec(i + 1, j)
    else:
        res = max(rec(i + 1, j), rec(i + 1, j - w[i]) + v[i])
    # 結果をテーブルに記憶する
    dp[i][j] = res
    return res

print(rec(0, W))

動的計画法(DP)(O(nW))

# 入力
n, W = map(int,input().split())
w = list(map(int,input().split()))
v = list(map(int,input().split()))

MAX_N = 100
MAX_W = 10000
dp = [[0] * (MAX_W + 1) for _ in range(MAX_N + 1)]    # dpテーブル

for i in range(n - 1, -1, -1):
    for j in range(W + 1):
        if j < w[i]:
            dp[i][j] = dp[i + 1][j]
        else:
            dp[i][j] = max(dp[i + 1][j], dp[i + 1][j - w[i]] + v[i])
print(dp[0][W])

iに関するループの向きを順方向に直した動的計画法(O(nW))

# 入力
n, W = map(int,input().split())
w = list(map(int,input().split()))
v = list(map(int,input().split()))

MAX_N = 100
MAX_W = 10000
dp = [[0] * (MAX_W + 1) for _ in range(MAX_N + 1)]    # dpテーブル

for i in range(n):
    for j in range(W + 1):
        if j < w[i]:
            dp[i + 1][j] = dp[i][j]
        else:
            dp[i + 1][j] = max(dp[i][j], dp[i][j - w[i]] + v[i])
print(dp[n][W])

p.56の解法(O(nW))

# 入力
n, W = map(int,input().split())
w = list(map(int,input().split()))
v = list(map(int,input().split()))

MAX_N = 100
MAX_W = 10000
dp = [[0] * (MAX_W + 1) for _ in range(MAX_N + 1)]    # dpテーブル

for i in range(n):
    for j in range(W + 1):
        dp[i + 1][j] = max(dp[i + 1][j], dp[i][j])
        if j + w[i] <= W:
            dp[i + 1][j + w[i]] = max(dp[i + 1][j + w[i]], dp[i][j] + v[i])
print(dp[n][W])