1 条题解

  • 1
    @ 2022-9-5 15:25:38
    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    using namespace std;
    typedef long long ll;
    const int mod=1000000007;
    const int maxn=7;
    const int maxv=1<<12;
    int n,m,c,Q,maxx,a1[maxn],a2[maxn],nxt1[maxn],nxt2[maxn],t1[maxn][3],t2[maxn][3],f[2][maxv][maxn][maxn],cur,ans;
    char s1[maxn],s2[maxn];
    inline int getid(char x)
    {
    	if(x=='W')return 0;
    	if(x=='B')return 1;
    	if(x=='X')return 2;
    }
    int powmod(int a,int k)
    {
    	ll ret=1,x=a;
    	while(k)
    	{
    		if(k&1)ret=ret*x%mod;
    		x=x*x%mod;
    		k>>=1;
    	}
    	return (int)ret;
    }
    int main()
    {
    	scanf("%d%d%d%d",&n,&m,&c,&Q);
    	maxx=1<<(m-c+1);
    	while(Q--)
    	{
    		scanf("%s%s",s1+1,s2+1);
    		for(int i=1;i<=c;i++)a1[i]=getid(s1[i]),a2[i]=getid(s2[i]);
    		for(int i=2,j=0;i<=c;i++)
    		{
    			while(j&&a1[j+1]!=a1[i])j=nxt1[j];
    			if(a1[j+1]==a1[i])j++;
    			nxt1[i]=j;
    		}
    		for(int i=2,j=0;i<=c;i++)
    		{
    			while(j&&a2[j+1]!=a2[i])j=nxt2[j];
    			if(a2[j+1]==a2[i])j++;
    			nxt2[i]=j;	
    		}
    		for(int i=0;i<c;i++)for(int j=0,k=i;j<3;j++,k=i)
    		{
    			while(k&&a1[k+1]!=j)k=nxt1[k];
    			if(a1[k+1]==j)k++;
    			t1[i][j]=k;
    		}
    		for(int i=0;i<c;i++)for(int j=0,k=i;j<3;j++,k=i)
    		{
    			while(k&&a2[k+1]!=j)k=nxt2[k];
    			if(a2[k+1]==j)k++;
    			t2[i][j]=k;
    		}
    		memset(f[0],0,sizeof f[0]);
    		f[0][0][0][0]=1;cur=1;
    		for(int i=1;i<=n;i++)
    		{
    			memset(f[cur],0,sizeof f[cur]);
    			for(int j=0;j<maxx;j++)for(int a=0;a<c;a++)
    			for(int b=0;b<c;b++)f[cur][j][0][0]+=f[1-cur][j][a][b],f[cur][j][0][0]%=mod;
    			cur=1-cur;
    			for(int j=1;j<=m;j++)
    			{
    				memset(f[cur],0,sizeof f[cur]);
    				for(int k=0;k<maxx;k++)for(int a=0;a<c;a++)
    				for(int b=0;b<c;b++)if(f[1-cur][k][a][b])
    				for(int col=0;col<3;col++)
    				{
    					int pa=t1[a][col],pb=t2[b][col],S=k;
    					if(j>=c)if((S>>j-c)&1)S^=1<<j-c;
    					if(pa==c){S^=1<<j-c;pa=nxt1[c];}
    					if(pb==c){if((k>>j-c)&1)continue;pb=nxt2[c];}
    					f[cur][S][pa][pb]+=f[1-cur][k][a][b];
    					f[cur][S][pa][pb]%=mod;
    				}
    				cur=1-cur;
    			}
    		}
    		ans=powmod(3,n*m);
    		for(int i=0;i<maxx;i++)for(int j=0;j<c;j++)for(int k=0;k<c;k++)
    		ans=(ans-f[1-cur][i][j][k]%mod+mod)%mod;
    		printf("%d\n",ans);
    	}
    	return 0;
    }
    
    • 1

    信息

    ID
    2226
    时间
    6000ms
    内存
    512MiB
    难度
    7
    标签
    递交数
    4
    已通过
    3
    上传者