首先老規矩,把三個字串用神秘的符號分隔後拼接起來,構造SuffixArray和Height陣列
相似的後綴會被排在一起,而且我們可以很快知道在SuffixArray的一群連續後綴的LCP(by H陣列)
假設在SuffixArray中有一群連續的後綴,他們的LCP=K,其中有a個後綴屬於第一個字串,b個屬於第二個字串,c個屬於第三個字串,那麼這堆後綴對於第K個答案的貢獻就是a*b*c
如果枚舉L,每次重新掃描一次H陣列算答案的話Complexity會是O(n^2),顯然是要TLE的
觀察之後可以發現對於當前的L而言,有用的H是那些>=L的H,而這樣的H在L遞減的情況下只會慢慢變多不會變少,對應到程式中就是一段段的有用區間會隨著L遞減不斷合併,剛好disjoint set可以維護這樣的操作,就不客氣的抓來用吧(?
細節上比較麻煩的就是答案的計算和小心SuffixArray不要寫爛(?
#define LL long long #include<bits/stdc++.h> using namespace std; const LL maxn=1000000+5; const LL mod=1000000007; struct SuffixArray { char s[maxn]; LL minlen,n,c[maxn],rank[maxn],pri[maxn],sa[maxn]; void build_sa(LL m) { n=strlen(s);s[n++]=0; for(LL i=0;i<m;i++) c[i]=0; for(LL i=0;i<n;i++) c[rank[i]=s[i]]++; for(LL i=1;i<m;i++) c[i]+=c[i-1]; for(LL i=n-1;i>=0;i--) sa[--c[rank[i]]]=i; for(LL k=1;k<=n;k<<=1) { LL p=0; for(LL i=n-k;i<n;i++) pri[p++]=i; for(LL i=0;i<n;i++) if(sa[i]>=k) pri[p++]=sa[i]-k; for(LL i=0;i<m;i++) c[i]=0; for(LL i=0;i<n;i++) c[rank[pri[i]]]++; for(LL i=1;i<m;i++) c[i]+=c[i-1]; for(LL i=n-1;i>=0;i--) sa[--c[rank[pri[i]]]]=pri[i]; swap(rank,pri); p=0;rank[sa[0]]=p++; for(LL i=1;i<n;i++) rank[sa[i]]=(pri[sa[i]]==pri[sa[i-1]])&&(pri[sa[i]+k]==pri[sa[i-1]+k])?p-1:p++; if(p==n) break; m=p; } } LL H[maxn]; void buildH() { for(LL i=0;i<n;i++) rank[sa[i]]=i; for(LL i=0,ans=0;i<n;i++) { if(ans) ans--; if(rank[i]==0) {H[rank[i]]=0;continue;} LL j=sa[rank[i]-1]; while(s[i+ans]==s[j+ans]) ans++; H[rank[i]]=ans; } } LL l[3],ans[maxn],fa[maxn],cnt[maxn][3],vis[maxn]; vector<LL> p[maxn],used; inline LL findset(LL x) {return fa[x]==x?x:fa[x]=findset(fa[x]);} inline LL getX(LL x){return x<l[1]?0:(x<l[2]?1:2);} void solve() { LL now=0; for(LL i=0;i<n;i++) {fa[i]=i;p[H[i]].push_back(i);} for(LL L=n;L>=1;L--) { for(LL j=0;j<p[L].size();j++) { LL P=p[L][j];vis[P]=1;used.push_back(P); cnt[P][getX(sa[P])]++; if(P>0 && vis[P-1]) { LL front=findset(P-1); cnt[front][getX(sa[front-1])]++; (now-=(cnt[front][0]*cnt[front][1]*cnt[front][2])%mod)%=mod; cnt[front][getX(sa[front-1])]--; for(LL i=0;i<3;i++) (cnt[front][i]+=cnt[P][i])%=mod; cnt[front][getX(sa[front-1])]++; (((now+=(cnt[front][0]*cnt[front][1]*cnt[front][2])%mod)%=mod)+=mod)%=mod; cnt[front][getX(sa[front-1])]--; fa[P]=front; } if(P<n-1 && vis[P+1])// merge to back { LL me=findset(P); // to sa[me-1] cnt[P+1][getX(sa[P])]++;cnt[me][getX(sa[me-1])]++; (now-=((cnt[P+1][0]*cnt[P+1][1]*cnt[P+1][2])%mod+(cnt[me][0]*cnt[me][1]*cnt[me][2])%mod)%mod)%=mod; cnt[P+1][getX(sa[P])]--;cnt[me][getX(sa[me-1])]--; for(LL i=0;i<3;i++) (cnt[me][i]+=cnt[P+1][i])%=mod; cnt[me][getX(sa[me-1])]++; (((now+=(cnt[me][0]*cnt[me][1]*cnt[me][2])%mod)%=mod)+=mod)%=mod; cnt[me][getX(sa[me-1])]--; fa[P+1]=me; } } ans[L]=now; } for(LL i=1;i<=minlen;i++) printf("%I64d%c",ans[i],i==minlen?'\n':' '); } }SA; char tmp[maxn]; int main() { LL trash='W';SA.n=0; scanf("%s",SA.s);SA.l[0]=SA.n;SA.n=strlen(SA.s);SA.minlen=SA.n; for(LL i=1;i<=2;i++) { SA.s[SA.n++]=trash++; SA.l[i]=SA.n; scanf("%s",tmp); LL n=strlen(tmp);SA.minlen=min(SA.minlen,n); for(LL j=0;j<n;j++) SA.s[SA.n++]=tmp[j]; } SA.build_sa('z'+5);SA.buildH();SA.solve(); }
沒有留言:
張貼留言