上次编辑到这里,代码来自缓存 点击恢复默认模板
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
* };
*/
class Solution {
public:
int countPairs(TreeNode* root, int distance) {
}
};
golang 解法, 执行用时: 20 ms, 内存消耗: 6.8 MB, 提交时间: 2022-12-07 18:12:04
/**
* Definition for a binary tree node.
* type TreeNode struct {
* Val int
* Left *TreeNode
* Right *TreeNode
* }
*/
func countPairs(root *TreeNode, distance int) (res int) {
// 后序, 处理返回来的左右到叶子的路径数组, 返回来一次, 数组里的数都要+1
// 遍历查找小于distance的距离和即可
var dfs func(root *TreeNode)[]int
dfs = func(root *TreeNode) []int{
if root == nil{return []int{}}
if root.Left == nil && root.Right==nil{
return []int{0}
}
lPath:=dfs(root.Left)
for i:=range lPath{lPath[i]++}
rPath:=dfs(root.Right)
for i:=range rPath{rPath[i]++}
for _,l:=range lPath{
for _,r:=range rPath{
if l+r <=distance{
res++
}
}
}
return append(lPath, rPath...)
}
dfs(root)
return
}
java 解法, 执行用时: 2 ms, 内存消耗: 41.9 MB, 提交时间: 2022-12-07 18:11:35
/**
* Definition for a binary tree node.
* public class TreeNode {
* int val;
* TreeNode left;
* TreeNode right;
* TreeNode() {}
* TreeNode(int val) { this.val = val; }
* TreeNode(int val, TreeNode left, TreeNode right) {
* this.val = val;
* this.left = left;
* this.right = right;
* }
* }
*/
class Solution {
public int countPairs(TreeNode root, int distance) {
Pair pair = dfs(root, distance);
return pair.count;
}
// 对于 dfs(root,distance),同时返回:
// 1)每个叶子节点与 root 之间的距离
// 2) 以 root 为根节点的子树中好叶子节点对的数量
public Pair dfs(TreeNode root, int distance) {
int[] depths = new int[distance + 1];
boolean isLeaf = root.left == null && root.right == null;
if (isLeaf) {
depths[0] = 1;
return new Pair(depths, 0);
}
int[] leftDepths = new int[distance + 1];
int[] rightDepths = new int[distance + 1];
int leftCount = 0, rightCount = 0;
if (root.left != null) {
Pair leftPair = dfs(root.left, distance);
leftDepths = leftPair.depths;
leftCount = leftPair.count;
}
if (root.right != null) {
Pair rightPair = dfs(root.right, distance);
rightDepths = rightPair.depths;
rightCount = rightPair.count;
}
for (int i = 0; i < distance; i++) {
depths[i + 1] += leftDepths[i];
depths[i + 1] += rightDepths[i];
}
int cnt = 0;
for (int i = 0; i <= distance; i++) {
for (int j = 0; j + i + 2 <= distance; j++) {
cnt += leftDepths[i] * rightDepths[j];
}
}
return new Pair(depths, cnt + leftCount + rightCount);
}
}
class Pair {
int[] depths;
int count;
public Pair(int[] depths, int count) {
this.depths = depths;
this.count = count;
}
}
python3 解法, 执行用时: 168 ms, 内存消耗: 16.5 MB, 提交时间: 2022-12-07 18:11:08
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def countPairs(self, root: TreeNode, distance: int) -> int:
# 对于 dfs(root,distance),同时返回:
# 每个叶子节点与 root 之间的距离
# 以 root 为根节点的子树中好叶子节点对的数量
def dfs(root: TreeNode, distance: int) -> (List[int], int):
depths = [0] * (distance + 1)
isLeaf = not root.left and not root.right
if isLeaf:
depths[0] = 1
return (depths, 0)
leftDepths, rightDepths = [0] * (distance + 1), [0] * (distance + 1)
leftCount = rightCount = 0
if root.left:
leftDepths, leftCount = dfs(root.left, distance)
if root.right:
rightDepths, rightCount = dfs(root.right, distance)
for i in range(distance):
depths[i + 1] += leftDepths[i]
depths[i + 1] += rightDepths[i]
cnt = 0
for i in range(distance + 1):
for j in range(distance - i - 1):
cnt += leftDepths[i] * rightDepths[j]
return (depths, cnt + leftCount + rightCount)
_, ret = dfs(root, distance)
return ret