1 条题解

  • 1
    @ 2022-8-24 9:19:53
    #include<bits/stdc++.h>
    
    #define For(i,_beg,_end) for(int i=(_beg),i##end=(_end);i<=i##end;++i)
    #define Rep(i,_beg,_end) for(int i=(_beg),i##end=(_end);i>=i##end;--i)
    
    template<typename T>T Max(const T &x,const T &y){return x<y?y:x;}
    template<typename T>T Min(const T &x,const T &y){return x<y?x:y;}
    template<typename T>int chkmax(T &x,const T &y){return x<y?(x=y,1):0;}
    template<typename T>int chkmin(T &x,const T &y){return x>y?(x=y,1):0;}
    template<typename T>void read(T &x){
    	T f=1;char ch=getchar();
    	for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
    	for(x=0;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
    	x*=f;
    }
    
    typedef long long LL;
    const int N=1010,mod=998244353;
    int n,m;
    LL a,b,p,q,dp[N][N],g[N][N],pw[N];
    LL A[N],f[N<<1];
    
    LL power(LL,LL);
    LL Solve(int);
    
    int main(){
    	read(n);read(m);read(a);read(b);
    	p=a*power(b,mod-2)%mod;q=(mod+1-p)%mod;
    	pw[0]=1;
    	For(i,1,m) pw[i]=pw[i-1]*p%mod;
    	
    	printf("%lld\n",(Solve(m)-Solve(m-1)+mod)%mod);
    	return 0;
    }
    
    LL Solve(int k){
    	memset(dp,0,sizeof dp);
    	memset(g,0,sizeof g);
    	For(i,1,k+2) g[0][i]=dp[0][i]=1;
    	For(i,1,k) Rep(j,k/i+1,2){
    		For(l,1,i) dp[i][j]=(dp[i][j]+g[l-1][j+1]*g[i-l][j]%mod*pw[l-1]%mod*q)%mod;
    		g[i][j]=(g[i][j+1]*pw[i]+dp[i][j])%mod;
    	}
    	memset(A,0,sizeof A);
    	For(i,0,k) A[i+1]=q*g[i][2]%mod*pw[i]%mod;
    	memset(f,0,sizeof f);
    	f[0]=1;
    	For(i,1,k){
    		f[i]=g[i][2]*pw[i]%mod;
    		For(j,1,i) f[i]=(f[i]+A[j]*f[i-j])%mod;
    	}
    	k++;
    	For(i,k,k<<1) For(j,1,k) f[i]=(f[i]+A[j]*f[i-j])%mod;
    	if(n<=k)return f[n];
    	int y=n-k,len=1,L=0;
    	LL res[N<<2],tmp[N<<2],x[N<<2];
    	memset(res,0,sizeof res);
    	memset(x,0,sizeof x);
    	res[0]=1;x[1]=1;
    	for(;y;y>>=1){
    		if(y&1){
    			For(i,0,L+len) tmp[i]=0;
    			For(i,0,L) For(j,0,len) tmp[i+j]=(tmp[i+j]+res[i]*x[j])%mod;
    			L+=len;
    			Rep(i,L,k) For(j,1,k) tmp[i-j]=(tmp[i-j]+tmp[i]*A[j])%mod;
    			chkmin(L,k-1);
    			For(i,0,L) res[i]=tmp[i];
    		}
    		For(i,0,len+len) tmp[i]=0;
    		For(i,0,len) For(j,0,len) tmp[i+j]=(tmp[i+j]+x[i]*x[j])%mod;
    		len<<=1;
    		Rep(i,len,k) For(j,1,k) tmp[i-j]=(tmp[i-j]+tmp[i]*A[j])%mod;
    		chkmin(len,k-1);
    		For(i,0,len) x[i]=tmp[i];
    	}
    	LL ans=0;
    	For(i,0,k-1) ans=(ans+res[i]*f[i+k])%mod;
    	return ans;
    }
    LL power(LL x,LL y){
    	LL res=1;
    	for(;y;y>>=1,x=x*x%mod) if(y&1) res=res*x%mod;
    	return res;
    }
    
    • 1

    信息

    ID
    2757
    时间
    1000~3000ms
    内存
    125MiB
    难度
    7
    标签
    递交数
    2
    已通过
    2
    上传者