跳转至

ABC332(A-G)

这场质量太高了!

A - Online Shopping

题目大意

有一个网店的有 \(N\) 件物品,第 \(i\) 件物品的单价是 \(P_i\),你打算买 \(Q_i\) 件。总价达到 \(S\) 可以免运费,低于 \(S\) 需要支付 \(K\) 的运费。问:你需要支付多少钱?

参考代码
#include <iostream>

using namespace std;

int main()
{
    int n, s, k, sum = 0, p, q;
    cin >> n >> s >> k;
    while(n--)
    {
        cin >> p >> q;
        sum += p * q;
    }
    if(sum < s)
        sum += k;
    cout << sum << endl;
    return 0;
}
import java.util.Scanner;

public class Main {

    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int n = in.nextInt();
        int s = in.nextInt();
        int k = in.nextInt();
        int sum = 0, p, q;
        while (n-- > 0) {
            p = in.nextInt();
            q = in.nextInt();
            sum += p * q;
        }
        if (sum < s)
            sum += k;
        System.out.println(sum);
        in.close();
    }
}

B - Glass and Mug

题目大意

你有一个容量为 \(G\) 毫升的玻璃杯,和一个容量为 \(M(G < M)\) 毫升的马克杯,初始的时候两个杯子都没有水,执行 \(K\) 次下列的操作:

  • 如果玻璃杯的水是满的,将玻璃杯的全部倒掉。
  • 否则,如果马克杯是空的,将马克杯的水补充到满。
  • 否则,将马克杯的水倒向玻璃杯,直到玻璃杯的水满或者马克杯的水为空。

问:\(K\) 次操作之后,两个杯子还剩多少水?

参考代码
#include <iostream>

using namespace std;

int main()
{
    int K, G, M, g = 0, m = 0;
    cin >> K >> G >> M;
    while(K--)
    {
        if(g == G)
            g = 0;
        else if(m == 0)
            m = M;
        else
        {
            int d = min(G-g, m);
            g += d;
            m -= d;
        }
    }
    cout << g << ' ' << m;
    return 0;
}
import java.util.Scanner;

public class Main {

    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int K, G, M, g = 0, m = 0;
        K = in.nextInt();
        G = in.nextInt();
        M = in.nextInt();
        while (K-- > 0) {
            if (g == G)
                g = 0;
            else if (m == 0)
                m = M;
            else {
                int d = Math.min(G - g, m);
                g += d;
                m -= d;
            }
        }
        System.out.printf("%d %d", g, m);
        in.close();
    }
}

C - T-shirts

题目大意

在你外出时,你会穿一件衬衫,衬衫分为有 logo 的衬衫和没有 logo 的衬衫。穿过的衬衫会变脏,脏的衬衫必须洗过才能重新穿。你有一份接下来 \(N\) 天的行程,用一个字符串来表示,具体来说:

  • 如果 s[i] == '0',表示第 \(i\) 天你在家休息,在家休息的时候,你会洗掉所有的脏衬衫,洗过的衬衫从第二天开始就可以穿。
  • 如果 s[i] == '1',表示第 \(i\) 天你打算穿任意一件(有 logo 或者没有 logo 的)衬衫外出。
  • 如果 s[i] == '2',表示第 \(i\)​​ 天你打算穿一件有 logo 的衬衫外出。

初始的时候你拥有 \(M\) 件没有 logo 的衬衫,你可以额外购买一些带有 logo 的衬衫,在第 \(0\) 天的时候所有你拥有的(包括额外购买的)衬衫都已经洗过了。问:为了满足出行的需要,最少需要额外购买多少件带有 logo 的衬衫?

参考代码
#include <iostream>

using namespace std;

const int maxn = 1000 + 10;

int main()
{
    int n, m, M = 0, x = 0, y = 0;
    char s[maxn];
    cin >> n >> m >> s;
    for(int i = 0; i <= n; i++)
    {
        if(i == n || s[i] == '0' )
        {
            if(x > m)
                y += x - m;
            M = max(y, M);
            x = y = 0;
        }
        else if(s[i] == '1')
            x++;
        else
            y++;
    }
    cout << M << endl;
    return 0;
}
import java.util.Scanner;

public class Main {

    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int n, m, M = 0, x = 0, y = 0;
        String s;
        n = in.nextInt();
        m = in.nextInt();
        s = in.next();
        for (int i = 0; i <= n; i++) {
            if (i == n || s.charAt(i) == '0') {
                if (x > m)
                    y += x - m;
                M = Math.max(y, M);
                x = y = 0;
            } else if (s.charAt(i) == '1')
                x++;
            else
                y++;
        }
        System.out.println(M);
        in.close();
    }
}

D - Swapping Puzzle

题目大意

给你两个 \(H(H \le 5)\)\(W(W \le 5)\) 列的二维矩阵 \(A\)\(B\),你可以任意交换矩阵 \(A\) 相邻的两行或者相邻的两列。问:最少需要多少次交换才能让 \(A\)\(B\) 完全相等?如果不可能让 \(A\)\(B\) 完全相等输出 -1

解题思路

注意到数据量很小,要想枚举所有的交换方案只需要枚举行和列的全排列,而交换次数等于逆序数,直接搜索即可,事件复杂度是 \(O(H! \times W!)\)

如果要进一步剪枝的话,只需要保证枚举排列的时候是按照逆序对递增的顺序来进行的即可,事实上 C++ STLnext_permutation() 就是按照这样的顺序枚举的,所以下面的 C++ 代码会简单一些。

Java 代码是自己实现的 next_permutation(),注意如果是用 dfs 搜索出来的排列并不是逆序对递增的顺序。

参考代码
#include <iostream>
#include <algorithm>

using namespace std;

const int maxn = 5 + 5;

int A[maxn][maxn], B[maxn][maxn];
int rm[] = {0, 1, 2, 3, 4};
int cm[] = {0, 1, 2, 3, 4};
int n, m, ans = 1E8;

bool check()
{
    for (int i = 0; i < n; i++)
        for (int j = 0; j < m; j++)
            if (A[rm[i]][cm[j]] != B[i][j])
                return 0;
    return 1;
}

int main()
{
    cin >> n >> m;
    for (int i = 0; i < n; i++)
        for (int j = 0; j < m; j++)
            cin >> A[i][j];
    for (int i = 0; i < n; i++)
        for (int j = 0; j < m; j++)
            cin >> B[i][j];
    int x = 0;
    do
    {
        x = 0;
        for (int i = 1; i < n; i++)
            for (int j = 0; j < i; j++)
                x += rm[i] < rm[j];
        int y = 0;
        do
        {
            y = 0;
            for (int i = 1; i < m; i++)
                for (int j = 0; j < i; j++)
                    y += cm[i] < cm[j];
            if (check()) {
                ans = min(ans, x + y);
                for(int i = 0; i < m; i++)
                    cm[i] = i;
            }
        } while (next_permutation(cm, cm + m) && x + y < ans);
    } while (next_permutation(rm, rm + n) && x < ans);
    if (ans == 1E8)
        ans = -1;
    cout << ans << endl;
    return 0;
}
import java.util.Scanner;

public class Main {
    private static int[][] A, B;
    private static int n, m;
    private static int rm[];
    private static int cm[];
    private static int ans = (int) 1E8;

    private static boolean check() {
        for (int i = 0; i < n; i++)
            for (int j = 0; j < m; j++)
                if (A[rm[i]][cm[j]] != B[i][j])
                    return false;
        return true;
    }

    private static void reverse(int[] nums, int l, int r) {
        for (int i = l, j = r, temp; i < j; i++, j--) {
            temp = nums[i];
            nums[i] = nums[j];
            nums[j] = temp;
        }
    }

    public static boolean nextPermutation(int[] nums) {
        int n = nums.length, i = n - 2;
        while (i >= 0 && nums[i] >= nums[i + 1])
            i--;
        if (i == -1) {
            reverse(nums, 0, n - 1);
            return false;
        }
        int j = n - 1;
        while (j > i && nums[j] <= nums[i])
            j--;
        int temp = nums[i];
        nums[i] = nums[j];
        nums[j] = temp;
        reverse(nums, i + 1, n - 1);
        return true;
    }

    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        n = in.nextInt();
        m = in.nextInt();
        A = new int[n][m];
        B = new int[n][m];
        rm = new int[n];
        cm = new int[m];
        for (int i = 0; i < n; i++)
            rm[i] = i;
        for (int i = 0; i < m; i++)
            cm[i] = i;
        for (int i = 0; i < n; i++)
            for (int j = 0; j < m; j++)
                A[i][j] = in.nextInt();
        for (int i = 0; i < n; i++)
            for (int j = 0; j < m; j++)
                B[i][j] = in.nextInt();
        int x = 0;
        do {
            x = 0;
            for (int i = 1; i < n; i++)
                for (int j = 0; j < i; j++)
                    x += rm[i] < rm[j] ? 1 : 0;
            int y = 0;
            do {
                y = 0;
                for (int i = 1; i < m; i++)
                    for (int j = 0; j < i; j++)
                        y += cm[i] < cm[j] ? 1 : 0;
                if (check()) {
                    ans = Math.min(ans, x + y);
                    for (int i = 0; i < m; i++)
                        cm[i] = i;
                }
            } while (nextPermutation(cm) && x + y < ans);
        } while (nextPermutation(rm) && x < ans);
        if (ans == 1E8)
            ans = -1;
        System.out.println(ans);
        in.close();
    }
}

E - Lucky bag

题目大意

你有 \(N(N \le 15)\) 个物品,第 \(i\) 个物品的重量是 \(W_i(W_i \le 10^8)\),你需要将这些物品分成 \(D\) 组,每一组的重量为所有物品的重量之和,每个物品必须属于一个组,允许存在空的组,空组的重量是 \(0\)。记每个组的物品重量之和为 \(x_1, x_2, \cdots x_D\),找到一种最小化 \(x_i\) 的方差的分组方案。换言之,你需要最小化 \(V = \frac{1}{D}\sum\limits_{i = 1}^D(x_i - \overline{x})^2\),其中 \(\overline{x} = \frac{1}{D}\sum\limits_{i = 1}^Dx_i\)

解题思路

考虑化简目标式子:

\[ \begin{align*} V & = \frac{1}{D}\sum_{i = 1}^D(x_i - \overline{x})^2\\ & = \frac{1}{D}(\sum_{i = 1}^Dx_i^2 + D\overline{x}^2 - 2\overline{x}\sum_{i = 1}^Dx_i)\\ & = \frac{1}{D}(\sum_{i = 1}^Dx_i^2 + D\overline{x}^2 - 2D\overline{x})\\ & = \frac{1}{D}\sum_{i = 1}^Dx_i^2 - \overline{x}^2 \end{align*} \]

注意到 \(\overline{x} = \frac{1}{D}\sum\limits_{i = 1}^Dx_i = \frac{1}{D}\sum\limits_{i = 1}^NW_i\),因此 \(\overline{x}\) 实际是一个常数,则在 \(V = \frac{1}{D}\sum_{i = 1}^Dx_i^2 - \overline{x}^2\\\) 中,只需要最小化 \(\sum\limits_{i = 1}^Dx_i^2\) 的值。注意到 \(N \le 15\),可以用 状压dp 求解。

定义 \(dp[s][j]\):将集合 \(s\) 的物品分成 \(j\) 个组,所得到的每个组的重量平方和的最小值。

显然,转移时只需要枚举一个子集作为上一个分组,设 \(ss\)\(s\) 的子集, 则有: $$ dp[s][j] = \min_{ss \sub s}{dp[s \backslash ss][j-1]+(\sum_{i \in ss}W_i)^2} $$ 其中 \(s \backslash ss\) 表示 \(s\) 去掉子集 \(ss\) 之后剩余的集合,只需要用异或运算即可。\((\sum_{i \in ss}W_i)^2\) 表示将 \(ss\) 集合中所有的物品分成一组产生的重量之和的平方,可以在 \(O(N\times2^N)\) 的时间内预处理出来。

在转移的时候还需要枚举 \(s\) 的所有子集 \(ss\),这可以使用 位运算技巧 里面的方法枚举出来,可以证明枚举所有集合的子集需要的时间复杂度是 \(O(3^N)\),因此 dp 总的时间复杂度是 \(O(N\times 3^N)\)

另外,本题的数据非常毒瘤,dp 算出来的结果是 long long,转成 long double 之后要很容易在除的时候丢失精度,需要减少除法的次数。设 dp 的结果为 \(X_{\min}\),所有物品的重量之和为 \(W_s\) 则有:

\[ \begin{align*} V & = \frac{1}{D}\sum_{i = 1}^Dx_i^2 - \overline{x}^2 \\ & = \frac{1}{D}X_{\min} - \frac{1}{D^2}W_s^2 \\ & = \frac{DX_{\min} - W_s^2}{D^2} \end{align*} \]

这样就只需要进行一次 long double 的除法,确保小数不会丢失。

另外,本题还存在一种随机化的做法,只需要随机打乱物品的顺序,然后每次将物品加入到重量最小的分组中即可(很直观的做法对吧)。

参考代码
#include <iostream>
#include <climits>

using namespace std;

typedef long long LL;
typedef long double LD;
const int maxn = 15 + 10;
const int maxs = (1 << 15) + 10;

int n, d, w[maxn];
LL sum[maxs], dp[maxs][maxn];

int main()
{
    cin >> n >> d;
    for(int i = 0; i < n; i++)
        cin >> w[i];
    int mask = (1 << n) - 1;
    for(int s = 0; s <= mask; s++) 
    {
        for(int i = 0; i < n; i++)
            if(s >> i & 1)
                sum[s] += w[i];
        sum[s] *= sum[s];
    }
    for(int s = 0; s <= mask; s++)
    {
        dp[s][1] = sum[s];
        for(int j = 2; j <= d; j++)
        {
            dp[s][j] = LLONG_MAX;
            int ss = s;
            do {
                int ls = s ^ ss;
                dp[s][j] = min(dp[s][j], dp[ls][j-1]+sum[ss]);
                ss = (ss - 1) & s;
            }while(ss != s);
        }
    }
    LL temp = dp[mask][d] * d - sum[mask];
    LD ans = (LD)temp / (d * d);
    printf("%.8Lf", ans);
    return 0;
}
import java.util.Scanner;

public class Main {

    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);

        int n = in.nextInt();
        int d = in.nextInt();
        int[] w = new int[n];
        int len = 1 << n, mask = len - 1;
        long[] sum = new long[len];
        long[][] dp = new long[len][d + 1];
        for (int i = 0; i < n; i++)
            w[i] = in.nextInt();
        for (int s = 0; s <= mask; s++) {
            for (int i = 0; i < n; i++)
                if ((s >> i & 1) == 1)
                    sum[s] += w[i];
            sum[s] *= sum[s];
        }
        for (int s = 0; s <= mask; s++) {
            dp[s][1] = sum[s];
            for (int j = 2; j <= d; j++) {
                dp[s][j] = Long.MAX_VALUE;
                int ss = s;
                do {
                    int ls = s ^ ss;
                    dp[s][j] = Math.min(dp[s][j], dp[ls][j - 1] + sum[ss]);
                    ss = (ss - 1) & s;
                } while (ss != s);
            }
        }
        long temp = dp[mask][d] * d - sum[mask];
        double ans = (double) temp / (d * d);
        System.out.printf("%.8f", ans);
        in.close();
    }
}

F - Random Update Query

题目大意

给你 \(N\) 个数,第 \(i\) 个数是 \(A_i\),执行 \(M\) 次操作,每次操作有 \(L_i \ R_i \ X_i\) 三个数,表示:

  • 在区间 \([L_i, R_i]\) 中随机等概率的选择一个位置,将这个位置的数字替换成 \(X_i\)

求出经过 \(M\) 次操作后,每个位置的数字的数学期望,答案对 \(998244353\) 取模。

解题思路

假设某次操作为 \(L_i \ R_i \ X_i\),设 \(p_i = \frac{1}{R_i - L_i + 1}\),区间 \([L_i, R_i]\) 某个位置的数字为 \(a\),则操作后这个位置的期望为 \(E_1(X) = (1-p_i)a + p_iX_i\)

考虑后一个操作 \(L_j \ R_j \ X_j\),设 \(p_j = \frac{1}{R_j - L_j + 1}\),经过两次操作,这个位置有三种可能:

数字 概率
\(a\) \((1-p_i)(1-p_j)\)
\(X_i\) \(p_i(1-p_j)\)
\(X_j\) \(p_j\)

这种情况的数学期望为:

\[ \begin{align*} E_2(X) & = a(1-p_i)(1-p_j) + X_ip_i(1-p_j) + X_jp_j \\ & = ((1-p_i)a + p_iX_i)(1-p_j) + X_jp_j \\ & = E_1(X)\times (1-p_j) + X_jp_j \end{align*} \]

\(k_j = 1-p_j\)\(b_j = X_jp_j\),显然,\(k_j\)\(b_j\) 都是容易计算的。

这表明 \(E_i(X)\) 实际是 \(k_iE_{i-1}(X) + b_i\) 的线性结构,对于一个点,在多次修改的时候,只需要每次算出 \(k_i\)\(b_i\) 然后叠加计算即可。由于涉及到区间修改,很容易想到懒标记的线段树,修改完所有操作之后再一个个单点查询即可。

参考代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

using namespace std;

typedef long long LL;
const int maxn = 2E5 + 10;
const int mod = 998244353;

void exgcd(LL a, LL b, LL &x, LL &y)
{
    if(b == 0)
    {
        x = 1, y = 0;
        return;
    }
    exgcd(b, a % b, y, x);
    y -= a / b * x;
}

LL inv(LL a, LL p)
{
    LL x, y;
    exgcd(a, p, x, y);
    if(x < 0)
        x += p;
    return x;
}

int n, m; 
LL a[maxn];

struct Node {
    LL mul, add;
    Node() : mul(1), add(0) {}
}t[maxn * 4];

void pushdown(int p)
{
    LL m = t[p].mul, a = t[p].add;
    if(m == 1 && a == 0)
        return;
    int lch = p * 2, rch = p * 2 + 1;
    t[lch].add = (m * t[lch].add + a) % mod;
    t[lch].mul = m * t[lch].mul % mod;
    t[rch].add = (m * t[rch].add + a) % mod;
    t[rch].mul = m * t[rch].mul % mod;
    t[p].mul = 1, t[p].add = 0;
}

void modify(int p, int beg, int end, int l, int r, LL m, LL a)
{
    if(end < l || beg > r)
        return;
    if(beg >= l && end <= r)
    {
        t[p].mul = (t[p].mul * m) % mod;
        t[p].add = (t[p].add * m + a) % mod;
        return;
    }
    pushdown(p);
    int mid = (beg + end) / 2;
    int lch = p * 2, rch = p * 2 + 1;
    modify(lch, beg, mid, l, r, m, a);
    modify(rch, mid+1, end, l, r, m, a);
}

LL query(int p, int beg, int end, int pos)
{
    if(beg == end)
        return (a[beg] * t[p].mul + t[p].add) % mod;
    pushdown(p);
    int mid = (beg + end) / 2;
    int lch = p * 2, rch = p * 2 + 1;
    return pos <= mid ? query(lch, beg, mid, pos) : query(rch, mid+1, end, pos);
}

int main()
{
    cin >> n >> m;
    for(int i = 1; i <= n; i++)
        scanf("%lld", &a[i]);
    while(m--)
    {
        int l, r;
        LL x;
        scanf("%d %d %lld", &l, &r, &x);
        LL a = inv(r - l + 1, mod), m = (r - l) * a % mod;
        a = a * x % mod;
        modify(1, 1, n, l, r, m, a);
    }
    for(int i = 1; i <= n; i++)
        printf("%lld ", query(1, 1, n, i));
    return 0;
}
import java.util.Scanner;

class Node {
    public long mul, add;

    public Node() {
        mul = 1;
        add = 0;
    }
}

public class Main {
    private static final int mod = 998244353;

    private static long[] exgcd(long a, long b) {
        if (b == 0)
            return new long[] { 1, 0 };
        long[] xy = exgcd(b, a % b);
        long x = xy[0], y = xy[1];
        xy[0] = y;
        xy[1] = x - a / b * y;
        return xy;
    }

    private static long inv(long a, long p) {
        long x = exgcd(a, p)[0];
        if (x < 0)
            x += p;
        return x;
    }

    private static Node[] t;
    private static long[] a;

    private static void build(int p, int beg, int end) {
        t[p] = new Node();
        if (beg == end)
            return;
        int mid = (beg + end) / 2;
        int lch = p * 2, rch = p * 2 + 1;
        build(lch, beg, mid);
        build(rch, mid + 1, end);
    }

    private static void pushdown(int p) {
        long m = t[p].mul, a = t[p].add;
        if (m == 1 && a == 0)
            return;
        int lch = p * 2, rch = p * 2 + 1;
        t[lch].add = (m * t[lch].add + a) % mod;
        t[lch].mul = m * t[lch].mul % mod;
        t[rch].add = (m * t[rch].add + a) % mod;
        t[rch].mul = m * t[rch].mul % mod;
        t[p].mul = 1;
        t[p].add = 0;
    }

    private static void modify(int p, int beg, int end, int l, int r, long m, long a) {
        if (end < l || beg > r)
            return;
        if (beg >= l && end <= r) {
            t[p].mul = (t[p].mul * m) % mod;
            t[p].add = (t[p].add * m + a) % mod;
            return;
        }
        pushdown(p);
        int mid = (beg + end) / 2;
        int lch = p * 2, rch = p * 2 + 1;
        modify(lch, beg, mid, l, r, m, a);
        modify(rch, mid + 1, end, l, r, m, a);
    }

    private static long query(int p, int beg, int end, int pos) {
        if (beg == end)
            return (a[beg] * t[p].mul + t[p].add) % mod;
        pushdown(p);
        int mid = (beg + end) / 2;
        int lch = p * 2, rch = p * 2 + 1;
        return pos <= mid ? query(lch, beg, mid, pos) : query(rch, mid + 1, end, pos);
    }

    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int n = in.nextInt();
        int M = in.nextInt();
        t = new Node[4 * n + 4];
        a = new long[n + 1];
        build(1, 1, n);
        for (int i = 1; i <= n; i++)
            a[i] = in.nextInt();
        while (M-- > 0) {
            int l = in.nextInt();
            int r = in.nextInt();
            long x = in.nextLong();
            long a = inv(r - l + 1, mod), m = (r - l) * a % mod;
            a = a * x % mod;
            modify(1, 1, n, l, r, m, a);
        }
        for (int i = 1; i <= n; i++)
            System.out.printf("%d ", query(1, 1, n, i));
        in.close();
    }
}

G - Not Too Many Balls

题目大意

\(N(N \le 500)\) 种颜色的球,颜色的值为 \(1\)\(N\),颜色为 \(i\) 的球有 \(A_i\) 个,有 \(M(M \le 5 \times 10^5)\) 个盒子,第 \(j\) 个盒子最多可以放 \(B_j\) 个球。此外,颜色为 \(i\) 的球在第 \(j\) 个盒子中最多能放 \(i \times j\) 个。问:最多可以将多少个球放在盒子中?

解题思路

可以使用网络流解决这个问题,建立一个二分图,其中超级源点记为 \(S\),超级汇点记为 \(T\),球作为二分图的左部 \(X = (x_1, x_2, \cdots, x_n)\),盒子作为二分图的右部 \(Y = (y_1, y_2, \cdots, y_m)\)。则图中有三种边:

  1. \(S\)\(x_i\) 的连边,流值为 \(A_i\)
  2. \(x_i\)\(y_j\) 的连边,流值为 \(i \times j\)
  3. \(y_j\)\(T\) 的连边,流值为 \(B_j\)

直接跑最大流就是答案,这样时间复杂度是没问题的,不过看看空间复杂度:第二种边有 \(NM\) 条,这样肯定是存不下的。

考虑转化成最小割的问题,假设最小割 \(S\) 可达的集合为 \(P \cup Q\),其中 \(P \sub X\)\(Q \sub Y\),则最小割的组成为:

\[ \begin{align*} \min_{P\sub X}\min_{Q\sub Y} \{ \sum_{x_i \in X\backslash P}A_i + \sum_{x_i \in X}i\sum_{y_j \in Y \backslash Q}j + \sum_{y_j \in Y}B_j \} \end{align*} \]

解释一下这个公式的组成:

  • \(\sum\limits_{x_i \in X\backslash P}A_i\):对于 \(X\backslash P\) 集合中的点,\(S\) 不可达,设 \(x_i \in X \backslash P\),则 \((S, x_i)\) 对应第一种边,这些边一定会被割掉(否则就存在 \(S \rightarrow x_i\))。
  • \(\sum\limits_{x_i \in X}i\sum\limits_{y_j \in Y \backslash Q}j\):对于 \(X\) 集合中的点,\(S\) 可达,对于 \(Y \backslash Q\) 集合中的点,\(S\) 不可达,设 \(x_i \in X\)\(y_j \in Y \backslash Q\),则边 \((x_i, y_j)\) 对应第二种边,这些边一定会被割掉(否则就存在 \(S \rightarrow x_i \rightarrow y_j\) )。
  • \(\sum\limits_{y_j \in Y}B_j\):对于 \(Y\) 集合中的点,\(S\) 不可达,设 \(y_j \in Y\)\((y_j, T)\) 对应第三种边,这些点一定会被割掉(否则就存在 \(S \rightarrow \cdots \rightarrow y_j \rightarrow T\) )。

如何算出最小割?肯定不能枚举两个子集。官方题解给出了一种非常巧妙的方法绕开了子集的枚举,注意到 \(N \le 500\)(而且题目的名字也在提示球并不多),令 \(k = \sum\limits_{x_i \in X}i\),由于 \(i\) 的取值是 \(1\)\(N\),则 \(k\) 的取值就是 \(0\)\(\frac{N(N+1)}{2}\)(其中 \(0\) 仅当 \(X\) 为空集的时候取得),于是考虑枚举 \(k\) 的值,则有:

\[ \begin{align*} & \min_{P\sub X}\min_{Q\sub Y} \{ \sum_{x_i \in X\backslash P}A_i + \sum_{x_i \in X}i\sum_{y_j \in Y \backslash Q}j + \sum_{y_j \in Y}B_j \} \\ = & \min_{k = 0}^{\frac{N(N+1)}{2}}\{\min_{P\sub X}\{ \sum_{x_i \in X\backslash P}A_i\} + \min_{Q\sub Y}\{ \sum_{y_j \in Y \backslash Q}kj + \sum_{y_j \in Y}B_j \}\} \\ \end{align*} \]

为什么要这样做?\(\sum\limits_{x_i \in X}i\sum\limits_{y_j \in Y \backslash Q}j\) 这个式子耦合了两个集合,这样不方便处理,而固定前者之后就只剩下一个集合,进一步就能拆成两部分。

考虑如何算出 \(\min\limits_{k = 0}^{\frac{N(N+1)}{2}}\min\limits_{P\sub X}\{ \sum\limits_{x_i \in X\backslash P}A_i\}\)\(k\) 实际限制了集合中每个球颜色值的和,而这个式子可以使用 dp 预处理出来,定义 \(dp[k]\):颜色值之和为 \(k\) 的球的集合对应的 \(\sum A_i\) 的最小值,这是一个背包问题,可以在 \(O(N^2\times N) = O(N^3)\) 的时间内完成。预处理出所有第 \(dp[k]\) 值之后,对于特定的 \(k\) 值, \(\min\limits_{P\sub X}\{ \sum\limits_{x_i \in X\backslash P}A_i\} = dp[\frac{N(N+1)}{2}-k]\)

考虑如何算出 \(\min\limits_{k = 0}^{\frac{N(N+1)}{2}}\min\limits_{Q\sub Y}\{\sum\limits_{y_j \in Y \backslash Q}kj + \sum\limits_{y_j \in Y}B_j\}\)\(k\)\(\sum\limits_{y_j \in Y}B_j\) 的值是没有影响的,对于某个子集 \(Q \sub Y\),如果 \(Q\) 的某个元素 \(y_j\)\(kj < B_j\),则将 \(y_j\) 放在 \(Y \backslash Q\) 会更小。进一步可以发现,\(\min\limits_{Q\sub Y}\{ \sum\limits_{y_j \in Y \backslash Q}kj + \sum\limits_{y_j \in Y}B_j\} = \sum\limits_{j = 1}^m\min(kj, B_j)\),这样就不需要枚举子集 \(Q\) 了。如何计算 \(\min\limits_{k = 0}^{\frac{N(N+1)}{2}}\{\sum\limits_{j = 1}^m\min(kj, B_j) \} \\\)?这个式子是 \(O(N^2M)\) 的,还是会超时。\(\min(kj, B_j)\) 实际是一个分段函数,前半段是 \(kj\),后半段是 \(B_j\),两者都是直线,这样从 \(k\)\(k+1\) 自增变化时,\(\sum\limits_{j = 1}^m\min(kj, B_j)\) 的增长也是有规律的,使用一阶差分可以维护区间加同一个数,使用 二阶差分 可以实现区间加一个等差数列(也就是一条直线),所以只需要针对所有的分段函数作出二阶差分,就可以在 \(O(1)\) 的时间内算出每个 \(\sum\limits_{j = 1}^m\min(kj, B_j)\)

参考代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>

using namespace std;

typedef long long LL;
const int maxn = 500 + 10;
const int maxl = maxn * maxn / 2;
const int maxm = 5E5 + 10;

int n, m, l;
LL A[maxn], B[maxm], dp[maxl], c[maxl];

void add(int l, int r, LL a, LL d)
{
    LL e = a + d * (r - l);
    c[l] += a;
    c[l+1] += d - a;
    c[r+1] -= e + d;
    c[r+2] += e;
}

int main()
{
    cin >> n >> m;
    for(int i = 1; i <= n; i++)
        scanf("%lld", &A[i]);
    l = n * (n + 1) / 2;
    memset(dp, 0x3f, sizeof dp);
    dp[0] = 0;
    for(int i = 1; i <= n; i++)
        for(int j = l; j >= i; j--)
            dp[j] = min(dp[j], dp[j-i]+A[i]);
    for(int j = 1; j <= m; j++) 
    {
        scanf("%lld", &B[j]);
        LL K = B[j] / j;
        add(0, min(K, (LL)l), 0, j);
        if(K+1 <= l)
            add(K+1, l, B[j], 0);
    }
    for(int k = 1; k <= l; k++)
        c[k] += c[k-1];
    for(int k = 1; k <= l; k++)
        c[k] += c[k-1];
    LL ans = dp[l];
    for(int k = 1; k <= l; k++) 
        ans = min(ans, dp[l-k]+c[k]);
    cout << ans << endl;
    return 0;
}
import java.util.Arrays;
import java.util.Scanner;

public class Main {
    private static int n, m, l;
    private static long[] A, B, dp, c;

    private static void add(int l, int r, long a, long d) {
        long e = a + d * (r - l);
        c[l] += a;
        c[l + 1] += d - a;
        c[r + 1] -= e + d;
        c[r + 2] += e;
    }

    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        n = in.nextInt();
        m = in.nextInt();
        l = n * (n + 1) / 2;
        A = new long[n + 1];
        B = new long[m + 1];
        dp = new long[l + 1];
        c = new long[l + 3];
        for (int i = 1; i <= n; i++)
            A[i] = in.nextLong();
        Arrays.fill(dp, (long) 1E18);
        dp[0] = 0;
        for (int i = 1; i <= n; i++)
            for (int j = l; j >= i; j--)
                dp[j] = Math.min(dp[j], dp[j - i] + A[i]);
        for (int j = 1; j <= m; j++) {
            B[j] = in.nextLong();
            long K = B[j] / j;
            add(0, (int) Math.min(K, (long) l), 0, j);
            if (K + 1 <= l)
                add((int) K + 1, l, B[j], 0);
        }
        for (int k = 1; k <= l; k++)
            c[k] += c[k - 1];
        for (int k = 1; k <= l; k++)
            c[k] += c[k - 1];
        long ans = dp[l];
        for (int k = 1; k <= l; k++)
            ans = Math.min(ans, dp[l - k] + c[k]);
        System.out.println(ans);
        in.close();
    }
}

最后更新: 2024-03-04
创建日期: 2024-03-04