Skip to content

线段树

SegmentTree

  • tree/segment-tree/SegmentTree.js
/**
 * 查找,更新,区间查找,复杂度都是 O(logn)
 * 适合处理数据区间
 * */
export default class SegmentTree {
  constructor(inputArray = [], operation) {
    this.inputArray = inputArray;
    this.operation = operation || ((a, b) => a + b);

    this.segmentTree = this.initSegmentTree(inputArray);

    this.buildSegmentTree();
  }

  /**
   * 初始化数组,长度为完全二叉树节点数
   * 
   * @param {number[]} inputArray
   * @return {number[]}
   */
  initSegmentTree(inputArray = []) {
    // this.inputArray = inputArray;
    const inputArrayLength = inputArray.length; // 即叶子结点数

    if (!inputArrayLength) {
      return [];
    }

    const nextPower = Math.ceil(Math.log2(inputArrayLength)) + 1;
    const segmentTreeArrayLength = 2 ** nextPower - 1;

    return new Array(segmentTreeArrayLength).fill(null);
  }

  /**
   * Build segment tree.
   */
  buildSegmentTree() {
    const startInputIndex = 0;
    const endInputIndex = this.inputArray.length - 1;
    const position = 0;
    this.buildTreeRecursively(startInputIndex, endInputIndex, position);
  }

  /**
   * Build segment tree recursively.
   *
   * @param {number} startInputIndex
   * @param {number} endInputIndex
   * @param {number} position
   */
  buildTreeRecursively(startInputIndex, endInputIndex, position) {
    if (startInputIndex >= endInputIndex) {
      this.segmentTree[position] = this.inputArray[startInputIndex];
      return;
    }

    const middleIndex = Math.floor((startInputIndex + endInputIndex) / 2);

    this.buildTreeRecursively(startInputIndex, middleIndex, this.getLeftChildIndex(position));
    this.buildTreeRecursively(middleIndex + 1, endInputIndex,  this.getRightChildIndex(position));

    this.segmentTree[position] = this.operation(
      this.segmentTree[this.getLeftChildIndex(position)],
      this.segmentTree[this.getRightChildIndex(position)],
    );
  }

   /**
   * Update segment tree;
   *
   * @param {number} index: inputArray 下标
   * @param {number} value
   */
  update(index, value) {
    const startInputIndex = 0;
    const endInputIndex = this.inputArray.length - 1;
    const position = 0;

    this.updateRecursively(startInputIndex, endInputIndex, position, index, value);
  }

  updateRecursively(startInputIndex, endInputIndex, position, index, value) {
    if (startInputIndex >= endInputIndex) {
      this.inputArray[startInputIndex] = value;
      this.segmentTree[position] = value;
      return;
    }

    const middleIndex = Math.floor((startInputIndex + endInputIndex) / 2);

    if (index >= startInputIndex && index <= middleIndex) {
      this.updateRecursively(startInputIndex, middleIndex, this.getLeftChildIndex(position), index, value);
    } else {
       this.updateRecursively(middleIndex + 1, endInputIndex, this.getRightChildIndex(position), index, value);
    }

    this.segmentTree[position] = this.operation(
      this.segmentTree[this.getLeftChildIndex(position)],
      this.segmentTree[this.getRightChildIndex(position)],
    );
  }

  /**
   * Do range query on segment tree in context of this.operation function.
   *
   * @param {number} queryLeftIndex
   * @param {number} queryRightIndex
   * @return {number}
   */
  rangeQuery(queryLeftIndex, queryRightIndex) {
    const startIndex = 0;
    const endIndex = this.inputArray.length - 1;
    const position = 0;

    return this.rangeQueryRecursive(queryLeftIndex, queryRightIndex, startIndex, endIndex, position);
  }

  rangeQueryRecursive(queryLeftIndex, queryRightIndex, startIndex, endIndex, position) {
    if (queryLeftIndex > endIndex || queryRightIndex < startIndex) {
      // No overlap
      return null;
    }

    if (queryLeftIndex <= startIndex && queryRightIndex >= endIndex) {
      // Total overlap
      return this.segmentTree[position];
    }

    // Partial overlap
    const middleIndex = Math.floor((startIndex + endIndex) / 2);

    const leftResult = this.rangeQueryRecursive(queryLeftIndex, queryRightIndex, startIndex, middleIndex, this.getLeftChildIndex(position));
    const rightResult = this.rangeQueryRecursive(queryLeftIndex, queryRightIndex, middleIndex + 1, endIndex, this.getRightChildIndex(position));

    return this.operation(leftResult, rightResult);

  }

  getLeftChildIndex(position) {
    return position * 2 + 1;
  }

  getRightChildIndex(position) {
    return position * 2 + 2;
  }
}