您的当前位置:首页正文

Palindromic Tree

2024-11-07 来源:个人技术集锦

一个很厉害的算法,具体怎么实现的不是很清楚,本着不求甚解的态度,知道怎么用就好了。

next[][]:类似于字典树,指向当前字符串在两段同时加上一个字符

fail[] fail指针:类似于AC自动机,返回失配后与当前i结尾的最长回文串本质上不同的最长回文后缀

cnt[]: 在最后统计后它可以表示形如以i为结尾的回文串中最长的那个串个数

num[]: 表示以i结尾的回文串的种类数

len[]: 表示以i为结尾的最长回文串长度

s[]: 存放添加的字符

last: 表示上一个添加的字符的位置

n: 表示字符数组的第几位

p: 表示树中节点的指针

本质不同的回文字符串:p-2 (减去两个根节点)
统计所有回文串的个数 ∑ 2 p − 1 n u m [ i ] \sum^{p-1}_{2} num[i] 2p1num[i]


对于每一个节点p来说,他的nextt节点都是以p节点为子串的回文串,他的fail节点都是p的回文子串,那么假设对于p来说,如果它向下有numn[p],向上有numc[p]个,那么这个节点的贡献就是numn[p]*numc[p];

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 1000 * 100 + 10;
const int MAXN = 100005;
char s[MAXN];

struct Palindromic_Tree{
	int nextt[N][26];
	int cnt[N];
	int fail[N];
	int len[N];
	int num[N];
	int s[N];
	int n;
	int p;
	int last;

	int newnode(int lent) {
		for (int i = 0; i < 26; i++)nextt[p][i] = 0;
		cnt[p] = 0;
		len[p] = lent;
		num[p] = 0;
		return p++;
	}

	void init() {
		p = 0; 
		newnode(0); newnode(-1);
		last = 0;
		n = 0;
		s[n] = -1;
		fail[0] = 1;
	}

	int get_fail(int x) {
		while (s[n - len[x] - 1] != s[n])x = fail[x];
		return x;
	}

	void add(int c) {
		c -= 'a';
		s[++n] = c;
		int cur = get_fail(last);
		if (!nextt[cur][c]) {
			int now = newnode(len[cur]+2);
			fail[now] = nextt[get_fail(fail[cur])][c];
			nextt[cur][c] = now;
			num[now] = num[fail[now]] + 1;
		}
		last = nextt[cur][c];
		cnt[last]++;
	}

	void count() {
		for (int i = p - 1; i >= 0; i--)cnt[fail[i]] += cnt[i];
	}
	int numc[N], numn[N],vis[N];

	int dfs(int x) {
		numn[x] = 1;
		numc[x] = 0;
		for (int t = x; !vis[t] && t >1; t = fail[t])vis[t] = x,numc[x]++;
		for (int i = 0; i < 26; i++) {
			if (nextt[x][i] == 0)continue;
			numn[x] += dfs(nextt[x][i]);
		}
		for (int t = x; vis[t] ==x && t >1; t = fail[t]) vis[t] = 0;
		return numn[x];
	}

	ll solve() {
		ll ans = 0;
		dfs(0); dfs(1);
		for (int i = 2; i < p; i++)ans += 1LL * numc[i] * numn[i];
		return ans-p+2;
	}
}T;

int main() {
	int t; scanf("%d",&t);
	int casen=1;
	while (t--){
		scanf("%s",s);
		int len = strlen(s);
		T.init();
		for (int i = 0; i < len; i++)T.add(s[i]);
		printf("Case #%d: %lld\n",casen++,T.solve());
	}
	return 0;
}
Top