class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def isSumTree(root):
# Returns a tuple (isSumTree, sum)
if root is None:
return (True, 0)
if root.left is None and root.right is None:
return (True, root.val)
leftIsSum, leftSum = isSumTree(root.left)
rightIsSum, rightSum = isSumTree(root.right)
total = leftSum + rightSum
if leftIsSum and rightIsSum and root.val == total:
return (True, root.val + total)
else:
return (False, 0)
# Example usage:
if __name__ == '__main__':
# Construct binary tree:
# 26
# / \
# 10 3
# / \ \
# 4 6 3
root = TreeNode(26, TreeNode(10, TreeNode(4), TreeNode(6)), TreeNode(3, None, TreeNode(3)))
print(isSumTree(root)[0])