/**
* Definition for a binary tree node.
* type TreeNode struct {
* Val int
* Left *TreeNode
* Right *TreeNode
* }
*/
func findDistance(root *TreeNode, p int, q int) int {
var ans int
var dfs func(root *TreeNode) (int, int)
dfs = func(root *TreeNode) (int, int) {
if root == nil {
return -1, -1
}
pp, qq := -1, -1
a, b := dfs(root.Left)
if a != -1 {
pp = a + 1
}
if b != -1 {
qq = b + 1
}
a, b = dfs(root.Right)
if a != -1 {
pp = a + 1
}
if b != -1 {
qq = b + 1
}
if root.Val == p {
pp = 0
}
if root.Val == q {
qq = 0
}
if pp != -1 && qq != -1 {
ans = pp + qq
pp,qq=-1,-1
}
return pp, qq
}
dfs(root)
return ans
}
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def findDistance(self, root: TreeNode, p: int, q: int) -> int:
LCA = self.find_LCA(root, p, q) #Least Common Ancestor
p_dep = self.dfs(LCA, p)
q_dep = self.dfs(LCA, q)
return p_dep + q_dep
def dfs(self, root: TreeNode, target: int) -> int:
if root == None:
return -1
if root.val == target:
return 0
L = self.dfs(root.left, target)
R = self.dfs(root.right, target)
if L == -1 and R == -1:
return -1
return max(L , R) + 1
def find_LCA(self, root: TreeNode, p: int, q: int) -> TreeNode:
if root==None or root.val==p or root.val==q:
return root
L = self.find_LCA(root.left, p, q)
R = self.find_LCA(root.right, p, q)
if L and R:
return root
elif L and R==None:
return L
elif L == None and R:
return R
else:
return None