1530.cpp (1797B)
1 // Recursive 2 class Solution { 3 public: 4 int countPairs(TreeNode *root, int distance) { 5 unordered_map<TreeNode *, vector<int>> um; 6 stack<TreeNode *> st; 7 int res = 0; 8 9 st.push(root); 10 while (!st.empty()) { 11 TreeNode *root = st.top(); 12 if (root) { 13 st.push(nullptr); 14 if (!root->left && !root->right) 15 um[root].push_back(1); 16 else { 17 if (root->left) st.push(root->left); 18 if (root->right) st.push(root->right); 19 } 20 continue; 21 } 22 st.pop(); 23 root = st.top(); 24 st.pop(); 25 26 for (const int n : um[root->right]) 27 um[root].push_back(n + 1); 28 29 for (const int a : um[root->left]) { 30 um[root].push_back(a + 1); 31 for (const int b : um[root->right]) 32 if (a + b <= distance) res++; 33 } 34 } 35 return res; 36 } 37 }; 38 39 // Iterative 40 class Solution { 41 int res = 0; 42 vector<int> rec(TreeNode *root, int distance) { 43 if (!root->left && !root->right) return {1}; 44 vector<int> left, right, sum; 45 if (root->left) left = rec(root->left, distance); 46 if (root->right) right = rec(root->right, distance); 47 48 sum.reserve(left.size() + right.size()); 49 for (const int b : right) 50 sum.push_back(b + 1); 51 for (const int a : left) { 52 sum.push_back(a + 1); 53 for (const int b : right) { 54 res += (a + b <= distance); 55 } 56 } 57 return sum; 58 } 59 60 public: 61 int countPairs(TreeNode *root, int distance) { 62 rec(root, distance); 63 return res; 64 } 65 };