ABC332(A-G)
这场质量太高了!
A - Online Shopping¶
题目大意
有一个网店的有 \(N\) 件物品,第 \(i\) 件物品的单价是 \(P_i\),你打算买 \(Q_i\) 件。总价达到 \(S\) 可以免运费,低于 \(S\) 需要支付 \(K\) 的运费。问:你需要支付多少钱?
参考代码
#include <iostream>
using namespace std;
int main()
{
int n, s, k, sum = 0, p, q;
cin >> n >> s >> k;
while(n--)
{
cin >> p >> q;
sum += p * q;
}
if(sum < s)
sum += k;
cout << sum << endl;
return 0;
}
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner in = new Scanner(System.in);
int n = in.nextInt();
int s = in.nextInt();
int k = in.nextInt();
int sum = 0, p, q;
while (n-- > 0) {
p = in.nextInt();
q = in.nextInt();
sum += p * q;
}
if (sum < s)
sum += k;
System.out.println(sum);
in.close();
}
}
B - Glass and Mug¶
题目大意
你有一个容量为 \(G\) 毫升的玻璃杯,和一个容量为 \(M(G < M)\) 毫升的马克杯,初始的时候两个杯子都没有水,执行 \(K\) 次下列的操作:
- 如果玻璃杯的水是满的,将玻璃杯的全部倒掉。
- 否则,如果马克杯是空的,将马克杯的水补充到满。
- 否则,将马克杯的水倒向玻璃杯,直到玻璃杯的水满或者马克杯的水为空。
问:\(K\) 次操作之后,两个杯子还剩多少水?
参考代码
#include <iostream>
using namespace std;
int main()
{
int K, G, M, g = 0, m = 0;
cin >> K >> G >> M;
while(K--)
{
if(g == G)
g = 0;
else if(m == 0)
m = M;
else
{
int d = min(G-g, m);
g += d;
m -= d;
}
}
cout << g << ' ' << m;
return 0;
}
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner in = new Scanner(System.in);
int K, G, M, g = 0, m = 0;
K = in.nextInt();
G = in.nextInt();
M = in.nextInt();
while (K-- > 0) {
if (g == G)
g = 0;
else if (m == 0)
m = M;
else {
int d = Math.min(G - g, m);
g += d;
m -= d;
}
}
System.out.printf("%d %d", g, m);
in.close();
}
}
C - T-shirts¶
题目大意
在你外出时,你会穿一件衬衫,衬衫分为有 logo 的衬衫和没有 logo 的衬衫。穿过的衬衫会变脏,脏的衬衫必须洗过才能重新穿。你有一份接下来 \(N\) 天的行程,用一个字符串来表示,具体来说:
- 如果
s[i] == '0'
,表示第 \(i\) 天你在家休息,在家休息的时候,你会洗掉所有的脏衬衫,洗过的衬衫从第二天开始就可以穿。 - 如果
s[i] == '1'
,表示第 \(i\) 天你打算穿任意一件(有 logo 或者没有 logo 的)衬衫外出。 - 如果
s[i] == '2'
,表示第 \(i\) 天你打算穿一件有 logo 的衬衫外出。
初始的时候你拥有 \(M\) 件没有 logo 的衬衫,你可以额外购买一些带有 logo 的衬衫,在第 \(0\) 天的时候所有你拥有的(包括额外购买的)衬衫都已经洗过了。问:为了满足出行的需要,最少需要额外购买多少件带有 logo 的衬衫?
参考代码
#include <iostream>
using namespace std;
const int maxn = 1000 + 10;
int main()
{
int n, m, M = 0, x = 0, y = 0;
char s[maxn];
cin >> n >> m >> s;
for(int i = 0; i <= n; i++)
{
if(i == n || s[i] == '0' )
{
if(x > m)
y += x - m;
M = max(y, M);
x = y = 0;
}
else if(s[i] == '1')
x++;
else
y++;
}
cout << M << endl;
return 0;
}
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner in = new Scanner(System.in);
int n, m, M = 0, x = 0, y = 0;
String s;
n = in.nextInt();
m = in.nextInt();
s = in.next();
for (int i = 0; i <= n; i++) {
if (i == n || s.charAt(i) == '0') {
if (x > m)
y += x - m;
M = Math.max(y, M);
x = y = 0;
} else if (s.charAt(i) == '1')
x++;
else
y++;
}
System.out.println(M);
in.close();
}
}
D - Swapping Puzzle¶
题目大意
给你两个 \(H(H \le 5)\) 行 \(W(W \le 5)\) 列的二维矩阵 \(A\) 和 \(B\),你可以任意交换矩阵 \(A\) 相邻的两行或者相邻的两列。问:最少需要多少次交换才能让 \(A\) 和 \(B\) 完全相等?如果不可能让 \(A\) 和 \(B\) 完全相等输出 -1
。
解题思路
注意到数据量很小,要想枚举所有的交换方案只需要枚举行和列的全排列,而交换次数等于逆序数,直接搜索即可,事件复杂度是 \(O(H! \times W!)\)。
如果要进一步剪枝的话,只需要保证枚举排列的时候是按照逆序对递增的顺序来进行的即可,事实上 C++ STL 的 next_permutation()
就是按照这样的顺序枚举的,所以下面的 C++ 代码会简单一些。
Java 代码是自己实现的 next_permutation()
,注意如果是用 dfs 搜索出来的排列并不是逆序对递增的顺序。
参考代码
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn = 5 + 5;
int A[maxn][maxn], B[maxn][maxn];
int rm[] = {0, 1, 2, 3, 4};
int cm[] = {0, 1, 2, 3, 4};
int n, m, ans = 1E8;
bool check()
{
for (int i = 0; i < n; i++)
for (int j = 0; j < m; j++)
if (A[rm[i]][cm[j]] != B[i][j])
return 0;
return 1;
}
int main()
{
cin >> n >> m;
for (int i = 0; i < n; i++)
for (int j = 0; j < m; j++)
cin >> A[i][j];
for (int i = 0; i < n; i++)
for (int j = 0; j < m; j++)
cin >> B[i][j];
int x = 0;
do
{
x = 0;
for (int i = 1; i < n; i++)
for (int j = 0; j < i; j++)
x += rm[i] < rm[j];
int y = 0;
do
{
y = 0;
for (int i = 1; i < m; i++)
for (int j = 0; j < i; j++)
y += cm[i] < cm[j];
if (check()) {
ans = min(ans, x + y);
for(int i = 0; i < m; i++)
cm[i] = i;
}
} while (next_permutation(cm, cm + m) && x + y < ans);
} while (next_permutation(rm, rm + n) && x < ans);
if (ans == 1E8)
ans = -1;
cout << ans << endl;
return 0;
}
import java.util.Scanner;
public class Main {
private static int[][] A, B;
private static int n, m;
private static int rm[];
private static int cm[];
private static int ans = (int) 1E8;
private static boolean check() {
for (int i = 0; i < n; i++)
for (int j = 0; j < m; j++)
if (A[rm[i]][cm[j]] != B[i][j])
return false;
return true;
}
private static void reverse(int[] nums, int l, int r) {
for (int i = l, j = r, temp; i < j; i++, j--) {
temp = nums[i];
nums[i] = nums[j];
nums[j] = temp;
}
}
public static boolean nextPermutation(int[] nums) {
int n = nums.length, i = n - 2;
while (i >= 0 && nums[i] >= nums[i + 1])
i--;
if (i == -1) {
reverse(nums, 0, n - 1);
return false;
}
int j = n - 1;
while (j > i && nums[j] <= nums[i])
j--;
int temp = nums[i];
nums[i] = nums[j];
nums[j] = temp;
reverse(nums, i + 1, n - 1);
return true;
}
public static void main(String[] args) {
Scanner in = new Scanner(System.in);
n = in.nextInt();
m = in.nextInt();
A = new int[n][m];
B = new int[n][m];
rm = new int[n];
cm = new int[m];
for (int i = 0; i < n; i++)
rm[i] = i;
for (int i = 0; i < m; i++)
cm[i] = i;
for (int i = 0; i < n; i++)
for (int j = 0; j < m; j++)
A[i][j] = in.nextInt();
for (int i = 0; i < n; i++)
for (int j = 0; j < m; j++)
B[i][j] = in.nextInt();
int x = 0;
do {
x = 0;
for (int i = 1; i < n; i++)
for (int j = 0; j < i; j++)
x += rm[i] < rm[j] ? 1 : 0;
int y = 0;
do {
y = 0;
for (int i = 1; i < m; i++)
for (int j = 0; j < i; j++)
y += cm[i] < cm[j] ? 1 : 0;
if (check()) {
ans = Math.min(ans, x + y);
for (int i = 0; i < m; i++)
cm[i] = i;
}
} while (nextPermutation(cm) && x + y < ans);
} while (nextPermutation(rm) && x < ans);
if (ans == 1E8)
ans = -1;
System.out.println(ans);
in.close();
}
}
E - Lucky bag¶
题目大意
你有 \(N(N \le 15)\) 个物品,第 \(i\) 个物品的重量是 \(W_i(W_i \le 10^8)\),你需要将这些物品分成 \(D\) 组,每一组的重量为所有物品的重量之和,每个物品必须属于一个组,允许存在空的组,空组的重量是 \(0\)。记每个组的物品重量之和为 \(x_1, x_2, \cdots x_D\),找到一种最小化 \(x_i\) 的方差的分组方案。换言之,你需要最小化 \(V = \frac{1}{D}\sum\limits_{i = 1}^D(x_i - \overline{x})^2\),其中 \(\overline{x} = \frac{1}{D}\sum\limits_{i = 1}^Dx_i\)
解题思路
考虑化简目标式子:
注意到 \(\overline{x} = \frac{1}{D}\sum\limits_{i = 1}^Dx_i = \frac{1}{D}\sum\limits_{i = 1}^NW_i\),因此 \(\overline{x}\) 实际是一个常数,则在 \(V = \frac{1}{D}\sum_{i = 1}^Dx_i^2 - \overline{x}^2\\\) 中,只需要最小化 \(\sum\limits_{i = 1}^Dx_i^2\) 的值。注意到 \(N \le 15\),可以用 状压dp 求解。
定义 \(dp[s][j]\):将集合 \(s\) 的物品分成 \(j\) 个组,所得到的每个组的重量平方和的最小值。
显然,转移时只需要枚举一个子集作为上一个分组,设 \(ss\) 是 \(s\) 的子集, 则有: $$ dp[s][j] = \min_{ss \sub s}{dp[s \backslash ss][j-1]+(\sum_{i \in ss}W_i)^2} $$ 其中 \(s \backslash ss\) 表示 \(s\) 去掉子集 \(ss\) 之后剩余的集合,只需要用异或运算即可。\((\sum_{i \in ss}W_i)^2\) 表示将 \(ss\) 集合中所有的物品分成一组产生的重量之和的平方,可以在 \(O(N\times2^N)\) 的时间内预处理出来。
在转移的时候还需要枚举 \(s\) 的所有子集 \(ss\),这可以使用 位运算技巧 里面的方法枚举出来,可以证明枚举所有集合的子集需要的时间复杂度是 \(O(3^N)\),因此 dp 总的时间复杂度是 \(O(N\times 3^N)\)。
另外,本题的数据非常毒瘤,dp 算出来的结果是 long long
,转成 long double
之后要很容易在除的时候丢失精度,需要减少除法的次数。设 dp 的结果为 \(X_{\min}\),所有物品的重量之和为 \(W_s\) 则有:
这样就只需要进行一次 long double
的除法,确保小数不会丢失。
另外,本题还存在一种随机化的做法,只需要随机打乱物品的顺序,然后每次将物品加入到重量最小的分组中即可(很直观的做法对吧)。
参考代码
#include <iostream>
#include <climits>
using namespace std;
typedef long long LL;
typedef long double LD;
const int maxn = 15 + 10;
const int maxs = (1 << 15) + 10;
int n, d, w[maxn];
LL sum[maxs], dp[maxs][maxn];
int main()
{
cin >> n >> d;
for(int i = 0; i < n; i++)
cin >> w[i];
int mask = (1 << n) - 1;
for(int s = 0; s <= mask; s++)
{
for(int i = 0; i < n; i++)
if(s >> i & 1)
sum[s] += w[i];
sum[s] *= sum[s];
}
for(int s = 0; s <= mask; s++)
{
dp[s][1] = sum[s];
for(int j = 2; j <= d; j++)
{
dp[s][j] = LLONG_MAX;
int ss = s;
do {
int ls = s ^ ss;
dp[s][j] = min(dp[s][j], dp[ls][j-1]+sum[ss]);
ss = (ss - 1) & s;
}while(ss != s);
}
}
LL temp = dp[mask][d] * d - sum[mask];
LD ans = (LD)temp / (d * d);
printf("%.8Lf", ans);
return 0;
}
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner in = new Scanner(System.in);
int n = in.nextInt();
int d = in.nextInt();
int[] w = new int[n];
int len = 1 << n, mask = len - 1;
long[] sum = new long[len];
long[][] dp = new long[len][d + 1];
for (int i = 0; i < n; i++)
w[i] = in.nextInt();
for (int s = 0; s <= mask; s++) {
for (int i = 0; i < n; i++)
if ((s >> i & 1) == 1)
sum[s] += w[i];
sum[s] *= sum[s];
}
for (int s = 0; s <= mask; s++) {
dp[s][1] = sum[s];
for (int j = 2; j <= d; j++) {
dp[s][j] = Long.MAX_VALUE;
int ss = s;
do {
int ls = s ^ ss;
dp[s][j] = Math.min(dp[s][j], dp[ls][j - 1] + sum[ss]);
ss = (ss - 1) & s;
} while (ss != s);
}
}
long temp = dp[mask][d] * d - sum[mask];
double ans = (double) temp / (d * d);
System.out.printf("%.8f", ans);
in.close();
}
}
F - Random Update Query¶
题目大意
给你 \(N\) 个数,第 \(i\) 个数是 \(A_i\),执行 \(M\) 次操作,每次操作有 \(L_i \ R_i \ X_i\) 三个数,表示:
- 在区间 \([L_i, R_i]\) 中随机等概率的选择一个位置,将这个位置的数字替换成 \(X_i\)。
求出经过 \(M\) 次操作后,每个位置的数字的数学期望,答案对 \(998244353\) 取模。
解题思路
假设某次操作为 \(L_i \ R_i \ X_i\),设 \(p_i = \frac{1}{R_i - L_i + 1}\),区间 \([L_i, R_i]\) 某个位置的数字为 \(a\),则操作后这个位置的期望为 \(E_1(X) = (1-p_i)a + p_iX_i\)。
考虑后一个操作 \(L_j \ R_j \ X_j\),设 \(p_j = \frac{1}{R_j - L_j + 1}\),经过两次操作,这个位置有三种可能:
数字 | 概率 |
---|---|
\(a\) | \((1-p_i)(1-p_j)\) |
\(X_i\) | \(p_i(1-p_j)\) |
\(X_j\) | \(p_j\) |
这种情况的数学期望为:
令 \(k_j = 1-p_j\),\(b_j = X_jp_j\),显然,\(k_j\) 和 \(b_j\) 都是容易计算的。
这表明 \(E_i(X)\) 实际是 \(k_iE_{i-1}(X) + b_i\) 的线性结构,对于一个点,在多次修改的时候,只需要每次算出 \(k_i\) 和 \(b_i\) 然后叠加计算即可。由于涉及到区间修改,很容易想到懒标记的线段树,修改完所有操作之后再一个个单点查询即可。
参考代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
typedef long long LL;
const int maxn = 2E5 + 10;
const int mod = 998244353;
void exgcd(LL a, LL b, LL &x, LL &y)
{
if(b == 0)
{
x = 1, y = 0;
return;
}
exgcd(b, a % b, y, x);
y -= a / b * x;
}
LL inv(LL a, LL p)
{
LL x, y;
exgcd(a, p, x, y);
if(x < 0)
x += p;
return x;
}
int n, m;
LL a[maxn];
struct Node {
LL mul, add;
Node() : mul(1), add(0) {}
}t[maxn * 4];
void pushdown(int p)
{
LL m = t[p].mul, a = t[p].add;
if(m == 1 && a == 0)
return;
int lch = p * 2, rch = p * 2 + 1;
t[lch].add = (m * t[lch].add + a) % mod;
t[lch].mul = m * t[lch].mul % mod;
t[rch].add = (m * t[rch].add + a) % mod;
t[rch].mul = m * t[rch].mul % mod;
t[p].mul = 1, t[p].add = 0;
}
void modify(int p, int beg, int end, int l, int r, LL m, LL a)
{
if(end < l || beg > r)
return;
if(beg >= l && end <= r)
{
t[p].mul = (t[p].mul * m) % mod;
t[p].add = (t[p].add * m + a) % mod;
return;
}
pushdown(p);
int mid = (beg + end) / 2;
int lch = p * 2, rch = p * 2 + 1;
modify(lch, beg, mid, l, r, m, a);
modify(rch, mid+1, end, l, r, m, a);
}
LL query(int p, int beg, int end, int pos)
{
if(beg == end)
return (a[beg] * t[p].mul + t[p].add) % mod;
pushdown(p);
int mid = (beg + end) / 2;
int lch = p * 2, rch = p * 2 + 1;
return pos <= mid ? query(lch, beg, mid, pos) : query(rch, mid+1, end, pos);
}
int main()
{
cin >> n >> m;
for(int i = 1; i <= n; i++)
scanf("%lld", &a[i]);
while(m--)
{
int l, r;
LL x;
scanf("%d %d %lld", &l, &r, &x);
LL a = inv(r - l + 1, mod), m = (r - l) * a % mod;
a = a * x % mod;
modify(1, 1, n, l, r, m, a);
}
for(int i = 1; i <= n; i++)
printf("%lld ", query(1, 1, n, i));
return 0;
}
import java.util.Scanner;
class Node {
public long mul, add;
public Node() {
mul = 1;
add = 0;
}
}
public class Main {
private static final int mod = 998244353;
private static long[] exgcd(long a, long b) {
if (b == 0)
return new long[] { 1, 0 };
long[] xy = exgcd(b, a % b);
long x = xy[0], y = xy[1];
xy[0] = y;
xy[1] = x - a / b * y;
return xy;
}
private static long inv(long a, long p) {
long x = exgcd(a, p)[0];
if (x < 0)
x += p;
return x;
}
private static Node[] t;
private static long[] a;
private static void build(int p, int beg, int end) {
t[p] = new Node();
if (beg == end)
return;
int mid = (beg + end) / 2;
int lch = p * 2, rch = p * 2 + 1;
build(lch, beg, mid);
build(rch, mid + 1, end);
}
private static void pushdown(int p) {
long m = t[p].mul, a = t[p].add;
if (m == 1 && a == 0)
return;
int lch = p * 2, rch = p * 2 + 1;
t[lch].add = (m * t[lch].add + a) % mod;
t[lch].mul = m * t[lch].mul % mod;
t[rch].add = (m * t[rch].add + a) % mod;
t[rch].mul = m * t[rch].mul % mod;
t[p].mul = 1;
t[p].add = 0;
}
private static void modify(int p, int beg, int end, int l, int r, long m, long a) {
if (end < l || beg > r)
return;
if (beg >= l && end <= r) {
t[p].mul = (t[p].mul * m) % mod;
t[p].add = (t[p].add * m + a) % mod;
return;
}
pushdown(p);
int mid = (beg + end) / 2;
int lch = p * 2, rch = p * 2 + 1;
modify(lch, beg, mid, l, r, m, a);
modify(rch, mid + 1, end, l, r, m, a);
}
private static long query(int p, int beg, int end, int pos) {
if (beg == end)
return (a[beg] * t[p].mul + t[p].add) % mod;
pushdown(p);
int mid = (beg + end) / 2;
int lch = p * 2, rch = p * 2 + 1;
return pos <= mid ? query(lch, beg, mid, pos) : query(rch, mid + 1, end, pos);
}
public static void main(String[] args) {
Scanner in = new Scanner(System.in);
int n = in.nextInt();
int M = in.nextInt();
t = new Node[4 * n + 4];
a = new long[n + 1];
build(1, 1, n);
for (int i = 1; i <= n; i++)
a[i] = in.nextInt();
while (M-- > 0) {
int l = in.nextInt();
int r = in.nextInt();
long x = in.nextLong();
long a = inv(r - l + 1, mod), m = (r - l) * a % mod;
a = a * x % mod;
modify(1, 1, n, l, r, m, a);
}
for (int i = 1; i <= n; i++)
System.out.printf("%d ", query(1, 1, n, i));
in.close();
}
}
G - Not Too Many Balls¶
题目大意
有 \(N(N \le 500)\) 种颜色的球,颜色的值为 \(1\) 到 \(N\),颜色为 \(i\) 的球有 \(A_i\) 个,有 \(M(M \le 5 \times 10^5)\) 个盒子,第 \(j\) 个盒子最多可以放 \(B_j\) 个球。此外,颜色为 \(i\) 的球在第 \(j\) 个盒子中最多能放 \(i \times j\) 个。问:最多可以将多少个球放在盒子中?
解题思路
可以使用网络流解决这个问题,建立一个二分图,其中超级源点记为 \(S\),超级汇点记为 \(T\),球作为二分图的左部 \(X = (x_1, x_2, \cdots, x_n)\),盒子作为二分图的右部 \(Y = (y_1, y_2, \cdots, y_m)\)。则图中有三种边:
- \(S\) 与 \(x_i\) 的连边,流值为 \(A_i\)。
- \(x_i\) 与 \(y_j\) 的连边,流值为 \(i \times j\)。
- \(y_j\) 与 \(T\) 的连边,流值为 \(B_j\)。
直接跑最大流就是答案,这样时间复杂度是没问题的,不过看看空间复杂度:第二种边有 \(NM\) 条,这样肯定是存不下的。
考虑转化成最小割的问题,假设最小割 \(S\) 可达的集合为 \(P \cup Q\),其中 \(P \sub X\),\(Q \sub Y\),则最小割的组成为:
解释一下这个公式的组成:
- \(\sum\limits_{x_i \in X\backslash P}A_i\):对于 \(X\backslash P\) 集合中的点,\(S\) 不可达,设 \(x_i \in X \backslash P\),则 \((S, x_i)\) 对应第一种边,这些边一定会被割掉(否则就存在 \(S \rightarrow x_i\))。
- \(\sum\limits_{x_i \in X}i\sum\limits_{y_j \in Y \backslash Q}j\):对于 \(X\) 集合中的点,\(S\) 可达,对于 \(Y \backslash Q\) 集合中的点,\(S\) 不可达,设 \(x_i \in X\),\(y_j \in Y \backslash Q\),则边 \((x_i, y_j)\) 对应第二种边,这些边一定会被割掉(否则就存在 \(S \rightarrow x_i \rightarrow y_j\) )。
- \(\sum\limits_{y_j \in Y}B_j\):对于 \(Y\) 集合中的点,\(S\) 不可达,设 \(y_j \in Y\),\((y_j, T)\) 对应第三种边,这些点一定会被割掉(否则就存在 \(S \rightarrow \cdots \rightarrow y_j \rightarrow T\) )。
如何算出最小割?肯定不能枚举两个子集。官方题解给出了一种非常巧妙的方法绕开了子集的枚举,注意到 \(N \le 500\)(而且题目的名字也在提示球并不多),令 \(k = \sum\limits_{x_i \in X}i\),由于 \(i\) 的取值是 \(1\) 到 \(N\),则 \(k\) 的取值就是 \(0\) 到 \(\frac{N(N+1)}{2}\)(其中 \(0\) 仅当 \(X\) 为空集的时候取得),于是考虑枚举 \(k\) 的值,则有:
为什么要这样做?\(\sum\limits_{x_i \in X}i\sum\limits_{y_j \in Y \backslash Q}j\) 这个式子耦合了两个集合,这样不方便处理,而固定前者之后就只剩下一个集合,进一步就能拆成两部分。
考虑如何算出 \(\min\limits_{k = 0}^{\frac{N(N+1)}{2}}\min\limits_{P\sub X}\{ \sum\limits_{x_i \in X\backslash P}A_i\}\) ,\(k\) 实际限制了集合中每个球颜色值的和,而这个式子可以使用 dp 预处理出来,定义 \(dp[k]\):颜色值之和为 \(k\) 的球的集合对应的 \(\sum A_i\) 的最小值,这是一个背包问题,可以在 \(O(N^2\times N) = O(N^3)\) 的时间内完成。预处理出所有第 \(dp[k]\) 值之后,对于特定的 \(k\) 值, \(\min\limits_{P\sub X}\{ \sum\limits_{x_i \in X\backslash P}A_i\} = dp[\frac{N(N+1)}{2}-k]\)
考虑如何算出 \(\min\limits_{k = 0}^{\frac{N(N+1)}{2}}\min\limits_{Q\sub Y}\{\sum\limits_{y_j \in Y \backslash Q}kj + \sum\limits_{y_j \in Y}B_j\}\),\(k\) 对 \(\sum\limits_{y_j \in Y}B_j\) 的值是没有影响的,对于某个子集 \(Q \sub Y\),如果 \(Q\) 的某个元素 \(y_j\) 有 \(kj < B_j\),则将 \(y_j\) 放在 \(Y \backslash Q\) 会更小。进一步可以发现,\(\min\limits_{Q\sub Y}\{ \sum\limits_{y_j \in Y \backslash Q}kj + \sum\limits_{y_j \in Y}B_j\} = \sum\limits_{j = 1}^m\min(kj, B_j)\),这样就不需要枚举子集 \(Q\) 了。如何计算 \(\min\limits_{k = 0}^{\frac{N(N+1)}{2}}\{\sum\limits_{j = 1}^m\min(kj, B_j) \} \\\)?这个式子是 \(O(N^2M)\) 的,还是会超时。\(\min(kj, B_j)\) 实际是一个分段函数,前半段是 \(kj\),后半段是 \(B_j\),两者都是直线,这样从 \(k\) 到 \(k+1\) 自增变化时,\(\sum\limits_{j = 1}^m\min(kj, B_j)\) 的增长也是有规律的,使用一阶差分可以维护区间加同一个数,使用 二阶差分 可以实现区间加一个等差数列(也就是一条直线),所以只需要针对所有的分段函数作出二阶差分,就可以在 \(O(1)\) 的时间内算出每个 \(\sum\limits_{j = 1}^m\min(kj, B_j)\)。
参考代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>
using namespace std;
typedef long long LL;
const int maxn = 500 + 10;
const int maxl = maxn * maxn / 2;
const int maxm = 5E5 + 10;
int n, m, l;
LL A[maxn], B[maxm], dp[maxl], c[maxl];
void add(int l, int r, LL a, LL d)
{
LL e = a + d * (r - l);
c[l] += a;
c[l+1] += d - a;
c[r+1] -= e + d;
c[r+2] += e;
}
int main()
{
cin >> n >> m;
for(int i = 1; i <= n; i++)
scanf("%lld", &A[i]);
l = n * (n + 1) / 2;
memset(dp, 0x3f, sizeof dp);
dp[0] = 0;
for(int i = 1; i <= n; i++)
for(int j = l; j >= i; j--)
dp[j] = min(dp[j], dp[j-i]+A[i]);
for(int j = 1; j <= m; j++)
{
scanf("%lld", &B[j]);
LL K = B[j] / j;
add(0, min(K, (LL)l), 0, j);
if(K+1 <= l)
add(K+1, l, B[j], 0);
}
for(int k = 1; k <= l; k++)
c[k] += c[k-1];
for(int k = 1; k <= l; k++)
c[k] += c[k-1];
LL ans = dp[l];
for(int k = 1; k <= l; k++)
ans = min(ans, dp[l-k]+c[k]);
cout << ans << endl;
return 0;
}
import java.util.Arrays;
import java.util.Scanner;
public class Main {
private static int n, m, l;
private static long[] A, B, dp, c;
private static void add(int l, int r, long a, long d) {
long e = a + d * (r - l);
c[l] += a;
c[l + 1] += d - a;
c[r + 1] -= e + d;
c[r + 2] += e;
}
public static void main(String[] args) {
Scanner in = new Scanner(System.in);
n = in.nextInt();
m = in.nextInt();
l = n * (n + 1) / 2;
A = new long[n + 1];
B = new long[m + 1];
dp = new long[l + 1];
c = new long[l + 3];
for (int i = 1; i <= n; i++)
A[i] = in.nextLong();
Arrays.fill(dp, (long) 1E18);
dp[0] = 0;
for (int i = 1; i <= n; i++)
for (int j = l; j >= i; j--)
dp[j] = Math.min(dp[j], dp[j - i] + A[i]);
for (int j = 1; j <= m; j++) {
B[j] = in.nextLong();
long K = B[j] / j;
add(0, (int) Math.min(K, (long) l), 0, j);
if (K + 1 <= l)
add((int) K + 1, l, B[j], 0);
}
for (int k = 1; k <= l; k++)
c[k] += c[k - 1];
for (int k = 1; k <= l; k++)
c[k] += c[k - 1];
long ans = dp[l];
for (int k = 1; k <= l; k++)
ans = Math.min(ans, dp[l - k] + c[k]);
System.out.println(ans);
in.close();
}
}
创建日期: 2024-03-04