1 条题解

  • 1
    @ 2022-8-27 9:16:38
    #include <bits/stdc++.h>
    using namespace std;
    const int maxn = 5, maxm = 55, maxl = 155, mod = 11192869, mo = 500000;
    typedef long long ll;
    typedef unsigned int uint;
    int a[maxn][maxm], L[maxl];
    int n, m, pos[maxn], plug[maxn], head[2][500010], tot[2], cur, pre, chc[maxl], ans;
    struct State{
    	bitset<maxl> used;
    	uint stt;
    	int val, nxt;
    	State() { val = nxt = stt = 0; used.reset(); }
    }ptr[2][1001000];
    void hah(uint stt, int val, bitset<maxl> &used) {
    	int x = stt % mo;
    	for(int i = head[cur][x]; i; i = ptr[cur][i].nxt) if(ptr[cur][i].stt == stt) {
    		ptr[cur][i].val = (ptr[cur][i].val + val) % mod; return;
    	}
    	ptr[cur][++tot[cur]].stt = stt;
    	ptr[cur][tot[cur]].val = val;
    	ptr[cur][tot[cur]].used = used;
    	ptr[cur][tot[cur]].nxt = head[cur][x];
    	head[cur][x] = tot[cur];
    }
    uint encode() {
    	uint stt = 0;
    	for(int i = 1; i <= n; i++) stt = (stt << 8) + pos[i];
    	for(int i = 0; i <= n; i++) stt = (stt << 2) + plug[i];
    	return stt;
    }
    void decode(uint stt) {
    	for(int i = n; i >= 0; i--) plug[i] = stt & 3, stt >>= 2;
    	for(int i = n; i; i--) pos[i] = stt & 255, stt >>= 8; 
    }
    void solve() {
    	bitset<maxl> used;
    	used.reset();
    	cur = 0; pre = 1; hah(0, 1, used);
    	for(int j = 1; j <= m; j++) {
    		// 新的一行要把plug整体右移
    		for(int t = 1; t <= tot[cur]; t++) {
    			decode(ptr[cur][t].stt);
    			for(int i = n - 1; i >= 0; i--) plug[i + 1] = plug[i]; 
    			plug[0] = 0;
    			ptr[cur][t].stt = encode();
    		}
    		for(int i = 1; i <= n; i++) {
    			swap(cur, pre); tot[cur] = 0;
    			memset(head[cur], 0, sizeof(head[cur]));
    			for(int t = 1; t <= tot[pre]; t++) {
    				uint stt = ptr[pre][t].stt; 
    				int val = ptr[pre][t].val;
    				used = ptr[pre][t].used;
    				decode(stt);
    				int r = plug[i - 1], d = plug[i];
    				int cnt = 0;
    				if(!r && !d) for(int i = 1; i <= n * m; i++) chc[++cnt] = i;
    				else {
    					if(r == 1) chc[++cnt] = pos[i-1] - 1;
    					else if(r == 2) chc[++cnt] = pos[i-1] + 1;
    					if(d == 1) chc[++cnt] = pos[i] - 1;
    					else if(d == 2) chc[++cnt] = pos[i] + 1;
    				}
    				// 当前位置可能会填哪些数
    				sort(chc + 1, chc + 1 + cnt);
    				cnt = unique(chc + 1, chc + 1 + cnt) - chc - 1;
    				for(int hh = 1; hh <= cnt; hh++) {
    					int x = chc[hh]; // 枚举当前位置填的数,判断是否合法
    					if(a[i][j] != L[x]) continue; if(used[x]) continue;
    					if(r == 1 && x != pos[i - 1] - 1) continue;
    					if(r == 2 && x != pos[i - 1] + 1) continue;
    					if(d == 1 && x != pos[i] - 1) continue;
    					if(d == 2 && x != pos[i] + 1) continue;
    					if(x == 1 && i > 1 && i < n && j > 1 && j < m) continue;
    					if(i == n && j == m) ans = (ans + val) % mod;
    					used[x] = 1; int od = pos[i]; pos[i] = x;
    					// cout << x << endl;
    					for(int npr = 0; npr <= 2; npr++) 
    						for(int npd = 0; npd <= 2; npd++) {
    							// 枚举新的插头,判断是否合法,这部分我写的比较冗杂,或许可以精简一下
    							int pnum = (r > 0) + (d > 0) + (npr > 0) + (npd > 0);
    							if(x != 1 && x != n * m && pnum != 2) continue;
    							if((x == 1 || x == n * m) && pnum != 1) continue;
    							if(npr == npd && npr) continue;
    							if(j == m && npr) continue; if(i == n && npd) continue;
    							if((npr == 1 || npd == 1) && used[x - 1]) continue;
    							if((npr == 2 || npd == 2) && used[x + 1]) continue;
    							if(npr == 1 && a[i][j+1] != L[x - 1]) continue;
    							if(npr == 2 && a[i][j+1] != L[x + 1]) continue;
    							if(npd == 1 && a[i+1][j] != L[x - 1]) continue;
    							if(npd == 2 && a[i+1][j] != L[x + 1]) continue;
    							// 当前转移合法,更新下一位置的状态和dp值
    							plug[i - 1] = npr; plug[i] = npd;
    							hah(encode(), val, used);
    							plug[i - 1] = r; plug[i] = d;
    						}
    					used[x] = 0; pos[i] = od;
    				}
    			}
    		}
    	}
    }
    int main() {
    	// printf("%lf\n", (double)(&b2-&b1)/1024/1024);
    	// freopen("trip.in", "r", stdin);
    	// freopen("trip.out", "w", stdout);
    	scanf("%d%d", &n, &m);
    	for(int i = 1; i <= n; i++) for(int j = 1; j <= m; j++) scanf("%d", &a[i][j]);
    	for(int i = 1; i <= n * m; i++) scanf("%d", &L[i]);
    	L[0] = L[n * m + 1] = 521;
    	for(int i = 0; i <= m + 1; i++) a[0][i] = a[n + 1][i] = 233;
    	for(int i = 1; i <= n; i++) a[i][0] = a[i][m + 1] = 233;
    	solve();
    	printf("%d\n", ans);
    	return 0;
    }
    
    • 1

    信息

    ID
    897
    时间
    3000~6000ms
    内存
    125MiB
    难度
    7
    标签
    递交数
    4
    已通过
    3
    上传者