列表

详情


NC20360. [SDOI2013]直径

描述

小Q最近学习了一些图论知识。根据课本,有如下定义。
树:无回路且连通的无向图,每条边都有正整数的权值来表示其长度。如果一棵树有N个节点,可以证明其有且仅有N-1 条边。 
路径:一棵树上,任意两个节点之间最多有一条简单路径。我们用 dis(a,b) 表示点a和点b的路径上各边长度之和。称dis(a,b)为a、b两个节点间的距离。  
直径:一棵树上,最长的路径为树的直径。树的直径可能不是唯一的。 
 现在小Q想知道,对于给定的一棵树,其直径的长度是多少,以及有多少条边满足所有的直径都经过该边。

输入描述

第一行包含一个整数N,表示节点数。
接下来N-1行,每行三个整数a, b, c ,表示点a和点b之间有一条长度为c的无向边。

输出描述

共两行。第一行一个整数,表示直径的长度。
第二行一个整数,表示被所有 直径经过的边的数量。

示例1

输入:

6 
3  1 1000
1  4 10
4  2 100
4  5 50
4  6 100

输出:

1110
2

说明:

【样例说明】
直径共有两条,3 到2的路径和3到6的路径。这两条直径都经过边(3, 1)和边(1, 4)。

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++(g++ 7.5.0) 解法, 执行用时: 322ms, 内存消耗: 54792K, 提交时间: 2022-08-20 17:03:55

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=5e5+10;
struct Edge{
    Edge(int to,ll len):to(to),len(len){}
    int to;
    ll len;
};
 
vector<Edge> G[N];
ll L[N],cnt[N],d,ans;
ll cntd;
typedef pair<ll,ll> P;
void dfs(int x,int fa){
    L[x]=0;cnt[x]=1;
    vector<P> v;
    for(Edge &t:G[x]){
        int y=t.to;ll len=t.len;
        if(y==fa)continue;
        dfs(y,x);
        ll d=L[y]+len;
        v.push_back(P(d,cnt[y]));
        if(L[x]>d)continue;
        if(L[x]==d)cnt[x]+=cnt[y];
        else{
            L[x]=d;cnt[x]=cnt[y];
        }
    }
    if(v.size()==0)return;
    sort(v.begin(),v.end(),greater<P>());
    if(v.size()==1){
        if(v[0].first>d){
            d=v[0].first;
            cntd=v[0].second;
        }
        else if(v[0].first==d)
            cntd+=v[0].second;
    }
    else{
        if(v[0].first!=v[1].first){
            ll sum=v[1].second;
            for(int i=2;i<v.size();++i){
                if(v[i].first!=v[i-1].first)break;
                sum+=v[i].second;
            }
            if(v[0].first+v[1].first>d){
                d=v[0].first+v[1].first;
                cntd=v[0].second*sum;
            }else if(v[0].first+v[1].first==d)
                cntd+=v[0].second*sum;
        }
        else if(v[0].first+v[1].first>=d){
            int i=0;ll sum=v[0].second;
            while(i+1<v.size()&&v[i+1].first==v[0].first)
                sum+=v[++i].second;
            ll s=(sum*(sum-1))/2;
            while(i>=0){
                s-=(v[i].second*(v[i].second-1))/2;
                --i;
            }
            if(v[0].first+v[1].first>d){
                d=v[0].first+v[1].first;
                cntd=s;
            }
            else cntd+=s;
        }
    }
}
void dfs2(int x,int fa,ll dd,ll c){
    map<ll,ll> mp;
    mp[dd]=c;
    for(Edge &t:G[x]){
        if(t.to==fa)continue;
        mp[t.len+L[t.to]]+=cnt[t.to];
    }
    for(Edge &t:G[x]){
        if(t.to==fa)continue;
        int y=t.to;
        ll C;
        ll tmp=d-t.len-L[y];
        if(tmp==0)C=cnt[y];
        else if(tmp==t.len+L[y])C=(mp[tmp]-cnt[y])*cnt[y];
        else C=mp[tmp]*cnt[y];
        if(C==cntd)++ans;
    }
    mp.clear();
    mp[dd]=c;
    for(Edge &t:G[x]){
        if(t.to==fa)continue;
        mp[t.len+L[t.to]]+=cnt[t.to];
    }
    for(Edge &t:G[x]){
        if(t.to==fa)continue;
        int y=t.to;
        ll tmp=t.len+L[y];
        auto it=prev(mp.end());
        if(it->first!=tmp)
            dfs2(y,x,it->first+t.len,it->second);
        else{
            if(it->second-cnt[y]!=0)
                dfs2(y,x,it->first+t.len,it->second-cnt[y]);
            else{
                it=prev(it);
                dfs2(y,x,it->first+t.len,it->second);
            }
        }
    }
}
int main(){
    int n,x,y;
    ll len;
    scanf("%d",&n);
    for(int i=1;i<n;++i){
        scanf("%d%d%lld",&x,&y,&len);
        G[x].push_back(Edge(y,len));
        G[y].push_back(Edge(x,len));
    }
    dfs(1,0);
    dfs2(1,0,0,1);
    printf("%lld\n%lld",d,ans);
    return 0;
}

C++11(clang++ 3.9) 解法, 执行用时: 227ms, 内存消耗: 13796K, 提交时间: 2020-08-20 10:51:02

#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10;
long long dis[N], val[N], sum[N];
int n, pre[N], p, q, vis[N];
vector<int> vv;
struct edge{
    int to, next, vi;
}e[N<<1];

int cnt, h[N];

void add(int u, int v, int w) {
    e[cnt].to = v;
    e[cnt].vi = w;
    e[cnt].next = h[u];
    h[u] = cnt++;
}

int dfs(int u, int fa) {
    int maxpos = u;
    for (int i = h[u]; ~i; i = e[i].next) {
        int v = e[i].to;
        if (vis[v] || v == fa) continue;
        dis[v] = dis[u] + e[i].vi;
        pre[v] = u;
        int z = dfs(v, u);
        if (dis[z] > dis[maxpos]) maxpos = z;
    }
    return maxpos;
}
int main() {
    scanf("%d", &n);
    memset(h, -1, sizeof h);
    for (int i = 1; i <= n-1; i++) {
        int x, y, z;
        scanf("%d %d %d", &x, &y, &z);
        add(x, y, z), add(y, x, z);
    }
    p = dfs(1, 0);
    dis[p] = 0;
    q = dfs(p, 0);
    int now = q;
    while (now != p) {
        vv.push_back(now);
        vis[now] = 1;
        now = pre[now];
    }
    vv.push_back(p);
    vis[p] = 1;

    reverse(vv.begin(), vv.end());

    for (int i = 0; i < vv.size(); i++) sum[i] = dis[vv[i]];
    int l = 0, r = vv.size() - 1;
    for (int i = 0; i < vv.size(); i++) dis[vv[i]] = 0, val[i] = dis[dfs(vv[i], 0)];

    for (int i = 0; i < vv.size(); i++) if (sum[i] == val[i]) l = i;
    for (int i = vv.size()-1; i >= 0; i--) if (val[i] == sum[vv.size()-1]-sum[i]) r = i;
    printf("%lld\n%d", sum[vv.size()-1], r - l);
}

上一题