题解:CF886E Maximum Element
正难则反,考虑长度为 i i i 的排列得到正确的结果的方案数。
设 d p i dp_i dpi 表示长度为 i i i 的排列直到循环完也没有提前 return
的方案数。考虑 i i i 所放置的位置,由于不会提前 return
,也就说明该数字所在的位置为 [ i − k + 1 , i ] [i - k + 1,i] [i−k+1,i] 的范围中。因此可以枚举 i i i 的位置为 j j j,则相当于将 [ 1 , i ] [1,i] [1,i] 的区间分为 [ 1 , j − 1 ] , [ j ] , [ j + 1 , i ] [1,j - 1],[j],[j + 1,i] [1,j−1],[j],[j+1,i]。
第一段为 i − 1 i - 1 i−1 个数字中选择 j − 1 j - 1 j−1 个,也就是 ( i − 1 j − 1 ) \binom{i-1}{j-1} (j−1i−1),然后合法的方案数为 d p j − 1 dp_{j - 1} dpj−1;第二段放最大值 i i i,第三段还剩下 i − j i - j i−j 个数字,随意放置,也就是 ( i − j ) ! (i - j)! (i−j)!。虽然说 d p i dp_i dpi 的状态考虑的是排列,但是显然我们只需要考虑数字之间的相对大小,因此第一段的方案数是合理的。可以得到以下转移:
d p i = ∑ j = i − k + 1 i ( i − 1 j − 1 ) × d p j − 1 × ( i − j ) ! dp_i=\sum_{j=i-k+1}^i \binom{i-1}{j-1}\times dp_{j-1}\times (i-j)! dpi=j=i−k+1∑i(j−1i−1)×dpj−1×(i−j)!
尝试进行化简,可以得到:
d p i = ∑ j = i − k + 1 i ( i − 1 ) ! ( j − 1 ) ! × ( ( i − 1 ) − ( j − 1 ) ) ! × d p j − 1 × ( i − j ) ! = ∑ j = i − k + 1 i ( i − 1 ) ! ( j − 1 ) ! × d p j − 1 = ( i − 1 ) ! × ∑ j = i − k i − 1 d p j j ! dp_i = \sum_{j=i-k+1}^i \frac{(i - 1)!}{(j - 1)! \times ((i - 1) - (j - 1))!}\times dp_{j-1}\times (i-j)! \\ = \sum_{j=i-k+1}^i\frac{(i - 1)!}{(j - 1)!} \times dp_{j - 1} \\ = (i - 1)! \times \sum_{j=i-k}^{i - 1} \frac{dp_j}{j!} dpi=j=i−k+1∑i(j−1)!×((i−1)−(j−1))!(i−1)!×dpj−1×(i−j)!=j=i−k+1∑i(j−1)!(i−1)!×dpj−1=(i−1)!×j=i−k∑i−1j!dpj
维护一段长度为 k k k 的 d p i i ! \frac{dp_i}{i!} i!dpi 的和即可 O ( n ) O(n) O(n) 求出 d p i dp_i dpi。
最后再考虑答案。若最后求得的答案是正确的,我们只需要枚举 n n n 所在的位置即可。因此总共合法的方案数为:
a n s = ∑ i = 1 n ( n − 1 i − 1 ) × d p i − 1 × ( n − i ) ! ans = \sum_{i = 1}^n \binom {n - 1}{i - 1} \times dp_{i - 1} \times (n - i)! ans=i=1∑n(i−1n−1)×dpi−1×(n−i)!
最后的答案就是 n ! − a n s n!-ans n!−ans。代码如下:
#include <bits/stdc++.h>
#define init(x) memset (x,0,sizeof (x))
#define ll long long
#define ull unsigned long long
#define INF 0x3f3f3f3f
#define pii pair <int,int>
using namespace std;
const int MAX = 1e6 + 5;
const int MOD = 1e9 + 7;
inline int read ();
int n,k;ll tot,sum,dp[MAX],f[MAX],inv[MAX];
ll qpow (ll x,ll y)
{ll res = 1;while (y){if (y & 1) res = res * x % MOD;x = x * x % MOD;y >>= 1;}return res;
}
int main ()
{n = read ();k = read ();inv[0] = f[0] = 1;for (int i = 1;i <= n;++i) f[i] = f[i - 1] * i % MOD;inv[n] = qpow (f[n],MOD - 2);for (int i = n - 1;i;--i) inv[i] = inv[i + 1] * (i + 1) % MOD;dp[0] = sum = 1;for (int i = 1;i <= n;++i){ dp[i] = f[i - 1] * sum % MOD;sum = (sum + dp[i] * inv[i] % MOD) % MOD;if (i >= k) sum = (sum - dp[i - k] * inv[i - k] % MOD + MOD) % MOD;}for (int i = 1;i <= n;++i) tot = (tot + dp[i - 1] * f[n - 1] % MOD * inv[i - 1] % MOD) % MOD;printf ("%lld\n",(f[n] - tot + MOD) % MOD);return 0;
}
inline int read ()
{int s = 0;int f = 1;char ch = getchar ();while ((ch < '0' || ch > '9') && ch != EOF){if (ch == '-') f = -1;ch = getchar ();}while (ch >= '0' && ch <= '9'){s = s * 10 + ch - '0';ch = getchar ();}return s * f;
}