2016年1月31日 星期日

[SPOJ APIO10A][DP] Commando

題目連結
最單純的想法是1D/1D的DP,但是顯然不夠快。
這時候看看SPOJ的tag上寫著Convex Hull,然後就可以無恥的去做凸包優化了。
實際上是這樣子的:
    dp[i]=max(dp[j]+a*sum(j+1,i)*sum(j+1,i)+b*sum(j+1,i)+c),0<=j< i,其中sum(a,b)代表從第a個數到第b個數的總和
用力把它展開,化簡:
    dp[i]=max(dp[j] + a*(pre[i]-pre[j])^2 + b*(pre[j]-pre[i])+c)  ,0<=j<i
       ==>(a*pre[i]^2+b*pre[i]+c)+max((-2)*a*pre[j]*pre[i]+(dp[j]+a*pre[j]^2-b*pre[j]))  ,0<=j<i
與決策點j有關的項是(dp[j]+a*pre[j]^2-b*pre[j])和(-2)*a*pre[j]*pre[i],前者是已經計算好的常數,而後者與pre[i]有著線性的關係。

所以可以把決策點j看成是條斜率是pre[j],截距是(dp[j]+a*pre[j]^2-b*pre[j])的直線,i如果從j轉移相當於把pre[i]代入j這條直線上得到一個值dp[i]。
我們的任務變成維護這些直線,使我們能快速查詢最佳的j。

首先觀察到pre陣列是遞增的,所以這些直線的斜率遞增,也就是最新的線一定會在凸包上,只需要往前面刪除那些因為這條線的出現而從凸包消失的線。
查詢可以二分搜實現,O(NlgN)。

最後一個優化是由於pre[i]也是遞增的,也就是我們要查詢的x會遞增,那麼經過的直線就可以直接刪掉了。達到了均攤O(1)查詢最大值。總時間O(n)。
#define PB push_back
#define LL long long
#include<bits/stdc++.h>
using namespace std;
const int maxn=1000000+5;
struct Line
{
    LL slope,inter;
    LL getVal(LL x) {return slope*x+inter;}
    Line(LL a,LL b):inter(a),slope(b){}
};
deque<Line> dq;
LL N,a,b,c,pre[maxn];
LL dp[maxn];
bool check(Line& L0,Line& L1,Line& L2)
{
    return (L0.inter-L2.inter)*(L1.slope-L0.slope)<=(L0.inter-L1.inter)*(L2.slope-L0.slope);
}
void join(Line L)
{
    while(dq.size()>1 && check(dq[dq.size()-2],dq[dq.size()-1],L)) dq.pop_back();
    dq.PB(L);
}
LL getMax(int i){return dq[0].getVal(pre[i]);}
void solve()
{
    memset(dp,0,sizeof dp);
    dq.clear();
    dq.PB((Line){0,0});
    for(int i=1;i<=N;i++)
    {
        while(dq.size()>1 && dq[0].getVal(pre[i])<dq[1].getVal(pre[i])) dq.pop_front();
        dp[i]=a*pre[i]*pre[i]+b*pre[i]+c+getMax(i);// Max(dp[j]+a*pre[j]*pre[j]-b*pre[j]-2*a*pre[j]*pre[i]),0<=j<i
        // inter:dp[i]+a*pre[i]*pre[i]-b*pre[i],slope:(-2)*a*pre[i]
        join(Line(dp[i]+a*pre[i]*pre[i]-b*pre[i],(-2)*a*pre[i]));
    }
    printf("%lld\n",dp[N]);
}
int main()
{
    int T;
    scanf("%d",&T);
    while(T--)
    {
        scanf("%lld%lld%lld%lld",&N,&a,&b,&c);
        pre[0]=0;
        for(LL i=1,x;i<=N;i++) {scanf("%lld",&x);pre[i]=pre[i-1]+x;}
        solve();
    }
}

沒有留言:

張貼留言