【LOJ3042】【ZJOI2019】麻将

题目链接

【LOJ3042】【ZJOI2019】麻将

做法

我们可以用 $ dp[i][j][k] $ 表示枚举到第 $ i $ 种牌,没有组成面子的 $ i - 1 $ 种牌有 $ j $ 个, $ i $ 种牌有 $ k $ 个。
然后再开一维表示是否有雀头,七对子再开一维特判即可。
然后暴力搜索所有 $ dp $ 状态,发现状态数只有 $ S = 3956 $ 种。
考虑摸 $ i $ 牌,计算所有大小为 $ 13+i $ 的牌集中不能胡牌的集合数 $ X $ 和总集合数 $ Y $ ,那么 $ \frac{X}{Y} $ 就是权值大于 $ i $ 的概率, $ \sum{\frac{X}{Y}} $ 即为权值的期望。
设 $ f[i][j][t] $ 表示处理前 $ i $ 种牌,选了 $ j $ 张牌, $ dp $ 状态编号为 $ t $ ,转移就枚举下一种牌张数即可。
时间复杂度 $ O(n^2 S) $ 。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include <bits/stdc++.h>
#define mp make_pair
#define fst first
#define snd second
using namespace std;
typedef long long ll;
const int mod = 998244353;
const int N = 110, S = 4010;

inline int add(const int &x, const int &y) {
return x + y < mod ? x + y : x + y - mod;
}
inline int sub(const int &x, const int &y) {
return x - y < 0 ? x - y + mod : x - y;
}
inline int mul(const int &x, const int &y) { return (int)((ll)x * y % mod); }
int ksm(int x, int y = mod - 2) {
int ss = 1; for(; y; y >>= 1, x = mul(x, x)) if(y & 1) ss = mul(ss, x);
return ss;
}
inline int Min(const int &x, const int &y) { return x < y ? x : y; }
inline int Max(const int &x, const int &y) { return x > y ? x : y; }

struct state { int f[3][3]; }; int Count = 0; map<state, int> ma;
bool operator<(const state &x, const state &y) {
for(int i = 0; i < 3; i++) for(int j = 0; j < 3; j++)
if(x.f[i][j] ^ y.f[i][j]) return x.f[i][j] < y.f[i][j];
return 0;
}
state cipher() {
state res;
for(int i = 0; i < 3; i++) for(int j = 0; j < 3; j++) res.f[i][j] = -1;
return res;
}
state starter() { state res = cipher(); res.f[0][0] = 0; return res; }
state operator+(const state &x, const state &y) {
state res;
for(int i = 0; i < 3; i++) for(int j = 0; j < 3; j++)
res.f[i][j] = Max(x.f[i][j], y.f[i][j]);
return res;
}
state operator+(const state &x, const int &y) {
state res = cipher();
for(int i = 0; i <= 2 && i <= y; i++)
for(int j = 0; j <= 2 && i + j <= y; j++) {
if(x.f[i][j] == -1) continue;
int tt = x.f[i][j];
for(int k = 0; k <= 2 && i + j + k <= y; k++)
res.f[j][k] = Max(res.f[j][k], Min(tt + i + (y - i - j - k) / 3, 4));
}
return res;
}
void dfs(state u) {
if(ma.count(u)) return ;
ma[u] = ++Count; for(int i = 0; i <= 4; i++) dfs(u + i);
}
typedef pair<pair<state, state>, int> mahjong;
bool ed[S]; mahjong states[S]; map<mahjong, int> id;
int tot = 0, nxt[5][S];

pair<state, state> operator+(const pair<state, state> &x, const int &y) {
if(y >= 2) return mp(x.fst + y, (x.snd + y) + (x.fst + (y - 2)));
return mp(x.fst + y, x.snd + y);
}
mahjong operator+(const mahjong &x, const int &y) {
return mp(x.fst + y, Min(x.snd + (y > 1), 7));
}
mahjong inception() { return mp(mp(starter(), cipher()), 0); }
void getstate(mahjong u) {
if(id.count(u)) return ; id[u] = ++tot, states[tot] = u;
for(int i = 0; i <= 4; i++) getstate(u + i);
}
bool check(mahjong u) {
if(u.snd >= 7) return 1;
for(int i = 0; i < 3; i++) for(int j = 0; j < 3; j++)
if(u.fst.snd.f[i][j] >= 4) return 1;
return 0;
}
void init() {
dfs(starter()), getstate(inception());
for(int i = 1; i <= tot; i++) {
ed[i] = check(states[i]);
for(int j = 0; j < 5; j++) nxt[j][i] = id[states[i] + j];
}
}

int n, used[N], ans = 0, C[5][5], f[N][4 * N][S];
int main() {
init(), scanf("%d", &n);
for(int i = 1, x, y; i <= 13; i++) scanf("%d%d", &x, &y), ++used[x];
C[0][0] = 1;
for(int i = 1; i < 5; i++) {
C[i][0] = 1;
for(int j = 1; j <= i; j++) C[i][j] = C[i - 1][j] + C[i - 1][j - 1];
}
f[0][0][1] = 1;
for(int i = 0; i < n; i++) for(int j = 0; j <= 4 * i; j++)
for(int k = 1; k <= tot; k++) {
if(!f[i][j][k]) continue;
for(int t = used[i + 1]; t < 5; t++)
f[i + 1][j + t][nxt[t][k]] = add(f[i + 1][j + t][nxt[t][k]],
mul(f[i][j][k], C[4 - used[i + 1]][t - used[i + 1]]));
}
for(int i = 13; i <= 4 * n; i++) {
int sum = 0, cnt = 0;
for(int j = 1; j <= tot; j++) {
sum = add(sum, f[n][i][j]); if(!ed[j]) cnt = add(cnt, f[n][i][j]);
}
ans = add(ans, mul(ksm(sum), cnt));
}
printf("%d\n", ans);
return 0;
}