题目链接

这题的方法口糊一下没有很难,没达到3500的水准。但是写起来才发现是真的恶心(主要是容易写错),没写过这么累的题,可能难度就体现在这里吧。

计数的时候是要分类讨论的,但是核心算法都一样:启发式合并,线段树合并。把\(m^2\)对路径分成以下三类,分别统计合法的:

  • 两条路径的LCA不同(路径的LCA指的是两个端点的LCA)。发现这两条路径的LCA必须是祖先和后代的关系,不然两条路径不可能有重合。

    比如图中的红蓝两条路径就属于这一类,考虑在×处(下面两个端点的LCA)把它们统计进答案。可以在dfs的同时用线段树合并维护所有 有端点在子树内的路径的LCA的深度。在合并两个儿子的时候,把线段树中值的数量较小的拿出来,遍历其中所有的元素,并在大的那个儿子的线段树中询问得到能和当前元素匹配的数量。这部分的复杂度是\(O(nlog^2n)\)。由于n和m同阶,都用n表示了。

  • 两条路径的LCA相同,且它们重合的部分分布在LCA的两个子树中。像下面这样:

    这种情况和下面的一种情况都需要把所有LCA为x的路径都放到点x处,统一处理它们之间产生的贡献。假设现在处理LCA为root的所有的路径。把这些路径的端点以及root都拿出来建一棵虚树。为了避免重复计数,对于任意两条需要被计数的路径,我们都在它们在原树中dfs序较小的两个端点的LCA处统计,比如上面图中的×处。还是用线段树合并+启发式合并,但这次线段树中只维护每条路径dfs序较小的那个端点的信息。令当前点为pos,在遍历较小的儿子线段树中的一条路径(x,y)时,假设x在pos子树内,y在root的另外一个子树内,则如果我们沿着x→y的方向走k步到点z,那么合法的匹配路径的端点都在z的子树内。同样可以在线段树上查询来统计。

  • 两条路径的LCA相同,且它们重合的部分分布在LCA的一个子树中。

    这种情况的统计方法和上面是类似的。为了保证重合部分只在一个子树内,需要一次额外dfs对每个点求出它在root的哪个子树里。

总时间复杂度\(O(nlog^2n)\)

调试太痛苦了

点击查看代码
#include <bits/stdc++.h>

#define rep(i,n) for(int i=0;i<n;++i)
#define repn(i,n) for(int i=1;i<=n;++i)
#define LL long long
#define pii pair <LL,LL>
#define fi first
#define se second
#define mpr make_pair
#define pb push_back

void fileio()
{
  #ifdef LGS
  freopen("in.txt","r",stdin);
  freopen("out.txt","w",stdout);
  #endif
}
void termin()
{
  #ifdef LGS
  std::cout<<"\n\nEXECUTION TERMINATED";
  #endif
  exit(0);
}

using namespace std;

LL n,q,t,fa[150010][23],dep[150010],dfn[150010],ed[150010],ans=0,X[150010],Y[150010],LCA[150010];
vector <LL> g[150010],tg[150010],dford;

LL ll=0;
void dfsPre(int pos,int par,int d)
{
  fa[pos][0]=par;dep[pos]=d;dford.pb(pos);
  dfn[pos]=ll++;
  rep(i,g[pos].size()) if(g[pos][i]!=par) dfsPre(g[pos][i],pos,d+1);
  ed[pos]=ll-1;
}

int getLCA(int x,int y)
{
  for(int i=19;i>=0;--i) if(fa[x][i]>0&&dep[fa[x][i]]>=dep[y]) x=fa[x][i];
  for(int i=19;i>=0;--i) if(fa[y][i]>0&&dep[fa[y][i]]>=dep[x]) y=fa[y][i];
  if(x==y) return x;
  for(int i=19;i>=0;--i) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
  return fa[x][0];
}
LL getAnces(LL x,LL y){rep(i,20) if(y&(1<<i)) x=fa[x][i];return x;}

namespace st//线段树合并
{
  LL n2,dat[10000000],len;
  int ls[10000000],rs[10000000];
  void init(LL nn)
  {
    n2=1;while(n2<nn) n2*=2;
    len=0;
  }
  LL newNode()
  {
    dat[++len]=0;ls[len]=rs[len]=0;
    return len;
  }
  LL newTree(LL lb,LL ub,LL to)
  {
    LL ret=newNode();dat[ret]=1;
    if(lb==ub) return ret;
    LL mid=(lb+ub)>>1;
    if(to<=mid) ls[ret]=newTree(lb,mid,to);
    else rs[ret]=newTree(mid+1,ub,to);
    return ret;
  }
  LL upd(LL k,LL lb,LL ub,LL to)
  {
    if(k==0) k=newNode();
    ++dat[k];
    if(lb==ub) return k;
    LL mid=(lb+ub)>>1;
    if(to<=mid) ls[k]=upd(ls[k],lb,mid,to);
    else rs[k]=upd(rs[k],mid+1,ub,to);
    return k;
  }
  vector <LL> res;
  void getAll(LL k,LL lb,LL ub)
  {
    if(k==0) return;
    if(lb==ub)
    {
      rep(i,dat[k]) res.pb(lb);
      return;
    }
    LL mid=(lb+ub)>>1;
    getAll(ls[k],lb,mid);getAll(rs[k],mid+1,ub);
  }
  vector <LL> getAll(LL root)
  {
    res.clear();
    getAll(root,0,n2-1);
    return res;
  }
  LL qry(LL k,LL lb,LL ub,LL tlb,LL tub)
  {
    if(k==0||ub<tlb||tub<lb) return 0;
    if(tlb<=lb&&ub<=tub) return dat[k];
    return qry(ls[k],lb,(lb+ub)>>1,tlb,tub)+qry(rs[k],((lb+ub)>>1)+1,ub,tlb,tub);
  }
  LL merge(LL a,LL b)
  {
    if(a==0||b==0) return a|b;
    dat[a]+=dat[b];
    ls[a]=merge(ls[a],ls[b]);rs[a]=merge(rs[a],rs[b]);
    return a;
  }
}

namespace part1
{
  vector <LL> v[150010];
  LL combine(LL a,LL b,LL curdep)
  {
    if(a==0||b==0) return a|b;
    if(st::dat[a]<st::dat[b]) swap(a,b);
    vector <LL> vec=st::getAll(b);
    rep(i,vec.size())
    {
      if(vec[i]>curdep-t) continue;
      LL v1=st::qry(a,0,st::n2-1,0,vec[i]-1),v2=st::qry(a,0,st::n2-1,vec[i]+1,curdep-t);
      ans+=v1+v2;
    }
    a=st::merge(a,b);
    return a;
  }
  LL dfs(LL pos,LL par)
  {
    LL ret=0;
    rep(i,v[pos].size())
    {
      LL nxt=st::newTree(0,st::n2-1,v[pos][i]);
      ret=combine(ret,nxt,dep[pos]);
    }
    rep(i,g[pos].size()) if(g[pos][i]!=par)
    {
      LL nxt=dfs(g[pos][i],pos);
      ret=combine(ret,nxt,dep[pos]);
    }
    return ret;
  }
  void countDiffLCA()
  {
    rep(i,q)
    {
      v[X[i]].pb(dep[LCA[i]]);
      v[Y[i]].pb(dep[LCA[i]]);
    }
    st::init(n);
    dfs(1,0);
  }
}

namespace part2
{
  vector <pii> pths[150010];
  LL curroot,rootdep;
  vector <LL> realver;
  void buildVT(vector <LL> vers)
  {
    realver.clear();
    rep(i,vers.size()) tg[vers[i]].clear();
    sort(vers.begin(),vers.end());vers.erase(unique(vers.begin(),vers.end()),vers.end());
    sort(vers.begin(),vers.end(),[](LL xx,LL yy){return dfn[xx]<dfn[yy];});
    stack <LL> stk;stk.push(vers[0]);
    realver=vers;
    repn(i,vers.size()-1)
    {
      LL pos=vers[i],lca=getLCA(pos,stk.top());
      if(lca==stk.top()) stk.push(pos);
      else
      {
        while(dep[stk.top()]>dep[lca])
        {
          int pp=stk.top();stk.pop();
          int nn=stk.top();if(dep[nn]<dep[lca]) nn=lca,tg[lca].clear(),realver.pb(lca);
          tg[nn].pb(pp);
        }
        if(stk.top()!=lca) stk.push(lca);
        stk.push(pos);
      }
    }
    while(stk.size()>1)
    {
      int pp=stk.top();stk.pop();
      tg[stk.top()].pb(pp);
    }
  }
  vector <LL> v[150010];
  LL fr[150010];
  LL walk(LL curpos,LL to,LL stp)
  {
    LL rd=dep[getLCA(curpos,to)];
    LL tot=dep[curpos]+dep[to]-rd*2;
    if(tot<stp) return -1;
    if(stp<=dep[curpos]-rd) return getAnces(curpos,stp);
    return getAnces(to,tot-stp);
  }
  LL combineTwo(LL a,LL b,LL curpos)
  {
    if(a==0||b==0) return a|b;
    if(st::dat[a]<st::dat[b]) swap(a,b);
    vector <LL> vec=st::getAll(b);rep(i,vec.size()) vec[i]=dford[vec[i]];
    rep(i,vec.size())
    {
      LL walkdist=max(t,dep[curpos]-rootdep+1),to=walk(curpos,vec[i],walkdist);
      if(to==-1) continue;
      LL vv=st::qry(a,0,st::n2-1,dfn[to],ed[to]);
      ans+=vv;
    }
    a=st::merge(a,b);
    return a;
  }
  LL dfsTwo(LL pos)
  {
    LL ret=0;
    rep(i,v[pos].size())
    {
      LL nxt=st::newTree(0,st::n2-1,dfn[v[pos][i]]);
      if(pos!=curroot) ret=combineTwo(ret,nxt,pos);
    }
    rep(i,tg[pos].size())
    {
      LL nxt=dfsTwo(tg[pos][i]);
      if(pos!=curroot) ret=combineTwo(ret,nxt,pos);
    }
    return ret;
  }
  void dfsMarkFr(LL pos,LL mk)
  {
    if(mk==-1&&pos!=curroot) mk=dfn[pos];
    fr[pos]=mk;
    rep(i,tg[pos].size()) dfsMarkFr(tg[pos][i],mk);
  }
  LL combineOne(LL a,LL b)
  {
    if(a==0||b==0) return a|b;
    if(st::dat[a]<st::dat[b]) swap(a,b);
    vector <LL> vec=st::getAll(b);
    rep(i,vec.size())
    {
      if(vec[i]==dfn[curroot])
      {
        ans+=st::dat[a];
        continue;
      }
      LL v1=st::qry(a,0,st::n2-1,0,vec[i]-1),v2=st::qry(a,0,st::n2-1,vec[i]+1,st::n2-1);
      ans+=v1+v2;
    }
    a=st::merge(a,b);
    return a;
  }
  LL dfsOne(LL pos)
  {
    LL ret=0;
    rep(i,v[pos].size())
    {
      LL nxt=st::newTree(0,st::n2-1,fr[v[pos][i]]);
      if(dep[pos]-rootdep>=t) ret=combineOne(ret,nxt);
    }
    rep(i,tg[pos].size())
    {
      LL nxt=dfsOne(tg[pos][i]);
      if(dep[pos]-rootdep>=t) ret=combineOne(ret,nxt);
    }
    return ret;
  }
  void countSameLCA()
  {
    rep(i,q)
    {
      if(dfn[X[i]]>dfn[Y[i]]) swap(X[i],Y[i]);
      pths[LCA[i]].pb(mpr(X[i],Y[i]));
    }
    repn(root,n) if(pths[root].size())
    {
      curroot=root;rootdep=dep[root];
      vector <LL> vers={root};
      rep(i,pths[root].size()) vers.pb(pths[root][i].fi),vers.pb(pths[root][i].se);
      buildVT(vers);
      rep(i,realver.size()) v[realver[i]].clear();
      rep(i,pths[root].size()) if(pths[root][i].fi!=root&&pths[root][i].se!=root) v[pths[root][i].fi].pb(pths[root][i].se);
      st::init(n);
      dfsTwo(root);

      dfsMarkFr(root,-1);fr[root]=dfn[root];
      rep(i,realver.size()) v[realver[i]].clear();
      rep(i,pths[root].size()) v[pths[root][i].fi].pb(pths[root][i].se),v[pths[root][i].se].pb(pths[root][i].fi);
      st::init(n);
      dfsOne(root);
    }
  }
}

int main()
{
  fileio();

  cin>>n>>q>>t;
  LL x,y;
  rep(i,n-1)
  {
    scanf("%lld%lld",&x,&y);
    g[x].pb(y);g[y].pb(x);
  }
  dfsPre(1,0,0);
  rep(i,20) repn(j,n) fa[j][i+1]=fa[fa[j][i]][i];
  rep(i,q)
  {
    scanf("%lld%lld",&X[i],&Y[i]);
    LCA[i]=getLCA(X[i],Y[i]);
  }
  part1::countDiffLCA();
  part2::countSameLCA();
  cout<<ans<<endl;

  termin();
}