题目链接
题意简述
定义若二元组 \((a,b)<(c,d)\),则有 \(a>c\) 或者 \(a=c \text{ and } b<d\)。
给你 \(n\) 个数,第 \(i\) 个数是 \(a_i\)。
现在系统会生成 \(n\) 个二元组,第 \(i\) 个二元组是 \((a_i,i)\) 或者是 \((a_i+1,i)\),然后给每个二元组一个排名,设排名数列为 \(rk\)。
求有多少种不同的 \(rk\) 数列,两个 \(rk\) 序列不同当且仅当存在至少一个位置 \(x\),使得两个 \(rk\) 数列的 \(rk_x\) 不同。
\(n\leq 5\times 10^5,a_i\leq 10^6\),时空限制:2s 256M。
样例
输入:
输出:
解释:有以下八种二元组的搭配方式,以及其相应的 \(rk\) 序列。
\((1,1),(2,2),(2,3)\ \ \ \ rk:3,1,2\)
\((1,1),(2,2),(3,3)\ \ \ \ rk:3,2,1\)
\((1,1),(3,2),(2,3)\ \ \ \ rk:3,1,2\)
\((1,1),(3,2),(3,3)\ \ \ \ rk:3,1,2\)
\((2,1),(2,2),(2,3)\ \ \ \ rk:1,2,3\)
\((2,1),(2,2),(3,3)\ \ \ \ rk:2,3,1\)
\((2,1),(3,2),(2,3)\ \ \ \ rk:2,1,3\)
\((2,1),(3,2),(3,3)\ \ \ \ rk:3,1,2\)
不同的 \(rk\) 序列一共有 \(5\) 个。
题解
比赛的时候一直认为是一个组合数学题目,没想到是一个 DP。
先下一个定义:若两个由前 \(i-1\) 个二元组形成的序列,无论 \([i,n]\) 中的每个二元组是 \((a_x,x)\) 还是 \((a_x+1,x)\),这两个序列最终形成的排名序列都是相同的,那么我们把它们称作不可区分的。
首先把所有二元组按照 \((a_i,i)\) 排序。为了方便,下文默认输入就已经保证了 \((a_i,i)<(a_{i+1},i+1)\)。
考虑 DP,设 \(f_i\) 表示前 \(i-1\) 个二元组组成的可区分的序列个数。
那么第 \(i\) 个位置既可以放 \((a_i,i)\) 也可以放 \((a_i+1,i)\),所以有 \(f_{i+1}=2f_i\),但是这样显然会算重,所以考虑容斥。
考虑对于一个位置 \(i\),找到一个极短的区间 \([l,r](l\leq i\leq r)\),它需要满足存在一种在 \([l,i-1],[i+1,r]\) 放置二元组的方案,使得第 \(i\) 个位置放 \((a_i,i)\) 和放 \((a_i+1,i)\) 是不可区分的。这样的话,这种放置方法的贡献本来只有 \(1\) 但是被算了两次,所以 \(f_r\) 就需要减去 \(f_l\)。
因为这个区间是极短的,所以可以恰好不重不漏地把所有的不合法的方案都减掉,最终的 \(f_{n+1}\) 即为答案。
不难发现,因为我们一开始排好了序,所以 \(l,r\) 是单调递增的,\(i\) 也是我们单调枚举的,所以 DP 的总复杂度是 \(O(n)\) 的。(有趣的是,复杂度的瓶颈在于排序)。
代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
| #include <bits/stdc++.h> #define debug(...) fprintf(stderr, __VA_ARGS__) #define RI register int typedef long long LL;
#define FILEIO(name) freopen(name".in", "r", stdin), freopen(name".out", "w", stdout);
using namespace std;
namespace IO { char buf[1000000], *p1 = buf, *p2 = buf; inline char gc() { if (p1 == p2) p2 = (p1 = buf) + fread(buf, 1, 1000000, stdin); return p1 == p2 ? EOF : *(p1++); } template <class T> inline void read(T &n) { n = 0; RI ch = gc(), f; while ((ch < '0' || ch > '9') && ch != '-') ch = gc(); f = (ch == '-' ? ch = gc(), -1 : 1); while (ch >= '0' && ch <= '9') n = n * 10 + (ch ^ 48), ch = gc(); n *= f; } char Of[105], *O1 = Of, *O2 = Of; template <class T> inline void print(T n, char ch = '\n') { if (n < 0) putchar('-'), n = -n; if (n == 0) putchar('0'); while (n) *(O1++) = (n % 10) ^ 48, n /= 10; while (O1 != O2) putchar(*(--O1)); putchar(ch); } }
using IO :: read; using IO :: print;
int const MAXN = 5e5 + 5; int const mod = 998244353; struct Node { int x, id; Node (int _x = 0, int _id = 0) { x = _x, id = _id; } bool operator < (const Node &A) const { return x ^ A.x ? x > A.x : id < A.id; }; } a[MAXN]; int f[MAXN];
int main() { #ifdef LOCAL FILEIO("a"); #endif
int n; read(n); for (RI i = 1; i <= n; ++i) read(a[i].x), a[i].id = i; sort(a + 1, a + 1 + n); f[1] = 1; int l = 1, r = 1; for (RI i = 1; i <= n; ++i) { while (l < i && (Node(a[l].x, a[l].id) < Node(a[i].x + 1, a[i].id))) ++l; while (r <= n && (Node(a[r].x + 1, a[r].id) < Node(a[i].x, a[i].id))) ++r; f[r] = (f[r] + mod - f[l]) % mod; f[i + 1] = ((f[i] + f[i]) % mod + f[i + 1]) % mod; } print(f[n + 1]);
return 0; }
|