1 条题解
-
1
贴,看不懂算了#include<bits/stdc++.h> using namespace std; typedef long long ll; typedef vector<int> vi; const int mod=998244353; int add(int x,int y){ x+=y; return (x<mod ? x : x-mod); } int qpow(int n,int m){ int s=n,ans=1; while (m){ if (m&1)ans=(ll)ans*s%mod; s=(ll)s*s%mod,m>>=1; } return ans; } namespace Poly{ const int N=19; int n,tn,inv[1<<N],w[N][1<<N],iw[N][1<<N]; void init(int g){ inv[0]=inv[1]=1; for(int i=2;i<(1<<N);i++)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod; for(int i=0;i<N;i++){ w[i][0]=iw[i][0]=1; w[i][1]=qpow(g,(mod-1>>i+1)); iw[i][1]=qpow(w[i][1],mod-2); for(int j=2;j<(1<<i);j++){ w[i][j]=(ll)w[i][1]*w[i][j-1]%mod; iw[i][j]=(ll)iw[i][1]*iw[i][j-1]%mod; } } } void get_n(int m){ n=1,tn=0; while (n<m)n<<=1,tn++; } void dft(int *a){ for(int i=n,t=0;i>1;i>>=1,t++){ int *W=w[tn-t-1]; for(int j=0;j<n;j+=i) for(int k=0;k<(i>>1);k++){ int x=a[j+k],y=a[j+k+(i>>1)]; a[j+k]=add(x,y); a[j+k+(i>>1)]=(ll)(x-y+mod)*W[k]%mod; } } } void idft(int *a){ for(int i=2,t=0;i<=n;i<<=1,t++){ int *W=iw[t]; for(int j=0;j<n;j+=i) for(int k=0;k<(i>>1);k++){ int x=a[j+k],y=(ll)W[k]*a[j+k+(i>>1)]%mod; a[j+k]=add(x,y),a[j+k+(i>>1)]=add(x,mod-y); } } int inv=qpow(n,mod-2); for(int i=0;i<n;i++)a[i]=(ll)inv*a[i]%mod; } vi mul(vi a,vi b,int ma,int mb,int m){ if (ma<0)ma=a.size(); if (mb<0)mb=b.size(); if (m<0)m=ma+mb-1; ma=min(ma,m),mb=min(mb,m); get_n(ma+mb-1); a.resize(n),b.resize(n); for(int i=ma;i<n;i++)a[i]=0; for(int i=mb;i<n;i++)b[i]=0; dft(a.data()),dft(b.data()); for(int i=0;i<n;i++)a[i]=(ll)a[i]*b[i]%mod; idft(a.data()),a.resize(m); return a; } vi get_inv(vi a,int m){ if (m==1)return vi{qpow(a[0],mod-2)}; vi s=get_inv(a,(m+1>>1)),ans; get_n(m<<1); a.resize(n),s.resize(n),ans.resize(n); for(int i=m;i<n;i++)a[i]=0; dft(a.data()),dft(s.data()); for(int i=0;i<n;i++)ans[i]=(ll)s[i]*(mod+2-(ll)a[i]*s[i]%mod)%mod; idft(ans.data()),ans.resize(m); return ans; } vi get_ln(vi a,int m){ if (m==1)return vi{0}; vi ans(m-1); for(int i=1;i<m;i++)ans[i-1]=(ll)i*a[i]%mod; ans=mul(ans,get_inv(a,m),m-1,m,m); for(int i=m-1;i;i--)ans[i]=(ll)inv[i]*ans[i-1]%mod; ans[0]=0; return ans; } vi get_exp(vi a,int m){ if (m==1)return vi{1}; vi s=get_exp(a,(m+1>>1)),ans; s.resize(m),ans=get_ln(s,m); ans[0]=add(1,mod-ans[0]); for(int i=1;i<m;i++)ans[i]=add(a[i],mod-ans[i]); return mul(s,ans,(m+1>>1),m,m); } vi get_pow(vi a,int m,int k){ if (!k){ for(int i=0;i<m;i++)a[i]=(!i); return a; } int t=0; while ((t<m)&&(!a[t]))t++; if ((ll)t*k>=m)return vi(m,0); int s1=qpow(a[t],mod-2),s2=qpow(a[t],k); for(int i=t;i<m;i++)a[i-t]=(ll)s1*a[i]%mod; a=get_ln(a,m-t); for(int i=0;i<m-t;i++)a[i]=(ll)k*a[i]%mod; a=get_exp(a,m-t),a.resize(m); t*=k; for(int i=m-1;i>=t;i--)a[i]=(ll)s2*a[i-t]%mod; for(int i=0;i<t;i++)a[i]=0; return a; } }; int n,m,k,ans; int main(){ Poly::init(3); scanf("%d%d",&n,&k); m=(n+1>>1); if (k>m){puts("0");return 0;} if ((n==1)&&(k==1)){puts("1");return 0;} int s=1;vi v0(n,0),v1(n,0); for(int i=1;i<m;i++){ s=(ll)s*i%mod; v0[i]=s,v1[i]=(ll)s*i%mod; } vi v=Poly::mul(Poly::get_pow(v0,n,k),Poly::get_pow(v1,n,m-k),n,n,n); ans=2LL*s*v[n-1]%mod,s=1; for(int i=1;i<=k;i++)s=(ll)s*i%mod; for(int i=1;i<=m-k;i++)s=(ll)s*i%mod; ans=(ll)ans*qpow(s,mod-2)%mod; printf("%d\n",ans); return 0; }
信息
- ID
- 19611
- 时间
- 8000ms
- 内存
- 1024MiB
- 难度
- 7
- 标签
- (无)
- 递交数
- 4
- 已通过
- 2
- 上传者