307. Range Sum Query - Mutable (Medium)
Given an integer array nums, find the sum of the elements between indices i and j (i ≤ j), inclusive.
The update(i, val) function modifies nums by updating the element at index i to val.
Example:
Time Complexity:
Given nums = [1, 3, 5]
sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8
Note:
- The array is only modifiable by the update function.
- You may assume the number of calls to update and sumRange function is distributed evenly.
Solution 1: Segment Tree 132ms
Time Complexity: update and sum $$O(logn)$$
version 1: use vector
class NumArray {
public:
NumArray(vector<int> nums): n(nums.size()), input(nums) {
if (n == 0) return;
int x = (int)ceil(log2(n));
segTree = vector<int>(2*(int)pow(2,x)-1, 0);
constructTree(nums, 0, n-1, 0);
}
void update(int i, int val) {
int diff = val-input[i];
input[i] = val;
updateTree(i, diff, 0, n-1, 0);
}
int sumRange(int i, int j) {
if (n == 0) return 0;
return rangeSumQuery(i, j, 0, n-1, 0);
}
private:
int n;
vector<int> input, segTree;
int constructTree(vector<int>& input, int low, int high, int pos) {
// If there is one element in array, store it in current node of
// segment tree and return
if (low == high) {
segTree[pos] = input[low];
return input[low];
}
// If there are more than one elements, then recur for left and
// right subtrees and store the sum of values in this node
int mid = (low+high)/2;
segTree[pos] = constructTree(input, low, mid, 2*pos+1) +
constructTree(input, mid+1, high, 2*pos+2);
return segTree[pos];
}
void updateTree(int i, int diff, int low, int high, int pos) {
if (i < low || i > high) { // not in range
return;
}
// i is in range, update the value
segTree[pos] += diff;
if (low == high) return;
int mid = (low+high)/2;
updateTree(i, diff, low, mid, 2*pos+1);
updateTree(i, diff, mid+1, high, 2*pos+2);
}
int rangeSumQuery(int qlow, int qhigh, int low, int high, int pos) {
if (qlow <= low && qhigh >= high) { // total overlap
return segTree[pos];
}
if (qlow > high || qhigh < low) { // no overlap
return 0;
}
int mid = (low+high)/2;
return rangeSumQuery(qlow, qhigh, low, mid, 2*pos+1)+
rangeSumQuery(qlow, qhigh, mid+1, high, 2*pos+2);
}
};
/**
* Your NumArray object will be instantiated and called as such:
* NumArray obj = new NumArray(nums);
* obj.update(i,val);
* int param_2 = obj.sumRange(i,j);
*/
version 2: don't store input array
class NumArray {
public:
NumArray(vector<int> nums): n(nums.size()) {
if (n == 0) return;
int x = (int)ceil(log2(n));
segTree = vector<int>(2*(int)pow(2,x)-1, 0);
constructTree(nums, 0, n-1, 0);
}
void update(int i, int val) {
updateTree(i, val, 0, n-1, 0);
}
int sumRange(int i, int j) {
if (n == 0) return 0;
return rangeSumQuery(i, j, 0, n-1, 0);
}
private:
int n;
vector<int> segTree;
int constructTree(vector<int>& input, int low, int high, int pos) {
// If there is one element in array, store it in current node of
// segment tree and return
if (low == high) {
segTree[pos] = input[low];
return input[low];
}
// If there are more than one elements, then recur for left and
// right subtrees and store the sum of values in this node
int mid = (low+high)/2;
segTree[pos] = constructTree(input, low, mid, 2*pos+1) +
constructTree(input, mid+1, high, 2*pos+2);
return segTree[pos];
}
int updateTree(int i, int val, int low, int high, int pos) {
if (i < low || i > high) return 0;
int diff = 0;
// i is in range, update the value
if (low == high) {
diff = val - segTree[pos];
segTree[pos] = val;
return diff;
}
int mid = (low+high)/2;
if (i <= mid)
diff = updateTree(i, val, low, mid, 2*pos+1);
else
diff = updateTree(i, val, mid+1, high, 2*pos+2);
segTree[pos] += diff;
return diff;
}
int rangeSumQuery(int qlow, int qhigh, int low, int high, int pos) {
if (qlow <= low && qhigh >= high) { // total overlap
return segTree[pos];
}
if (qlow > high || qhigh < low) { // no overlap
return 0;
}
int mid = (low+high)/2;
return rangeSumQuery(qlow, qhigh, low, mid, 2*pos+1)+
rangeSumQuery(qlow, qhigh, mid+1, high, 2*pos+2);
}
};
version 3: use TreeNode 166ms
struct SegTreeNode {
int sum;
SegTreeNode* left, *right;
SegTreeNode(): sum(0), left(NULL), right(NULL) {}
};
class NumArray {
public:
NumArray(vector<int> nums): n(nums.size()), root(NULL) {
if (n > 0) root = contructTree(nums, 0, n-1);
}
void update(int i, int val) {
updateTree(i, val, 0, n-1, root);
}
int sumRange(int i, int j) {
return rangeSumQuery(i, j, 0, n-1, root);
}
private:
int n;
SegTreeNode* root;
SegTreeNode* contructTree(vector<int>& nums, int low, int high) {
SegTreeNode* root = new SegTreeNode();
if (low == high) {
root->sum = nums[low];
return root;
}
int mid = (low+high)/2;
root->left = contructTree(nums, low, mid);
root->right = contructTree(nums, mid+1, high);
root->sum = root->left->sum + root->right->sum;
return root;
}
int updateTree(int i, int val, int low, int high, SegTreeNode* root) {
if (i < low || i > high) return 0;
int diff = 0;
if (low == high) {
diff = val-root->sum;
root->sum = val;
return diff;
}
int mid = (low+high)/2;
if (i <= mid)
diff = updateTree(i, val, low, mid, root->left);
else
diff = updateTree(i, val, mid+1, high, root->right);
root->sum += diff;
return diff;
}
int rangeSumQuery(int qlow, int qhigh, int low, int high, SegTreeNode* root) {
if (qlow <= low && qhigh >= high) return root->sum; //total overlap
if (qlow > high || qhigh < low) return 0; // no overlap
// partial overlap
int mid = (low+high)/2;
return rangeSumQuery(qlow, qhigh, low, mid, root->left)+
rangeSumQuery(qlow, qhigh, mid+1, high, root->right);
}
};
Solution 2: Binary Indexed Tree 159 ms
class NumArray {
public:
NumArray(vector<int> nums): n(nums.size()+1), input(nums), BIT(n, 0) {
for (int i = 1; i < n; ++i) {
updateBIT(i, nums[i-1]);
}
}
void update(int i, int val) {
updateBIT(i+1, val-input[i]);
input[i] = val;
}
int sumRange(int i, int j) {
return getSum(j+1)-getSum(i);
}
private:
int n;
vector<int> input, BIT;
void updateBIT(int i, int val) {
while (i < n) {
BIT[i] += val;
i += i & (-i); // get next
}
}
int getSum(int i) {
int sum = 0;
while (i > 0) {
sum += BIT[i];
i -= i & (-i); // get parent;
}
return sum;
}
};