practice · medium · from Binary Search Trees

Kth Smallest Element in a BST

time
O(h + k)
space
O(h)
tags
bst · in-order

in the wild Order-statistic queries on a sorted index: 'show me the 100th highest-priority pending task', 'find the median sensor reading in this BST-backed sketch'.

Problem

Given the root of a BST and an integer k (1-indexed), return the value of the k-th smallest element.

You may assume k is always valid: 1 ≤ k ≤ n, where n is the number of nodes.

Examples

        3
       / \
      1   4         k=1 → 1
       \           k=2 → 2
        2          k=3 → 3
                   k=4 → 4

        5
       / \
      3   6         k=1 → 1
     / \           k=3 → 3
    2   4          k=6 → 6
   /
  1

Why this is the canonical "in-order = sorted" problem

The crucial BST fact: in-order traversal visits keys in increasing order. So "k-th smallest" is just "stop the in-order traversal after k visits". You're not searching — you're counting.

Three ways to implement, with strictly increasing elegance:

  1. Collect everything, take index k-1. O(n) time, O(n) space. Wasteful — you read past the k-th element.
  2. Recursive in-order with a counter, short-circuit when you hit k. O(h + k) time, O(h) space. Better — but recursion is hard to "pause" cleanly.
  3. Iterative in-order using an explicit stack. O(h + k) time, O(h) space — and the stack also gives you a proper iterator, so you could query successive k's without retraversing.

The third version is what real database engines do for ordered indices: they treat the BST as a cursor that can step "next smallest" on demand. Once you've internalized the iterative in-order pattern, you can answer "k-th smallest", "all values in range [lo, hi]", "kth order-statistic with deletion", and "merge two BSTs" — all variations on the same machinery.

Hints

Hint 1 — what traversal order gives sorted?

Which of the four traversal orders (pre / in / post / level) visits the keys of a BST in sorted ascending order? (See the BST lesson — "the corollary".)

Hint 2 — short-circuit

You don't need to traverse the whole tree. Once you've visited k nodes in-order, the most-recent one is your answer. Pass a mutable counter down, and bail out when it hits k.

Hint 3 — the iterative in-order pattern
stack = []
node  = root
while stack or node:
    while node:                      # push all-left chain
        stack.append(node)
        node = node.left
    node = stack.pop()               # process: this is the next-smallest
    k -= 1
    if k == 0: return node.val
    node = node.right                # move to right subtree

This is the iterative-in-order pattern. Memorize it once; it pays dividends across every ordered-tree problem.

Solution

from dataclasses import dataclass
from typing import Optional

@dataclass
class Node:
    val: int
    left:  Optional["Node"] = None
    right: Optional["Node"] = None

def kth_smallest(root: Optional[Node], k: int) -> int:
    # Iterative in-order using an explicit stack.
    stack: list[Node] = []
    node = root
    while stack or node is not None:
        while node is not None:           # walk all the way left
            stack.append(node)
            node = node.left
        node = stack.pop()                # visit smallest unvisited
        k -= 1
        if k == 0:
            return node.val
        node = node.right                 # in-order proceeds right
    raise ValueError("k out of range")

# Build the BST: insert in BST order.
def insert(root: Optional[Node], v: int) -> Node:
    if root is None:
        return Node(v)
    if v < root.val: root.left  = insert(root.left,  v)
    else:            root.right = insert(root.right, v)
    return root

root: Optional[Node] = None
for v in [5, 3, 6, 2, 4, 1]:
    root = insert(root, v)

cases = [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6)]
for k, expected in cases:
    got  = kth_smallest(root, k)
    mark = "✓" if got == expected else "✗"
    print(f"{mark} kth_smallest(root, k={k}) = {got}  (expected {expected})")
class Node {
  constructor(val, left = null, right = null) {
    this.val = val; this.left = left; this.right = right;
  }
}

function kthSmallest(root, k) {
  const stack = [];
  let node = root;
  while (stack.length || node !== null) {
    while (node !== null) {                  // push the all-left chain
      stack.push(node);
      node = node.left;
    }
    node = stack.pop();                       // smallest unvisited
    if (--k === 0) return node.val;
    node = node.right;
  }
  throw new Error("k out of range");
}

function insert(root, v) {
  if (root === null) return new Node(v);
  if (v < root.val)  root.left  = insert(root.left,  v);
  else                root.right = insert(root.right, v);
  return root;
}

let root = null;
for (const v of [5, 3, 6, 2, 4, 1]) root = insert(root, v);

const cases = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6]];
for (const [k, expected] of cases) {
  const got = kthSmallest(root, k);
  console.log(`${got === expected ? "✓" : "✗"} kthSmallest(root, k=${k}) = ${got}  (expected ${expected})`);
}
pub struct Node {
    pub val:   i32,
    pub left:  Option<Box<Node>>,
    pub right: Option<Box<Node>>,
}

pub fn kth_smallest(root: &Option<Box<Node>>, mut k: i32) -> i32 {
    let mut stack: Vec<&Node> = Vec::new();
    let mut cur = root.as_deref();
    while !stack.is_empty() || cur.is_some() {
        while let Some(n) = cur {            // push all-left chain
            stack.push(n);
            cur = n.left.as_deref();
        }
        let n = stack.pop().unwrap();        // smallest unvisited
        k -= 1;
        if k == 0 { return n.val; }
        cur = n.right.as_deref();
    }
    panic!("k out of range");
}

fn insert(root: Option<Box<Node>>, v: i32) -> Option<Box<Node>> {
    match root {
        None    => Some(Box::new(Node { val: v, left: None, right: None })),
        Some(mut n) => {
            if v < n.val { n.left  = insert(n.left,  v); }
            else          { n.right = insert(n.right, v); }
            Some(n)
        }
    }
}

fn main() {
    let mut root: Option<Box<Node>> = None;
    for v in [5, 3, 6, 2, 4, 1] { root = insert(root, v); }
    for k in 1..=6 {
        let got  = kth_smallest(&root, k);
        let mark = if got == k { "✓" } else { "✗" };
        println!("{mark} kth_smallest(root, k={}) = {}  (expected {})", k, got, k);
    }
}
#include <iostream>
#include <stack>

struct Node {
    int   val;
    Node* left;
    Node* right;
    Node(int v, Node* l = nullptr, Node* r = nullptr) : val(v), left(l), right(r) {}
};

int kth_smallest(Node* root, int k) {
    std::stack<Node*> stk;
    Node* cur = root;
    while (!stk.empty() || cur) {
        while (cur) { stk.push(cur); cur = cur->left; }   // all-left chain
        cur = stk.top(); stk.pop();
        if (--k == 0) return cur->val;
        cur = cur->right;
    }
    return -1;  // unreachable given a valid k
}

Node* insert(Node* root, int v) {
    if (!root) return new Node(v);
    if (v < root->val) root->left  = insert(root->left,  v);
    else                root->right = insert(root->right, v);
    return root;
}

int main() {
    Node* root = nullptr;
    for (int v : {5, 3, 6, 2, 4, 1}) root = insert(root, v);
    for (int k = 1; k <= 6; k++) {
        int got = kth_smallest(root, k);
        std::cout << (got == k ? "✓" : "✗")
                  << " kth_smallest(root, k=" << k << ") = " << got << "\n";
    }
}
#include <stdio.h>
#include <stdlib.h>

typedef struct Node {
    int          val;
    struct Node* left;
    struct Node* right;
} Node;

// Iterative in-order with an explicit stack of pointers.
int kth_smallest(Node* root, int k) {
    Node* stack[1024];
    int   sp = 0;
    Node* cur = root;
    while (sp > 0 || cur != NULL) {
        while (cur != NULL) { stack[sp++] = cur; cur = cur->left; }
        cur = stack[--sp];
        if (--k == 0) return cur->val;
        cur = cur->right;
    }
    return -1;
}

static Node* mk(int v) {
    Node* n = malloc(sizeof(Node));
    n->val = v; n->left = NULL; n->right = NULL;
    return n;
}
static Node* insert(Node* root, int v) {
    if (!root) return mk(v);
    if (v < root->val) root->left  = insert(root->left,  v);
    else                root->right = insert(root->right, v);
    return root;
}

int main(void) {
    Node* root = NULL;
    int vs[] = {5, 3, 6, 2, 4, 1};
    for (int i = 0; i < 6; i++) root = insert(root, vs[i]);
    for (int k = 1; k <= 6; k++) {
        int got = kth_smallest(root, k);
        printf("%s kth_smallest(root, k=%d) = %d\n", got == k ? "✓" : "✗", k, got);
    }
    return 0;
}

The iterative in-order template — the "push the all-left chain, pop and visit, recurse right" loop — is one of the highest-leverage patterns in tree algorithmics. The same five lines power "k-th smallest", "range query", and iterator-style BST traversal.

Complexity

Time
O(h + k) — the all-left descent costs h; visiting k elements costs k. In the worst case (k = n), O(n). In the best case (k = 1 on a balanced tree), O(log n).
Space
O(h) — the stack holds the current ancestors of the in-flight node.
Augmented-tree alternative
If you also stored subtree_size at every node, kth_smallest drops to O(h) — at each step you go left, right, or stop based on the left subtree's size. That's exactly what an order-statistic tree does.

In the wild

Variations