Diameter of a Binary Tree

Posted in

Diameter of a Binary Tree
vinaykhatri

Vinay Khatri
Last updated on December 22, 2024

    Diameter of a Binary Tree

    The diameter of a tree or a binary tree is referred to as the count of the total number of nodes in the longest path possible between the topmost (root) node and the bottom-most node. Sometimes, it is also known as the width of a tree. The path to be called the diameter of a tree should satisfy the following points:

    • The diameter of the complete tree must be longer than the diameter of the left subtree.
    • The diameter of the complete tree must be longer than the diameter of the right subtree.
    • The diameter of the complete tree must be longer than the path between the leaf nodes that pass through the node.

    Example 1

    Output: 7

    Explanation: The longest possible path possible between the topmost(root) node and the bottom-most node contains 7 nodes passing through the root node. Hence, the diameter of this tree will be 7.

    Example 2:

    Output: 6

    Explanation: In this example, the longest path possible contains 6 nodes passing through the root node. Hence, the diameter of this tree will be 7

    Approach 1: Recursive Approach

    The idea is to first compute the height of the left subtree and right subtree, and the diameter of the left subtree and right subtree. The height of a tree is the number of nodes along the longest path from the root node down to the farthest leaf node.

    The max of the sums of heights left diameter, and right diameter will be the diameter of the tree respectively.

    Algorithm

    Let us consider the given node to be n.

    1. Start traversing the tree from the top.
    2. If the root of the tree is equal to NULL, then return false, which means that the tree does not contain any node.
    3. Compute the heights of the left and right subtree respectively.
    4. Add the heights.
    5. Compute the diameters of the left and right subtrees respectively.
    6. Return the maximum of the sum of the heights and left diameter and right diameter.

    Below is the implementation of the approach discussed above:

    CPP

    #include <bits/stdc++.h>
    using namespace std;
    
    
    // Structure to store binary tree having 
    // data with left child pointer and right child pointer.
    struct node {
        int data;
        struct node *left, *right;
    };
    
    // declaring prototype of
    // a function to create a new node in the tree.
    struct node* newNode(int data);
    
    // utility function to get the 
    // maximum of two integers.
    int max(int a, int b) { return (a > b) ? a : b; }
    
    // declaring prototype of
    // utility function to calculate the 
    // height of a tree.
    int height(struct node* node);
    
    // Function that takes a root node 
    // as its parameter to 
    // calculate the diameter 
    // of a tree.
    int diameter(struct node* tree)
    {
        // base case
        if (tree == NULL)
            return 0;
    
        // calculating height of the
        // left subtree and
        // right subtree
        int lheight = height(tree->left);
        int rheight = height(tree->right);
    
        // recursively calculate the diameter of
        // left subtree and
        // right subtree
        int ldiameter = diameter(tree->left);
        int rdiameter = diameter(tree->right);
    
    
        // return maximum of:
        // diameter of left subtree, that of right subtree and
        // 1 + height of left subtree + right subtree
        return max(lheight + rheight + 1,
                max(ldiameter, rdiameter));
    }
    
    // defining the utility function to calculate 
    // the height of a tree.  
    // The height of a tree is the 
    // number of nodes along the longest path 
    // from the root node down to the farthest leaf node.
    int height(struct node* node)
    {
        // base case 
        if (node == NULL)
            return 0;
    
        // calculate the height as 
        // 1 + maximum of left subtree height and right subtree height
        return 1 + max(height(node->left), height(node->right));
    }
    
    // defining the utility function to create a new node in the tree.
    struct node* newNode(int data)
    {
        struct node* node
            = (struct node*)malloc(sizeof(struct node));
        node->data = data;
        node->left = NULL;
        node->right = NULL;
    
        return (node);
    }
    
    // Driver Code
    int main()
    {
        /* Construct the following tree
             1
           /   \
          /     \
         2       3
        / \      
       4   5     
      / \ 
      6  7    
        
      */
    
    
        struct node* root = newNode(1);
        root->left = newNode(2);
        root->right = newNode(3);
        root->left->left = newNode(4);
        root->left->right = newNode(5);
        root->left->left->left = newNode(6);
        root->left->left->right = newNode(7);
    
    
        // Function Call
        cout << "Diameter: " <<diameter(root);
    
    
        return 0;
    }
    

    Output

    Diameter: 5

    JAVA

    // Class to store binary tree having 
    // data with left child pointer and right child pointer.
    class Node {
        int data;
        Node left, right;
    
        public Node(int item)
        {
            data = item;
            left = right = null;
        }
    }
    
    // A class to print diameter of a tree
    class BinaryTree {
        Node root;
    
        // Method that takes a root node 
        // as its parameter to 
        // calculate the diameter 
        // of a tree.
    
        int diameter(Node root)
        {
            // base case 
            if (root == null)
                return 0;
    
            // calculating height of the
            // left subtree and
            // right subtree
    
            int lheight = height(root.left);
            int rheight = height(root.right);
    
            // recursively calculate the diameter of
            // left subtree and
            // right subtree
            int ldiameter = diameter(root.left);
            int rdiameter = diameter(root.right);
    
            // return maximum of:
            // diameter of left subtree, that of right subtree and
            // 1 + height of left subtree + right subtree
            return Math.max(lheight + rheight + 1,
                            Math.max(ldiameter, rdiameter));
        }
    
        // wrapper function for diameter(Node root)
        int diameter() { return diameter(root); }
    
        // defining the utility function to calculate 
        // the height of a tree.  
        // The height of a tree is the 
        // number of nodes along the longest path 
        // from the root node down to the farthest leaf node.
    
        static int height(Node node)
        {
            // base case 
            if (node == null)
                return 0;
    
            // calculate the height as 
            // 1 + maximum of left subtree height and right subtree height
            return (1
                    + Math.max(height(node.left),
                            height(node.right)));
        }
    
        // Driver Code
        public static void main(String args[])
        {
            /* Construct the following tree
             1
           /   \
          /     \
         2       3
        / \      
       4   5     
      / \ 
      6  7    
        
      */
    
            BinaryTree tree = new BinaryTree();
            tree.root = new Node(1);
            tree.root.left = new Node(2);
            tree.root.right = new Node(3);
            tree.root.left.left = new Node(4);
            tree.root.left.right = new Node(5);
            tree.root.left.left.left = new Node(6);
            tree.root.left.left.right = new Node(7);
    
    
    
            // Function Call
            System.out.println(
                "Diameter: "
                + tree.diameter());
        }
    }
    

    Output

    Diameter: 5

    Python

    # Class to store binary tree having 
    # data with left child pointer and right child pointer.
    
    class Node:
    
        # Constructor to create a new node
        def __init__(self, data):
            self.data = data
            self.left = None
            self.right = None
    
    
    
    # defining the utility function to calculate 
    # the height of a tree.  
    # The height of a tree is the 
    # number of nodes along the longest path 
    # from the root node down to the farthest leaf node.
    
    
    def height(node):
    
        # Base Case 
        if node is None:
            return 0
    
        # calculate the height as 
        # 1 + maximum of left subtree height and right subtree height
        return 1 + max(height(node.left), height(node.right))
    
    # Function that takes a root node 
    # as its parameter to 
    # calculate the diameter 
    # of a tree.
    def diameter(root):
    
        # Base Case 
        if root is None:
            return 0
    
        # calculating height of the
        # left subtree and
        # right subtree
        lheight = height(root.left)
        rheight = height(root.right)
    
        # recursively calculate the diameter of
        # left subtree and
        # right subtree
        ldiameter = diameter(root.left)
        rdiameter = diameter(root.right)
    
        # return maximum of:
        # diameter of left subtree, that of right subtree and
        # 1 + height of left subtree + right subtree
        return max(lheight + rheight + 1, max(ldiameter, rdiameter))
    
    
    # Driver Code
    
    root = Node(1)
    root.left = Node(2)
    root.right = Node(3)
    root.left.left = Node(4)
    root.left.right = Node(5)
    root.left.left.left = Node(6)
    root.left.left.right = Node(7)
    
    print("Diameter:" ,diameter(root) )
    

    Output

    Diameter: 5

    Complexity Analysis

    Time complexity: O(N^2), where N is the total count of nodes. Space complexity: O(H, where h is the height of the tree, O(h) is due to the space utilized by the recursion stack, which could be equal to the height of the tree in the worst case.

    Approach 2: Optimised Recursive Approach

    The above approach can be optimized by computing the height in the same recursion stack instead of defining a separate function of calculating the height.

    Algorithm:

    1. Start traversing the tree from the top.
    2. If the root of the tree is equal to NULL, then return false, which means that the tree does not contain any node.
    3. Get the heights of left and right subtrees in separate variables and return the values as left and right tree’s diameters.
    4. Return the maximum of the sum of the heights and left diameter and right diameter.

    The implementation of the approach discussed above is as follows:

    CPP

    #include <bits/stdc++.h>
    using namespace std;
    
    // Structure to store binary tree having 
    // data with left child pointer and right child pointer.
    struct node {
        int data;
        struct node *left, *right;
    };
    
    // declaring prototype of
    // a function to create a new node in the tree.
    struct node* newNode(int data);
    
    int diameterOpt(struct node* root, int* height)
    {
        // lHeight is the left subtree's height
        // rHeight is the right subtree's height
        int lHeight = 0, rHeight = 0;
    
        // lDiameter is the left subtree's diameter 
        // rDiameter is the right subtree's diameter 
        int lDiameter = 0, rDiameter = 0;
    
        if (root == NULL) {
            *height = 0;
            return 0; 
        }
    
        // store the result obtained after passing 
        // the left subtree height: lHeight and 
        // the right subtree height: rHeight
        // in lDiameter and rDiameter respectively.
        lDiameter = diameterOpt(root->left, &lHeight);
        rDiameter = diameterOpt(root->right, &rHeight);
    
        // calculate the height of the current node as
        // 1 + maximum of the height of left subtree and right subtree.
        *height = max(lHeight, rHeight) + 1;
    
        return max(lHeight + rHeight + 1, max(lDiameter, rDiameter));
    }
    
    // Function that takes a root node 
    // as its parameter to 
    // calculate the diameter 
    // of a tree.
    struct node* newNode(int data)
    {
        struct node* node
            = (struct node*)malloc(sizeof(struct node));
        node->data = data;
        node->left = NULL;
        node->right = NULL;
    
        return (node);
    }
    
    // Driver Code
    int main()
    {
    
          /* Construct the following tree
             1
           /   \
          /     \
         2       3
        / \      
       4   5     
      / \ 
      6  7    
        
      */
    
    
        struct node* root = newNode(1);
        root->left = newNode(2);
        root->right = newNode(3);
        root->left->left = newNode(4);
        root->left->right = newNode(5);
        root->left->left->left = newNode(6);
        root->left->left->right = newNode(7);
    
        int height = 0;
        // Function Call
        cout << "Diameter: " << diameterOpt(root, &height);
    
        return 0;
    }
    

    Output

    Diameter: 5

    JAVA

    // Class to store binary tree having 
    // data with left child pointer and right child pointer.
    class Node {
        int data;
        Node left, right;
    
        public Node(int item)
        {
            data = item;
            left = right = null;
        }
    }
    
    // Helper class to get a height object.
    class Height {
        int h;
    }
    
    // A class to print diameter of a tree
    class BinaryTree {
        Node root;
    
        int diameterOpt(Node root, Height height)
        {
            // lHeight is the left subtree's height
            // rHeight is the right subtree's height
            Height lHeight = new Height(), rHeight = new Height();
    
            if (root == null) {
                height.h = 0;
                return 0; 
            }
    
            // lDiameter is the left subtree's diameter 
            // rDiameter is the right subtree's diameter 
    
            // store the result obtained after passing 
            // the left subtree height: lHeight and 
            // the right subtree height: rHeight
            // in lDiameter and rDiameter respectively.
            int lDiameter = diameterOpt(root.left, lHeight);
            int rDiameter = diameterOpt(root.right, rHeight);
    
            // calculate height of the current node as
            // 1 + maximum of height of left subtree and right subtree.
            height.h = Math.max(lHeight.h, rHeight.h) + 1;
    
            return Math.max(lHeight.h + rHeight.h + 1,
                            Math.max(lDiameter, rDiameter));
        }
    
        // wrapper function for diameter(Node root)
        int diameter()
        {
            Height height = new Height();
            return diameterOpt(root, height);
        }
    
        // Function that takes a root node 
        // as its parameter to 
        // calculate the diameter 
        // of a tree.
        static int height(Node node)
        {
            // base case 
            if (node == null)
                return 0;
    
            // calculate the height of the tree as
            // 1 + maximum of left subtree's height and right subtree's height
            return (1
                    + Math.max(height(node.left),
                            height(node.right)));
        }
    
        // Driver Code
        public static void main(String args[])
        {
            /* Construct the following tree
             1
           /   \
          /     \
         2       3
        / \      
       4   5     
      / \ 
      6  7    
        
      */
            
            BinaryTree tree = new BinaryTree();
            tree.root = new Node(1);
            tree.root.left = new Node(2);
            tree.root.right = new Node(3);
            tree.root.left.left = new Node(4);
            tree.root.left.right = new Node(5);
            tree.root.left.left.left = new Node(6);
            tree.root.left.left.right = new Node(7);
    
            // Function Call
            System.out.println(
                "Diameter: "
                + tree.diameter());
        }
    }
    

    Output

    Diameter: 5

    Python

    class Node:
    
        # Constructor to store binary tree having 
        # data with left child pointer and right child pointer.
        def __init__(self, data):
            self.data = data
            self.left = self.right = None
    
    # Helper class to get a height object.
    
    class Height:
        def __init(self):
            self.h = 0
    
    
    def diameterOpt(root, height):
    
        # lHeight is the left subtree's height
        # rHeight is the right subtree's height
        lHeight = Height()
        rHeight = Height()
    
        # base case
        if root is None:
            height.h = 0
            return 0
    
        
        # lDiameter is the left subtree's diameter 
        # rDiameter is the right subtree's diameter 
    
        # store the result obtained after passing 
        # the left subtree height: lHeight and 
        # the right subtree height: rHeight
        # in lDiameter and rDiameter respectively.
        
        lDiameter = diameterOpt(root.left, lHeight)
        rDiameter = diameterOpt(root.right, rHeight)
    
        # calculate the height of the current node as
        # 1 + maximum of height of left subtree and right subtree.
    
        height.h = max(lHeight.h, rHeight.h) + 1
    
        # Function that takes a root node 
        # as its parameter to 
        # calculate the diameter 
        # of a tree.
        return max(lHeight.h + rHeight.h + 1, max(lDiameter, rDiameter))
    
    # defining a function that 
    # calculates the diameter of a tree
    def diameter(root):
        height = Height()
        return diameterOpt(root, height)
    
    
    # Driver Code
    root = Node(1)
    root.left = Node(2)
    root.right = Node(3)
    root.left.left = Node(4)
    root.left.right = Node(5)
    root.left.left.left = Node(6)
    root.left.left.right = Node(7)
    
    
    # Function Call
    print(diameter(root))
    

    Output

    Diameter: 5

    Complexity Analysis

    Time complexity: O(N), where N is the total count of nodes. This is due to we are computing the height in the same recursion stack instead of defining a separate function of calculating the height. Space complexity: O(N), where N is the total count of nodes. This is because the space complexity is related to the memory used by our recursion stack.

    Approach 3: Using DFS

    In this approach, we will find the diameter of the tree using DFS. Using DFS is a better approach because as we already know that the longest path will always fall between two leaf nodes

    Algorithm:

    1. Start the DFS traversal of the tree from a random node.
    2. Compute the farthest node from the starting point.
    3. Take the obtained node as a start point and again start DFS traversal from it.
    4. Calculate the farthest node from it and trace the path.
    5. The total number of nodes in the obtained path will be the diameter of the tree.

    CPP

    #include 
    #include 
    #include 
    using namespace std;
    
    // a variable to track the node 
    // that is the farthest.
    int x;
    
    // a utility function to 
    // find maximum distance from node 
    // and assign the result to maxCount
    void dfsUtil(int node, int count, bool visited[],
                    int& maxCount, list* adj)
    {
        visited[node] = true;
        count++;
        for (auto i = adj[node].begin(); i != adj[node].end(); ++i) {
            if (!visited[*i]) {
                if (count >= maxCount) {
                    maxCount = count;
                    x = *i;
                }
                dfsUtil(*i, count, visited, maxCount, adj);
            }
        }
    }
    
    // function implementing DFS using the 
    // utility function dfsUtil() recursively.
    void dfs(int node, int n, list* adj, int& maxCount)
    {
        bool visited[n + 1];
        int count = 0;
    
        // Initially set the vertices
        // as false, indicating that 
        // vertices are unvisited.
        for (int i = 1; i <= n; ++i)
            visited[i] = false;
    
        // for the node that is visited
        // add 1 to the count.
        dfsUtil(node, count + 1, visited, maxCount, adj);
    }
    
    // function that calculates the diameter of
    // a binary tree and returns it
    // using adjacency list.
    int diameter(list* adj, int n)
    {
        int maxCount = INT_MIN;
    
        // DFS to find the node that is the farthest, i.e. x 
        // for any random node.
        dfs(1, n, adj, maxCount);
    
        // DFS to find the node that is the farthest
        // from the node x.
        dfs(x, n, adj, maxCount);
    
        return maxCount;
    }
    
    /* Driver program to test above functions*/
    int main()
    {
        int n = 7;
    
        /* Construct the following tree
             1
           /   \
          /     \
         2       3
        / \      
       4   5     
      / \ 
      6  7    
        
      */
        list* adj = new list[n + 1];
    
        /*create undirected edges */
        adj[1].push_back(2);
        adj[2].push_back(1);
        adj[1].push_back(3);
        adj[3].push_back(1);
        adj[2].push_back(4);
        adj[4].push_back(2);
        adj[2].push_back(5);
        adj[5].push_back(2);
        adj[4].push_back(6);
        adj[6].push_back(4);
        adj[4].push_back(7);
        adj[7].push_back(4);
    
    
    
    
        // Print the diameter of the above tree.
        cout << "Diameter: "
                << diameter(adj, n);
        return 0;
    }
    

    Output

    Diameter: 5

    JAVA

    import java.util.ArrayList;
    import java.util.Arrays;
    import java.util.List;
    class BinaryTree {
    
        // a variable to track the node 
        // that is the farthest.
        static int x;
    
        static int maxCount;
        static List adj[];
        
        // a utility function to 
        // find maximum distance from node 
        // and assign the result to maxCount
        static void dfsUtil(int node, int count,
                            boolean visited[],
                        List adj[])
        {
            visited[node] = true;
            count++;
            
            List l = adj[node];
            for(Integer i: l)
            {
                if(!visited[i]){
                    if (count >= maxCount) {
                        maxCount = count;
                        x = i;
                    }
                    dfsUtil(i, count, visited, adj);
                }
            }
        }
        
        // function implementing DFS using the 
        // utility function dfsUtil() recursively.
        static void dfs(int node, int n, List
                                        adj[])
        {
            boolean[] visited = new boolean[n + 1];
            int count = 0;
        
            // Initially set the vertices
            // as false, indicating that 
            // vertices are unvisited.
            Arrays.fill(visited, false);
        
            // for the node that is visited
            // add 1 to the count.
            dfsUtil(node, count + 1, visited, adj);
            
        }
        
        // function that calculates the diameter of
        // a binary tree and returns it
        // using adjacency list.
        static int diameter(List adj[], int n)
        {
            maxCount = Integer.MIN_VALUE;
        
            // DFS to find the node that is the farthest, i.e. x 
            // for any random node.
            dfs(1, n, adj);
        
            // DFS to find the node that is the farthest
            // from the node x.
            dfs(x, n, adj);
        
            return maxCount;
        }
        
        /* Driver program to test above functions*/
        public static void main(String args[])
        {
            int n = 7;
        
        /* Construct the following tree
             1
           /   \
          /     \
         2       3
        / \      
       4   5     
      / \ 
      6  7    
        
      */
            adj = new List[n + 1];
            for(int i = 0; i < n+1 ; i++)
                adj[i] = new ArrayList();
        
            /*create undirected edges */
            adj[1].add(2);
            adj[2].add(1);
            adj[1].add(3);
            adj[3].add(1);
            adj[2].add(4);
            adj[4].add(2);
            adj[2].add(5);
            adj[5].add(2);
            adj[4].add(6);
            adj[6].add(4);
            adj[4].add(7);
            adj[7].add(4);
            
            // Print the diameter of the above tree.
            System.out.println("Diameter:  " + diameter(adj, n));
        }
    }
    

    Output

    Diameter: 5

    Python

    # a utility function to
    # find maximum distance from node
    # and assign the result to maxCount
    def dfsUtil(node, count):
        global visited, x, maxCount, adj
        visited[node] = 1
        count += 1
        for i in adj[node]:
            if (visited[i] == 0):
                if (count >= maxCount):
                    maxCount = count
                    x = i
                dfsUtil(i, count)
    
    # function implementing DFS using the
    # utility function dfsUtil() recursively.
    def dfs(node, n):
        count = 0
        # Initially set the vertices
        # as false, indicating that
        # vertices are unvisited.
        for i in range(n + 1):
            visited[i] = 0
        # for the node that is visited
        # add 1 to the count.
        dfsUtil(node, count + 1)
    
    # function that calculates the diameter of
    # a binary tree and returns it
    # using adjacency list.
    def diameter(n):
        global adj, maxCount
    
        # DFS to find the node that is the farthest, i.e. x
        # for any random node.
        dfs(1, n)
    
        # DFS to find the node that is the farthest
        # from the node x.
        dfs(x, n)
        return maxCount
    
    ## Driver code*/
    if __name__ == '__main__':
        n = 7
    
    
        adj, visited = [[] for i in range(n + 1)], [0 for i in range(n + 1)]
        maxCount = -10**19
        x = 0
    
        # create undirected edges */
        adj[1].append(2)
        adj[2].append(1)
        adj[1].append(3)
        adj[3].append(1)
        adj[2].append(4)
        adj[4].append(2)
        adj[2].append(5)
        adj[5].append(2)
        adj[4].append(6)
        adj[6].append(4)
        adj[4].append(7)
        adj[7].append(4)
    
    
        # Print the diameter of the above tree.
        print ("Diameter: ", diameter(n))
    

    Output

    Diameter: 5

    Time complexity: O(N), where N is the total count of nodes. This is so because we are only traversing each node once. Space complexity: O(N), where N is the total count of nodes. This is due to the DFS traversal.

    Wrapping Up!

    In this article, we have learned how to calculate the diameter of a tree. We discussed three approaches in which we can solve this problem. This article also contains well-explained codes of both the approaches in the three most popular languages which are c++, Java, and python along with their respective outputs attached to the article for a better understanding of a wide range of our readers. We sincerely hope that this article has walked you through some deep and important concepts of Binary Trees and how we should approach such kinds of problems.

    Happy Learning!

    People are also reading:

    Leave a Comment on this Post

    0 Comments