[ARC195E] Ramdom Tree Distance

好像大家都在推 \(\operatorname{lca}\) 深度固定什么的,这篇题解来个比较粗暴的推法。


首先你考虑直接 dp:对于 \(i<j\),令 \(f(i,j)\) 表示 \(i\)\(j\) 的期望距离。

这里我们将 \(j\) 不断地跳父亲,直到跳到 \(\le i\) 的点。

那么贡献变成了两部分:

  • 编号 \(>i\) 的贡献:显然 \(a_j\) 一定有贡献。考虑 \(a_k\;(i<k<j)\),其当且仅当曾经跳到过 \(k\) 时产生贡献。那么这个等价于「第一次跳到 \(\le k\) 的点时恰好跳到了 \(k\)」,这个概率显然是 \(\dfrac1k\),所以期望贡献是 \(\dfrac{a_k}k\)
  • 编号 \(\le i\) 的贡献:如果最后跳到 \(i\) 就是 \(0\),否则跳到 \(k\;(k<i)\) 就是 \(f(k,i)\)

所以有转移方程:

\[f(i,j)=a_j+\left(\sum_{k=i+1}^{j-1}\frac{a_k}k\right)+ \left(\frac1i\sum_{k=1}^{i-1}f(k,i)\right)\]

你发现最后面那坨和 \(j\) 完全没关系,那么可以考虑将 \(j\) 这一维去掉。

观察 \(f\)\(j\) 这一维的差分:

\[f(i,j)-f(i,j-1)=a_j-\frac{j-2}{j-1}a_{j-1}\]

你发现这玩意只跟 \(j\) 有关,不妨记为 \(C_j\)。另外这个式子在 \(j\ge i+2\) 时成立,所以有:

\[f(i,j)=f(i,i+1)+\sum_{k=i+2}^jC_k\]

现在 \(f\) 里面已经没有 \(j\) 了,那么令 \(F_i=f(i,i+1)\)

\[\begin{aligned}F_i&=a_i+\frac1i\sum_{k=1}^{i-1}f(k,i)\\ &=a_i+\frac1i\sum_{k=1}^{i-1}\left(F_k+\sum_{l=k+2}^iC_l\right) \\&=a_i+\frac1i\left[\left(\sum_{k=1}^{i-1}F_k\right)+\left(\sum_{l=3}^i\sum_{k=1}^{l-2}C_l\right)\right] \\&=a_i+\frac1i\left[\left(\sum_{k=1}^{i-1}F_k\right)+\left(\sum_{l=3}^i(l-2)\cdot C_l\right)\right]\end{aligned}\]

那么你发现后面的两个求和都可以边求边推,所以 \(F\) 就可以 \(\mathcal O(n)\) 求出了。

那你每次询问就是一个 \(F\) 值加一段 \(C\) 的和(再乘一个 \((n-1)!\),因为是求和),可以 \(\mathcal O(1)\) 求出。

于是做完了。时间复杂度 \(\mathcal O(n+q)\)

代码很短。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
// Author: YE Minghan
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define endl '\n'
#define PB emplace_back
#define PPB pop_back
#define MP make_pair
#define ALL(Name) Name.begin(),Name.end()
#define PII pair<int,int>
#define VI vector<int>
#define GI greater<int>
#define fi first
#define se second

const int N=200005,MOD=998244353;
int n,q,a[N];
int inv[N],tfc;
int f[N],s[N],ss[N];

int main()
{
ios::sync_with_stdio(false),cin.tie(nullptr);
// int _;cin>>_;while(_--)

cin>>n>>q,inv[1]=tfc=1;
for(int i=2;i<=n;i++)cin>>a[i],a[i]%=MOD,inv[i]=1ll*(MOD-MOD/i)*inv[MOD%i]%MOD,tfc=(i-1ll)*tfc%MOD;
for(int i=2;i<=n;i++)ss[i]=(ss[i-1]+(s[i]=(a[i]+(MOD-i+2ll)*a[i-1]%MOD*inv[i-1])%MOD))%MOD;
int t=0;
for(int i=1;i<n;i++)
{
f[i]=(1ll*t*inv[i]+a[i+1])%MOD;
t=(t+(i-1ll)*s[i+1]+f[i])%MOD;
}
while(q--)
{
int x,y;
cin>>x>>y;
cout<<1ll*tfc*(0ll+f[x]+(ss[y]-ss[x+1]+MOD))%MOD<<endl;
}

return 0;
}