class Solution {
public:
int averageOfSubtree(TreeNode *root) {
stack<TreeNode *> st;
unordered_map<TreeNode *, int> um;
int res = 0;
typedef tuple<int, int> tii;
public:
int averageOfSubtree(const TreeNode *root) const {
unordered_map<const TreeNode *, tii> um;
stack<const TreeNode *> st;
int count = 0;
st.push(root);
while (!st.empty()) {
TreeNode *root = st.top();
if (root == nullptr) {
st.pop(), root = st.top(), st.pop();
int sum = root->val, count = 1;
if (root->left) {
sum += root->left->val;
count += um[root->left];
}
if (root->right) {
sum += root->right->val;
count += um[root->right];
}
if (root->val == sum / count) res++;
um[root] = count;
root->val = sum;
if (st.top() != nullptr) {
const TreeNode *root = st.top();
st.push(nullptr);
if (root->left) st.push(root->left);
if (root->right) st.push(root->right);
continue;
}
st.push(nullptr);
if (root->left) st.push(root->left);
if (root->right) st.push(root->right);
st.pop();
const TreeNode *root = st.top();
st.pop();
tii left = um[root->left], right = um[root->right];
um[root] = {root->val + get<0>(left) + get<0>(right), 1 + get<1>(left) + get<1>(right)};
if (get<0>(um[root]) / get<1>(um[root]) == root->val) count++;
}
return res;
return count;
}
};