2016年1月25日 星期一

[Codeforces 452E][MemSQL Start[c]UP 2.0 - Round 1][Suffix Array] Three strings

原題連結

首先老規矩,把三個字串用神秘的符號分隔後拼接起來,構造SuffixArray和Height陣列

相似的後綴會被排在一起,而且我們可以很快知道在SuffixArray的一群連續後綴的LCP(by H陣列)

假設在SuffixArray中有一群連續的後綴,他們的LCP=K,其中有a個後綴屬於第一個字串,b個屬於第二個字串,c個屬於第三個字串,那麼這堆後綴對於第K個答案的貢獻就是a*b*c

如果枚舉L,每次重新掃描一次H陣列算答案的話Complexity會是O(n^2),顯然是要TLE的

觀察之後可以發現對於當前的L而言,有用的H是那些>=L的H,而這樣的H在L遞減的情況下只會慢慢變多不會變少,對應到程式中就是一段段的有用區間會隨著L遞減不斷合併,剛好disjoint set可以維護這樣的操作,就不客氣的抓來用吧(?

細節上比較麻煩的就是答案的計算和小心SuffixArray不要寫爛(?

#define LL long long
#include<bits/stdc++.h>
using namespace std;
const LL maxn=1000000+5;
const LL mod=1000000007;
struct SuffixArray
{
    char s[maxn];
    LL minlen,n,c[maxn],rank[maxn],pri[maxn],sa[maxn];
    void build_sa(LL m)
    {
        n=strlen(s);s[n++]=0;
        for(LL i=0;i<m;i++) c[i]=0;
        for(LL i=0;i<n;i++) c[rank[i]=s[i]]++;
        for(LL i=1;i<m;i++) c[i]+=c[i-1];
        for(LL i=n-1;i>=0;i--) sa[--c[rank[i]]]=i;
        for(LL k=1;k<=n;k<<=1)
        {
            LL p=0;
            for(LL i=n-k;i<n;i++) pri[p++]=i;
            for(LL i=0;i<n;i++) if(sa[i]>=k) pri[p++]=sa[i]-k;

            for(LL i=0;i<m;i++) c[i]=0;
            for(LL i=0;i<n;i++) c[rank[pri[i]]]++;
            for(LL i=1;i<m;i++) c[i]+=c[i-1];
            for(LL i=n-1;i>=0;i--) sa[--c[rank[pri[i]]]]=pri[i];
            
            swap(rank,pri);
            p=0;rank[sa[0]]=p++;
            for(LL i=1;i<n;i++) rank[sa[i]]=(pri[sa[i]]==pri[sa[i-1]])&&(pri[sa[i]+k]==pri[sa[i-1]+k])?p-1:p++;
            if(p==n) break;
            m=p;
        }
    }

    LL H[maxn];
    void buildH()
    {
        for(LL i=0;i<n;i++) rank[sa[i]]=i;
        for(LL i=0,ans=0;i<n;i++)
        {
            if(ans) ans--;
            if(rank[i]==0) {H[rank[i]]=0;continue;}
            LL j=sa[rank[i]-1];
            while(s[i+ans]==s[j+ans]) ans++;
            H[rank[i]]=ans;
        }
    }
    LL l[3],ans[maxn],fa[maxn],cnt[maxn][3],vis[maxn];
    vector<LL> p[maxn],used;
    inline LL findset(LL x) {return fa[x]==x?x:fa[x]=findset(fa[x]);}
    inline LL getX(LL x){return x<l[1]?0:(x<l[2]?1:2);}
    void solve()
    {
        LL now=0;
        for(LL i=0;i<n;i++) {fa[i]=i;p[H[i]].push_back(i);}
        for(LL L=n;L>=1;L--)
        {
            for(LL j=0;j<p[L].size();j++)
            {
                LL P=p[L][j];vis[P]=1;used.push_back(P);  
                cnt[P][getX(sa[P])]++;
                if(P>0 && vis[P-1])
                {
                    LL front=findset(P-1);
                    cnt[front][getX(sa[front-1])]++;
                    (now-=(cnt[front][0]*cnt[front][1]*cnt[front][2])%mod)%=mod;
                    cnt[front][getX(sa[front-1])]--;
                    for(LL i=0;i<3;i++) (cnt[front][i]+=cnt[P][i])%=mod;
                    cnt[front][getX(sa[front-1])]++;
                    (((now+=(cnt[front][0]*cnt[front][1]*cnt[front][2])%mod)%=mod)+=mod)%=mod;
                    cnt[front][getX(sa[front-1])]--;
                    fa[P]=front;
                }
                if(P<n-1 && vis[P+1])// merge to back
                {
                    LL me=findset(P); // to sa[me-1]
                    cnt[P+1][getX(sa[P])]++;cnt[me][getX(sa[me-1])]++;
                    (now-=((cnt[P+1][0]*cnt[P+1][1]*cnt[P+1][2])%mod+(cnt[me][0]*cnt[me][1]*cnt[me][2])%mod)%mod)%=mod;
                    cnt[P+1][getX(sa[P])]--;cnt[me][getX(sa[me-1])]--;
                    for(LL i=0;i<3;i++) (cnt[me][i]+=cnt[P+1][i])%=mod;
                    cnt[me][getX(sa[me-1])]++;
                    (((now+=(cnt[me][0]*cnt[me][1]*cnt[me][2])%mod)%=mod)+=mod)%=mod;
                    cnt[me][getX(sa[me-1])]--;
                    fa[P+1]=me;
                }
            }
            ans[L]=now;
        }
        for(LL i=1;i<=minlen;i++) printf("%I64d%c",ans[i],i==minlen?'\n':' ');
    }
}SA;
char tmp[maxn];
int main()
{
    LL trash='W';SA.n=0;
    scanf("%s",SA.s);SA.l[0]=SA.n;SA.n=strlen(SA.s);SA.minlen=SA.n;
    for(LL i=1;i<=2;i++)
    {
        SA.s[SA.n++]=trash++;
        SA.l[i]=SA.n;
        scanf("%s",tmp);
        LL n=strlen(tmp);SA.minlen=min(SA.minlen,n);
        for(LL j=0;j<n;j++) SA.s[SA.n++]=tmp[j];
    }
    SA.build_sa('z'+5);SA.buildH();SA.solve();
}

沒有留言:

張貼留言