0473.cpp (1648B)
1 // Backtracking: 4 ^ N 2 class Solution { 3 static bool find(const vector<int> &sticks, const int side, int idx, int crnt[4]) { 4 if (idx == size(sticks)) return crnt[0] == side && crnt[1] == side && crnt[2] == side; 5 6 for (int i = 0; i < 4; i++) { 7 if (crnt[i] + sticks[idx] > side) continue; 8 9 crnt[i] += sticks[idx]; 10 if (find(sticks, side, idx + 1, crnt)) return true; 11 crnt[i] -= sticks[idx]; 12 } 13 14 return false; 15 } 16 17 public: 18 bool makesquare(vector<int> &sticks) const { 19 const int sum = accumulate(begin(sticks), end(sticks), 0); 20 if (sum % 4 != 0) return false; 21 22 sort(begin(sticks), end(sticks), greater<>()); 23 int crnt[4] = {sticks[0], 0}; 24 return find(sticks, sum / 4, 1, crnt); 25 } 26 }; 27 28 // DP: N * 2 ^ N 29 class Solution { 30 uint8_t dp[4][1 << 16] = {0}; 31 32 bool find(const vector<int> &sticks, int left, int side, int crnt, uint16_t mask) { 33 if (dp[left][mask]) return false; 34 dp[left][mask] = 1; 35 36 if (crnt == side) left--, crnt = 0; 37 if (!left) return true; 38 39 for (int i = 0, msk = 1; i < size(sticks); i++, msk <<= 1) { 40 if (mask & msk) continue; 41 if (crnt + sticks[i] > side) continue; 42 if (find(sticks, left, side, crnt + sticks[i], mask | msk)) return true; 43 } 44 45 return false; 46 } 47 48 public: 49 bool makesquare(vector<int> &sticks) { 50 const int sum = accumulate(begin(sticks), end(sticks), 0); 51 if (sum % 4 != 0) return false; 52 53 const int side = sum / 4; 54 return find(sticks, 3, side, 0, 0); 55 } 56 };