레벨업 일지

[Java] leetcode 307. Range Sum Query - Mutable 본문

알고리즘/leetcode

[Java] leetcode 307. Range Sum Query - Mutable

24시간이모자란 2023. 5. 4. 00:01

문제

https://leetcode.com/problems/range-sum-query-mutable/

 

Range Sum Query - Mutable - LeetCode

Can you solve this real interview question? Range Sum Query - Mutable - Given an integer array nums, handle multiple queries of the following types: 1. Update the value of an element in nums. 2. Calculate the sum of the elements of nums between indices lef

leetcode.com

 

알아야 할 개념

  • 세그먼트 트리

풀이

풀이 알고리즘은 다음과 같다. 

  1. 세그먼트 트리 클래스를 구현.
  2. 트리 클래스의 메소드를 호출하여 정답을 리턴.

세그먼트 트리란?

세그먼트 트리(Segment Tree)는 구간 쿼리를 효율적으로 처리하기 위한 이진 트리 자료구조입니다. 구간 합, 최소값, 최대값 등과 같은 구간 쿼리를 빠르게 수행할 수 있도록 설계되어 있습니다. 세그먼트 트리는 배열을 분할하고 각 구간에 대한 값을 노드에 저장하여 사용합니다.

 

class Smt{
    int tree[]; //세그먼트 트리 배열
    int size; // 트리 사이즈
    ...
}

세그먼트 트리(SMT) 초기화 

  • 배열을 기반으로 세그먼트 트리를 구성한다. 이때 트리 초기 사이즈는 len( 배열 ) * 4 이다.
    이진 트리 기반으로 트리를 만들고 Left, Right 탐색인덱스가 최대 (root.index * 2 + 2 ) 이기 때문에 arr.length  * 4 를 해주면 배열 인덱스 충돌없이 커버 가능하다.
 public void initTree(int start, int end, int rootIdx, int[] arr){
        if(start > end)return;
        if(start == end){
            tree[rootIdx] = arr[start];
            return;
        }
        int L = rootIdx * 2 + 1;
        int R = rootIdx * 2 + 2;
        int P = start + (end - start)/2;
        
        initTree(start, P, L, arr);
        initTree(P+1, end, R, arr);
        tree[rootIdx] = tree[L] + tree[R];
    }

 

그다음, 배열의 특정 원소를 변경하거나 추가할 때 세그먼트 트리를 업데이트 수도 &  구현 코드이다.

getVal (현재 범위 start, 현재 탐색 범위 end, 현재 세그먼트 트리 루트 인덱스, 목표 범위 start, 목표 범위 end){
// 목표 범위가 현재 범위 밖에 있다면 0 반환
if (목표 범위 시작 > 현재 범위 끝 || 목표 범위 끝 < 현재 범위 시작) return 0;

 left 탐색 + right 탐색 한 값 리턴.
}
setVal(현재 시작과 끝 범위, 현재 루트 인덱스, 목표 idx, 더해줄 차이 값){
// 목표 인덱스가 현재 범위 밖에 있다면 종료.
if (현재 범위 시작 > 현재 범위 끝 || 목표 인덱스 < 현재 범위 시작 || 목표 인덱스 > 현재 범위 끝) return;

// 현재 루트 인덱스에 차이값을 더함
tree [현재 세그먼트 트리 루트 인덱스] += 차이값;

// 현재 범위 시작과 끝이 같다면 리턴 (리프 노드)
if (현재 범위 시작 == 현재 범위 끝) return;

Left , Right 탐색
}
 public int getVal(int start, int end, int rootIdx, int targets, int targete){
        if(targets > end || targete < start) return 0;
        if(targets <= start && targete >= end)return tree[rootIdx];
        int L = rootIdx * 2 + 1;
        int R = rootIdx * 2 + 2;
        int P = start + (end - start)/2;
        return getVal(start, P, L, targets, targete ) + getVal(P+1, end, R, targets, targete);
    }
    public void setVal(int start, int end, int rootIdx, int targetIdx, int dif){
        if(start > end || targetIdx < start || targetIdx > end)return;
        tree[rootIdx] += dif;
        if(start == end)return;
        int L = rootIdx * 2 + 1;
        int R = rootIdx * 2 + 2;
        int P = start + (end - start)/2;
        setVal(start, P, L,targetIdx, dif);
        setVal(P+1, end, R,targetIdx, dif);
    }

그림을 그려가며 따라가면 이해가 쉽다.

코드

전체 코드는 다음과 같다. 

 

자바

class Smt{
    int tree[];
    int size;
    Smt(int arr[]){
        size = arr.length*4;
        tree = new int[size];
        initTree(0, arr.length-1, 0 , arr);
    }
    public void initTree(int start, int end, int rootIdx, int[] arr){
        if(start > end)return;
        if(start == end){
            tree[rootIdx] = arr[start];
            return;
        }
        int L = rootIdx * 2 + 1;
        int R = rootIdx * 2 + 2;
        int P = start + (end - start)/2;
        
        initTree(start, P, L, arr);
        initTree(P+1, end, R, arr);
        tree[rootIdx] = tree[L] + tree[R];
    }
    public int getVal(int start, int end, int rootIdx, int targets, int targete){
        if(targets > end || targete < start) return 0;
        if(targets <= start && targete >= end)return tree[rootIdx];
        int L = rootIdx * 2 + 1;
        int R = rootIdx * 2 + 2;
        int P = start + (end - start)/2;
        return getVal(start, P, L, targets, targete ) + getVal(P+1, end, R, targets, targete);
    }
    public void setVal(int start, int end, int rootIdx, int targetIdx, int dif){
        if(start > end || targetIdx < start || targetIdx > end)return;
        tree[rootIdx] += dif;
        if(start == end)return;
        int L = rootIdx * 2 + 1;
        int R = rootIdx * 2 + 2;
        int P = start + (end - start)/2;
        setVal(start, P, L,targetIdx, dif);
        setVal(P+1, end, R,targetIdx, dif);
    }
}
class NumArray {
    Smt smt;
    int nums[];
    public NumArray(int[] nums) {
        this.nums = nums;
        smt = new Smt(nums);     
    }
    public void update(int index, int val) {
        int dif = val - nums[index];
        nums[index] = val;
        smt.setVal(0, nums.length-1, 0,index, dif);
    }
    
    public int sumRange(int left, int right) {
        return smt.getVal(0, nums.length-1,0, left, right);
    }
}
Comments