/**
* Definition for a binary tree node.
* type TreeNode struct {
* Val int
* Left *TreeNode
* Right *TreeNode
* }
*/
func averageOfSubtree(root *TreeNode) (ans int) {
var dfs func(*TreeNode) (int, int)
dfs = func(node *TreeNode) (int, int) {
sum, cnt := node.Val, 1
if node.Left != nil {
s, c := dfs(node.Left)
sum += s
cnt += c
}
if node.Right != nil {
s, c := dfs(node.Right)
sum += s
cnt += c
}
if node.Val == sum/cnt {
ans++
}
return sum, cnt
}
dfs(root)
return
}
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def averageOfSubtree(self, root: Optional[TreeNode]) -> int:
# 相当于先序遍历了
def dfs(root: Optional[TreeNode]) -> (int, int):
nonlocal ans
if not root:
return 0, 0
l = dfs(root.left)
r = dfs(root.right)
# value左右子树之和,count左右子树节点总和
value = l[0] + r[0] + root.val
count = l[1] + r[1] + 1
if root.val == value // count:
ans += 1
return value, count
# 主函数
ans = 0
dfs(root)
return ans