307. Range Sum Query - Mutable (Medium)

Given an integer array nums, find the sum of the elements between indices i and j (ij), inclusive.

The update(i, val) function modifies nums by updating the element at index i to val.

Example:

Time Complexity:

Given nums = [1, 3, 5]

sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8

Note:

  1. The array is only modifiable by the update function.
  2. You may assume the number of calls to update and sumRange function is distributed evenly.

Solution 1: Segment Tree 132ms

Time Complexity: update and sum $$O(logn)$$

version 1: use vector

class NumArray {
public:
    NumArray(vector<int> nums): n(nums.size()), input(nums) {
        if (n == 0) return;
        int x = (int)ceil(log2(n));
        segTree = vector<int>(2*(int)pow(2,x)-1, 0);
        constructTree(nums, 0, n-1, 0);
    }

    void update(int i, int val) {
        int diff = val-input[i];
        input[i] = val;
        updateTree(i, diff, 0, n-1, 0);
    }

    int sumRange(int i, int j) {
        if (n == 0) return 0;
        return rangeSumQuery(i, j, 0, n-1, 0);
    }
private: 
    int n;
    vector<int> input, segTree;

    int constructTree(vector<int>& input, int low, int high, int pos) {
        // If there is one element in array, store it in current node of
        // segment tree and return
        if (low == high) {
            segTree[pos] = input[low];
            return input[low];
        }
        // If there are more than one elements, then recur for left and
        // right subtrees and store the sum of values in this node
        int mid = (low+high)/2;
        segTree[pos] = constructTree(input, low, mid, 2*pos+1) +
            constructTree(input, mid+1, high, 2*pos+2);
        return segTree[pos];
    }

    void updateTree(int i, int diff, int low, int high, int pos) {
        if (i < low || i > high) { // not in range
            return;
        }
        // i is in range, update the value
        segTree[pos] += diff;
        if (low == high) return;
        int mid = (low+high)/2;
        updateTree(i, diff, low, mid, 2*pos+1);
        updateTree(i, diff, mid+1, high, 2*pos+2);
    }

    int rangeSumQuery(int qlow, int qhigh, int low, int high, int pos) {
        if (qlow <= low && qhigh >= high) { // total overlap
            return segTree[pos];
        }
        if (qlow > high || qhigh < low) { // no overlap
            return 0;
        }

        int mid = (low+high)/2;
        return rangeSumQuery(qlow, qhigh, low, mid, 2*pos+1)+
        rangeSumQuery(qlow, qhigh, mid+1, high, 2*pos+2);
    }
};

/**
 * Your NumArray object will be instantiated and called as such:
 * NumArray obj = new NumArray(nums);
 * obj.update(i,val);
 * int param_2 = obj.sumRange(i,j);
 */

version 2: don't store input array

class NumArray {
public:
    NumArray(vector<int> nums): n(nums.size()) {
        if (n == 0) return;
        int x = (int)ceil(log2(n));
        segTree = vector<int>(2*(int)pow(2,x)-1, 0);
        constructTree(nums, 0, n-1, 0);
    }

    void update(int i, int val) {
        updateTree(i, val, 0, n-1, 0);
    }

    int sumRange(int i, int j) {
        if (n == 0) return 0;
        return rangeSumQuery(i, j, 0, n-1, 0);
    }
private: 
    int n;
    vector<int> segTree;

    int constructTree(vector<int>& input, int low, int high, int pos) {
        // If there is one element in array, store it in current node of
        // segment tree and return
        if (low == high) {
            segTree[pos] = input[low];
            return input[low];
        }
        // If there are more than one elements, then recur for left and
        // right subtrees and store the sum of values in this node
        int mid = (low+high)/2;
        segTree[pos] = constructTree(input, low, mid, 2*pos+1) +
            constructTree(input, mid+1, high, 2*pos+2);
        return segTree[pos];
    }

    int updateTree(int i, int val, int low, int high, int pos) {
        if (i < low || i > high) return 0;
        int diff = 0;

        // i is in range, update the value
        if (low == high) {
            diff = val - segTree[pos];
            segTree[pos] = val;
            return diff;
        } 

        int mid = (low+high)/2;
        if (i <= mid)  
            diff = updateTree(i, val, low, mid, 2*pos+1);
        else 
            diff = updateTree(i, val, mid+1, high, 2*pos+2);

        segTree[pos] += diff;
        return diff;
    }

    int rangeSumQuery(int qlow, int qhigh, int low, int high, int pos) {
        if (qlow <= low && qhigh >= high) { // total overlap
            return segTree[pos];
        }
        if (qlow > high || qhigh < low) { // no overlap
            return 0;
        }

        int mid = (low+high)/2;
        return rangeSumQuery(qlow, qhigh, low, mid, 2*pos+1)+
        rangeSumQuery(qlow, qhigh, mid+1, high, 2*pos+2);
    }
};

version 3: use TreeNode 166ms

struct SegTreeNode {
    int sum;
    SegTreeNode* left, *right;
    SegTreeNode(): sum(0), left(NULL), right(NULL) {}
};

class NumArray {
public:
    NumArray(vector<int> nums): n(nums.size()), root(NULL) {
        if (n > 0) root = contructTree(nums, 0, n-1);
    }

    void update(int i, int val) {
        updateTree(i, val, 0, n-1, root);
    }

    int sumRange(int i, int j) {
        return rangeSumQuery(i, j, 0, n-1, root);
    }
private: 
    int n;
    SegTreeNode* root;
    SegTreeNode* contructTree(vector<int>& nums, int low, int high) {
        SegTreeNode* root = new SegTreeNode();
        if (low == high) {
            root->sum = nums[low];
            return root;
        }
        int mid = (low+high)/2;
        root->left = contructTree(nums, low, mid);
        root->right = contructTree(nums, mid+1, high);
        root->sum = root->left->sum + root->right->sum;
        return root;
    }

    int updateTree(int i, int val, int low, int high, SegTreeNode* root) {
        if (i < low || i > high) return 0;
        int diff = 0;
        if (low == high) {
            diff = val-root->sum;
            root->sum = val;
            return diff;
        }
        int mid = (low+high)/2;
        if (i <= mid)
            diff = updateTree(i, val, low, mid, root->left);
        else 
            diff = updateTree(i, val, mid+1, high, root->right);

        root->sum += diff;
        return diff;

    }

    int rangeSumQuery(int qlow, int qhigh, int low, int high, SegTreeNode* root) {
        if (qlow <= low && qhigh >= high) return root->sum; //total overlap
        if (qlow > high || qhigh < low) return 0; // no overlap
        // partial overlap
        int mid = (low+high)/2;
        return rangeSumQuery(qlow, qhigh, low, mid, root->left)+
            rangeSumQuery(qlow, qhigh, mid+1, high, root->right);

    }
};

Solution 2: Binary Indexed Tree 159 ms

class NumArray {

public:
    NumArray(vector<int> nums): n(nums.size()+1), input(nums), BIT(n, 0) {
        for (int i = 1; i < n; ++i) {
            updateBIT(i, nums[i-1]);
        }
    }

    void update(int i, int val) {
        updateBIT(i+1, val-input[i]);
        input[i] = val;
    }

    int sumRange(int i, int j) {
        return getSum(j+1)-getSum(i);
    }
private:
    int n;
    vector<int> input, BIT;
    void updateBIT(int i, int val) {

        while (i < n) {
            BIT[i] += val;
            i += i & (-i); // get next
        }

    }

    int getSum(int i) {
        int sum = 0;
        while (i > 0) {
            sum += BIT[i];
            i -= i & (-i); // get parent;
        }
        return sum;
    }
};

results matching ""

    No results matching ""