列表

详情


NC20564. [JSOI2015]字符串树

描述

萌萌买了一颗字符串树的种子,春天种下去以后夏天就能长出一棵很大的字 符串树。字符串树很奇特,树枝上都密密麻麻写满了字符串,看上去很复杂的样子。
【问题描述】 字符串树本质上还是一棵树,即N个节点N-1条边的连通无向无环图,节点从1到N编号。与普通的树不同的是,树上的每条边都对应了一个字符串。萌萌和JYY在树下玩的时候,萌萌决定考一考JYY。每次萌萌都写出一个字符串S和两个节点U,V,需要JYY立即回答U和V之间的最短路径(即,之间边数最少的路径。由于给定的是一棵树,这样的路径是唯一的)上有多少个字符串以为前缀。 JYY虽然精通编程,但对字符串处理却不在行。所以他请你帮他解决萌萌的难题。

输入描述

输入第一行包含一个整数N,代表字符串树的节点数量。
接下来N-1行,每行先是两个数U,V,然后是一个字符串S,表示节点和U节 点V之间有一条直接相连的边,这条边上的字符串是S。输入数据保证给出的是一 棵合法的树。
接下来一行包含一个整数Q,表示萌萌的问题数。 接来下Q行,每行先是两个数U,V,然后是一个字符串S,表示萌萌的一个问 题是节点U和节点V之间的最短路径上有多少字符串以S为前缀。
输入中所有字符串只包含a-z的小写字母。
1 ≤ N,Q ≤ 100,000,且输入所有字符串长度不超过10。

输出描述

输出Q行,每行对应萌萌的一个问题的答案。

示例1

输入:

4
1 2 ab
2 4 ac
1 3 bc
3
1 4 a
3 4 b
3 2 ab

输出:

2
1
1

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++11(clang++ 3.9) 解法, 执行用时: 321ms, 内存消耗: 111984K, 提交时间: 2019-03-16 12:39:03

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
using namespace std;
struct node{
    int x,y,next;
    char s[12];
}a[200010];int len,last[100010];
struct trnode{
    int a[26],c;
}tr[1000005];int tot=0,root[100010];
int fa[100010][20],dep[100010];
int n,m;
void ins(int x,int y,char *s)
{
    a[++len].y=y;strcpy(a[len].s,s);
    a[len].next=last[x];last[x]=len;
}
int solve(int x,int y)
{
    if(dep[x]<dep[y]) swap(x,y);
    for(int i=17;i>=0;i--)
        if((1<<i)<=dep[x]-dep[y]) x=fa[x][i];
    if(x==y) return x;
    for(int i=17;i>=0;i--)
        if((1<<i)<=dep[x]&&fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}
void insert(int &x,int froot,int i,int len,char *s)
{
    x=++tot;tr[x]=tr[froot];
    tr[x].c++;
    if(i==len) return;
    int c=s[i]-'a';
    insert(tr[x].a[c],tr[froot].a[c],i+1,len,s);
}
void dfs(int x,int f)
{
    dep[x]=dep[f]+1;fa[x][0]=f;
    for(int i=1;(1<<i)<=dep[x];i++)
        fa[x][i]=fa[fa[x][i-1]][i-1];
    for(int i=last[x];i;i=a[i].next)
    {
        int y=a[i].y;
        if(y==f) continue;
        insert(root[y],root[x],0,strlen(a[i].s),a[i].s);
        dfs(y,x);
    }
}
char s[15];
int get(int x,int i,int len,char *s)
{
    if(!x) return 0;
    if(i==len) return tr[x].c;
    int c=s[i]-'a';
    return get(tr[x].a[c],i+1,len,s);
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<n;i++)
    {
        int x,y;char s[12];
        scanf("%d %d %s",&x,&y,s);
        ins(x,y,s);ins(y,x,s);
    }
    dep[0]=-1;dfs(1,0);
    scanf("%d",&m);
    while(m--)
    {
        int x,y;scanf("%d %d",&x,&y);
        int lca=solve(x,y);
        scanf("%s",s);
        printf("%d\n",get(root[x],0,strlen(s),s)+get(root[y],0,strlen(s),s)-2*get(root[lca],0,strlen(s),s));
    }
}

上一题