Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

CSC3200 Data Structures and Advanced Programming

Introduction

Teaching C++ first.

Then data sturctures.

Assignment 1

Trigonometric Functions

Implement some of the trigonometric functions from scratch without using any mathematical functions from STL.

student_math.h

#ifndef STUDENT_MATH_H
#define STUDENT_MATH_H

namespace student_std {
    double sin(double x); //x is an angle in radian
    double sin_deg(double x); //x is an angle in degree
    double cos(double x); //x is an angle in radian
    double cos_deg(double x); //x is an angle in degree
    double tan(double x); //x is an angle in radian
    double tan_deg(double x); //x is an angle in degree
    double cot(double x); //x is an angle in radian
    double cot_deg(double x); //x is an angle in degree
}

#endif

student_math.cpp

#include "student_math.h"

const double pi = 3.1415926535897932384626;
const double esp = 1e-8;

double toRadian(double x) {
    return pi*x/180.0;
}

double dcmp(double x) {
    return (x < -esp ? -1 : (x > esp ? 1 : 0));
}

double fabs(double x) {
    return (dcmp(x) == -1 ? -x : x);
}

double TrigonometricLagrange(double x, bool issin) {
    while(x > 2.0*pi) x -= 2.0*pi;
    while(x < -2.0*pi) x += 2.0*pi;
    double numerator = 1, denominator = 1;
    double timescnt = 1;
    if(issin) {
        numerator = x;
        timescnt = 2;
    }
    double current = numerator/denominator, last = numerator/denominator;
    double re = current;
    while(fabs(current) > esp) {
        current = current*x*x/(timescnt)/(timescnt+1.0)*(-1.0);
        timescnt += 2;
        re += current;
    }
    return (fabs(re) > esp ? re : 0);
}

namespace student_std {
    double sin(double x) {
        return TrigonometricLagrange(x, 1);
    }
    double sin_deg(double x) {
        return sin(toRadian(x));
    }
    double cos(double x) {
        return TrigonometricLagrange(x, 0);
    }
    double cos_deg(double x) {
        return cos(toRadian(x));
    }
    double tan(double x) {
        return sin(x)/cos(x);
    }
    double tan_deg(double x) {
        return sin(toRadian(x))/cos(toRadian(x));
    }
    double cot(double x) {
        return 1.0/tan(x);
    }
    double cot_deg(double x) {
        return 1.0/tan(toRadian(x));
    }
}

String

Implement a C++ string class from scratch without using STL.

student_string.h

#ifndef STUDENT_STRING_H
#define STUDENT_STRING_H

#define MAXLEN 256
namespace student_std {
    class string {
        public:
            string();
            string(const char* str);
            string(string const&);
            ~string();
            int get_strlen() const;
            const char* get_c_str() const;
            void strcat(string const&);
            string& operator=(string const&);
            string& operator+=(string const&);
            char& operator[](int);
            const char& operator[](int) const;
            void to_upper();
            void to_lower();
            void strcopy(string const&);
            bool equals(string const&) const;
            bool equals_ignore_case(string const&) const;
            void trim(); // Removes spaces ' ' from beginning and end
            bool is_alphabetic() const;
        private: // hint
            char c_str[MAXLEN];
            int strlen;
    };
}

#endif

student_string.cpp

#include "student_string.h"

namespace student_std {
    string::string() {
        this->c_str[0] = '\0';
        this->strlen = 0;
    }

    string::string(const char* str) {
        int len = 0;
        while(str[len] != '\0') {
            len++;
        }
        this->strlen = len;
        for(int i = 0; i <= len; i++) {
            this->c_str[i] = str[i];
        }
    }

    string::string(string const& str) {
        this->strlen = str.strlen;
        for(int i = 0; i <= this->strlen; i++) {
            this->c_str[i] = str.c_str[i];
        }
    }

    string::~string() {}

    int string::get_strlen() const {
        return this->strlen;
    }

    const char* string::get_c_str() const {
        return this->c_str;
    }

    void string::strcat(string const& str) {
        int len = this->strlen;
        for(int i = 0; i <= str.strlen; i++) {
            this->c_str[i+len] = str.c_str[i];
        }
        this->strlen = len+str.strlen;
    }

    string& string::operator=(string const& str) {
        if(this == &str) {
            return *this;
        }
        this->strlen = str.strlen;
        for(int i = 0; i <= this->strlen; i++) {
            this->c_str[i] = str.c_str[i];
        }
        return *this;
    }

    string& string::operator+=(string const& str) {
        this->strcat(str);
        return *this;
    }

    char& string::operator[](int i) {
        return this->c_str[i];
    }

    const char& string::operator[](int i) const {
        return this->c_str[i];
    }

    void string::to_upper() {
        for(int i = 0; i < this->strlen; i++) {
            if(this->c_str[i] >= 'a' && this->c_str[i] <= 'z') {
                this->c_str[i] = char(c_str[i]+('A'-'a'));
            }
        }
    }

    void string::to_lower() {
        for(int i = 0; i < this->strlen; i++) {
            if(this->c_str[i] >= 'A' && this->c_str[i] <= 'Z') {
                this->c_str[i] = char(c_str[i]+('a'-'A'));
            }
        }
    }

    void string::strcopy(string const& str) {
        *this = str;
    }

    bool string::equals(string const& str) const {
        if(this->strlen != str.strlen) return false;
        for(int i = 0; i < this->strlen; i++) {
            if(this->c_str[i] != str.c_str[i]) return false;
        }
        return true;
    }

    bool string::equals_ignore_case(string const& str) const {
        string a = *this, b = str;
        a.to_lower();
        b.to_lower();
        return a.equals(b);
    }

    void string::trim() {
        int l = 0, r = this->strlen-1;
        while(this->c_str[l] == ' ') {
            l++;
        }
        while(this->c_str[r] == ' ') {
            r--;
        }
        for(int i = 0; i < r-l+1; i++) {
            this->c_str[i] = this->c_str[i+l];
        }
        for(int i = r-l+1; i <= this->strlen; i++) {
            this->c_str[i] = '\0';
        }
        this->strlen = r-l+1;
    }

    bool string::is_alphabetic() const {
        for(int i = 0; i < this->strlen; i++) {
            if(!((this->c_str[i] >= 'a' && this->c_str[i] <= 'z') || (this->c_str[i] >= 'A' && this->c_str[i] <= 'Z'))) {
                return false;
            }
        }
        return true;
    }
}

Assignment 2

Vector

Implement a C++ vector (dynamic array) from scratch without using STL.

student_vector.h

#ifndef STUDENT_VECTOR_H
#define STUDENT_VECTOR_H

#include <cstddef>
#include <cassert>
#include <stdexcept>
#include <iterator>

namespace student_std {
    template <typename T>
    class vector {

        public:
            using size_type = std::size_t;
            using difference_type = std::ptrdiff_t;
            using value_type = T;
        
        private:
            T* s;
            size_type sz;
            size_type cap;

            void reallocate(size_type newcap) {
                if(newcap < sz) newcap = sz;
                T* news = new T[newcap];
                for(size_type i = 0; i < sz; i++) {
                    news[i] = s[i];
                }
                delete[] s;
                s = news;
                cap = newcap;
            }

        public:
            vector() {
                s = nullptr;
                sz = 0;
                cap = 0;
            }

            vector(const vector& other) {
                s = nullptr;
                sz = other.sz;
                cap = other.cap;
                s = new T[cap];
                for(size_type i = 0; i < sz; i++) {
                    s[i] = other.s[i];
                }
            }

            vector& operator= (const vector& other) {
                if(this == &other) return *this;
                delete[] s;
                sz = other.sz;
                cap = other.cap;
                s = new T[cap];
                for(size_type i = 0; i < sz; i++) {
                    s[i] = other.s[i];
                }
                return *this;
            }

            ~vector() {
                delete[] s;
            }

            size_type size() const {
                return sz;
            }

            size_type capacity() const {
                return cap;
            }

            bool empty() const {
                return sz == 0;
            }

            const T* data() const {
                return s;
            }

            const T& at(size_type p) const {
                if(p >= sz) throw std::out_of_range("vector::at out of range");
                else return s[p];
            }

            const T& operator[] (size_type p) const {
                assert(p >= 0 && p < sz);
                return s[p];
            }

            const T& front() const {
                assert(sz > 0);
                return s[0];
            }

            const T& back() const {
                assert(sz > 0);
                return s[sz-1];
            }

            void reserve(size_type newcap) {
                if(newcap > cap) {
                    reallocate(newcap);
                }
            }

            void push_back(const T& val) {
                if(sz == cap) {
                    size_type newcap = (cap == 0 ? 1 : cap * 2);
                    reserve(newcap);
                }
                s[sz] = val;
                sz++;
            }

            void pop_back() {
                assert(sz > 0);
                sz--;
            }

            void resize(size_type newsz, const T& val) {
                if(newsz > cap) {
                    reserve(newsz);
                }
                if(newsz > sz) {
                    for(size_type i = sz; i < newsz; i++) {
                        s[i] = val;
                    }
                }
                sz = newsz;
            }

            void resize(size_type newsz) {
                resize(newsz, T());
            }

            void clear() {
                sz = 0;
            }

            void swap(vector& other) {
                std::swap(s, other.s);
                std::swap(sz, other.sz);
                std::swap(cap, other.cap);
            }

            T* data() {
                return s;
            }

            T& at(size_type p) {
                if(p >= sz) throw std::out_of_range("vector::at out of range");
                else return s[p];
            }

            T& operator[] (size_type p) {
                assert(p >= 0 && p < sz);
                return s[p];
            }

            T& front() {
                assert(sz > 0);
                return s[0];
            }

            T& back() {
                assert(sz > 0);
                return s[sz-1];
            }

        public:
            class iterator {
                public:
                    using difference_type = std::ptrdiff_t;
                    using value_type = T;
                    using pointer = T*;
                    using reference = T&;
                    using iterator_category = std::random_access_iterator_tag;
                
                private:
                    pointer ptr;

                public:
                    iterator() {
                        ptr = nullptr;
                    }

                    iterator(pointer p) {
                        ptr = p;
                    }

                    reference operator*() const {
                        return *ptr;
                    }

                    pointer operator->() const {
                        return ptr;
                    }

                    iterator& operator++() {
                        ptr++;
                        return *this;
                    }

                    iterator operator++(int) {
                        iterator tmp = *this;
                        ptr++;
                        return tmp;
                    }

                    iterator& operator--() {
                        ptr--;
                        return *this;
                    }

                    iterator operator--(int) {
                        iterator tmp = *this;
                        ptr--;
                        return tmp;
                    }

                    iterator& operator+=(difference_type num) {
                        ptr += num;
                        return *this;
                    }

                    iterator& operator-=(difference_type num) {
                        ptr -= num;
                        return *this;
                    }

                    iterator operator+(difference_type num) const {
                        return iterator(ptr+num);
                    }

                    iterator operator-(difference_type num) const {
                        return iterator(ptr-num);
                    }

                    difference_type operator-(const iterator& other) const {
                        return ptr-other.ptr;
                    }

                    bool operator==(const iterator& other) const{
                        return ptr == other.ptr;
                    }
                    
                    bool operator!=(const iterator& other) const{
                        return ptr != other.ptr;
                    }
                    
                    bool operator<=(const iterator& other) const{
                        return ptr <= other.ptr;
                    }
                    
                    bool operator>=(const iterator& other) const{
                        return ptr >= other.ptr;
                    }
                    
                    bool operator<(const iterator& other) const{
                        return ptr < other.ptr;
                    }
                    
                    bool operator>(const iterator& other) const{
                        return ptr > other.ptr;
                    }
            };
            
        public:
            iterator begin() {
                return iterator(s);
            }

            iterator end() {
                return iterator(s+sz);
            }

            iterator erase(iterator it) {
                assert(it >= begin() && it < end());
                T* p = &(*it);
                for(T* i = p; i+1 < s+sz; i++) {
                    *i = *(i+1);
                }
                sz--;
                return iterator(p);
            }

            iterator erase(iterator it1, iterator it2) {
                assert(it1 >= begin() && it1 < it2 && it2 <= end());
                T* p1 = &(*it1);
                T* p2 = &(*it2);
                size_type erasecnt = p2 - p1;
                for(T* i = p1; i+erasecnt < s + sz; i++) {
                    *i = *(i+erasecnt);
                }
                sz -= erasecnt;
                return iterator(p1);
            }

            iterator insert(iterator it, const T& val) {
                assert(it >= begin() && it <= end());
                size_type p = it - begin();
                if(sz == cap) reserve(cap == 0 ? 1 : cap * 2);
                it = iterator(s + p);
                for(size_type i = sz; i > p; i--) {
                    s[i] = s[i-1];
                }
                s[p] = val;
                sz++;
                return it;
            }
    };
}

#endif

Assignment 3

Maze

Too simple.

Priority Queue

Priority queue with list and $O(n)$ operation ??? What the fk are you doing CUHKSZ

student_priority_queue.h

#ifndef STUDENT_PRIORITY_QUEUE_H
#define STUDENT_PRIORITY_QUEUE_H

#include <list>

namespace student_std {
    template <typename T, typename Container = std::list<T>>
    class priority_queue {
        public:
            using container_type = Container;
            using value_type = typename Container::value_type;
            using size_type = typename Container::size_type;
            // ...
        private:
            Container c;
            static bool cmp(T a, T b) {
                return b < a;
            }
        public:
            priority_queue() = default;
            priority_queue(const priority_queue &other) {
                c = other.c;
            }
            priority_queue &operator = (const priority_queue &other) {
                if(this != &other) c = other.c;
                return this;
            }
            value_type const& top() const {
                return *c.begin();
            }
            void pop() {
                c.erase(c.begin());
            }
            size_type size() const {
                return c.size();
            }
            bool empty() const {
                return c.empty();
            }
            void push(const value_type &val) {
                auto it = c.begin();
                while(it != c.end() && val < *it) it++;
                c.insert(it, val);
            }
            void swap(priority_queue &other) {
                c.swap(other.c);
            }
    };
}

#endif

Assignment 4

Unordered Map

Implement an unordered map using bucket hashing.

Bucket type: vector and list

student_unordered_map.h

#ifndef STUDENT_UNORDERED_MAP_H
#define STUDENT_UNORDERED_MAP_H

#include <vector>
#include <list>
#include <functional>
#include <utility>
#include <cassert>

namespace student_std {
    template <typename Key, typename T, typename Hash = std::hash<Key>>
    class unordered_map {
        public:
            using key_type = Key;
            using mapped_type = T;
            using size_type = std::size_t;
            using difference_type = std::ptrdiff_t;
            using value_type = std::pair<Key, T>;
            using hasher = Hash;
            using reference = value_type&;
            using const_reference = const value_type&;
        
        private:
            std::vector <std::list <value_type>> buckets;
            size_type sz = 0;
            hasher hsh;

            size_type index_for(Key const& k) const {
                return hsh(k) % buckets.size();
            }

            double load_factor() const {
                return double(sz) / double(buckets.size());
            }
            
            void rehash(size_type newsz) {
                auto tmp = std::move(buckets);
                buckets = std::vector <std::list <value_type>> (newsz);
                for(auto& lst : tmp) {
                    for(auto& x : lst) {
                        size_type idx = index_for(x.first);
                        buckets[idx].push_back(std::move(x));
                    }
                }
            }

        public:
            unordered_map(size_type init_sz = 8) : buckets(init_sz) {}

            size_type size() const {
                return sz;
            }

            bool empty() const {
                return sz == 0;
            }

            size_type bucket_count() const {
                return buckets.size();
            }

            bool contains(key_type const& k) const {
                size_type idx = index_for(k);
                for(auto const& x : buckets[idx]) {
                    if(x.first == k) {
                        return true;
                    }
                }
                return false;
            }

            void clear() {
                for(auto& x : buckets) {
                    x.clear();
                }
                sz = 0;
            }

            size_type erase(key_type const& k) {
                size_type idx = index_for(k);
                auto& lst = buckets[idx];
                for(auto it = lst.begin(); it != lst.end(); it++) {
                    if(it->first == k) {
                        lst.erase(it);
                        sz--;
                        return 1;
                    }
                }
                return 0;
            }

            T& operator [](key_type const& k) {
                size_type idx = index_for(k);
                auto& lst = buckets[idx];
                for(auto& x : lst) {
                    if(x.first == k) {
                        return x.second;
                    }
                }
                
                lst.emplace_back(k, T());
                sz++;
                if(load_factor() >= 2.0) {
                    rehash(buckets.size()*2);
                }
                idx = index_for(k);
                auto& nlst = buckets[idx];
                for(auto& x : nlst) {
                    if(x.first == k) {
                        return x.second;
                    }
                }
                assert(false);
                static T re{};
                return re;
            }
    };
}

#endif

BST's Inorder Iterator

Implement the binary tree traversal inorder forward iterator from scratch.

student_inorder_iterator.h

#ifndef STUDENT_INORDER_ITERATOR_H
#define STUDENT_INORDER_ITERATOR_H

namespace student_std {
    template <typename BinaryTree>
    class inorder_iterator {
        public:
            using value_type = typename BinaryTree::value_type;
            using difference_type = std::ptrdiff_t;
            using iterator_category = std::forward_iterator_tag;
            using reference = value_type const&;
            using pointer = value_type const*;
        
        private:
            BinaryTree const* cur;
        
        public:
            inorder_iterator() : cur(nullptr) {}

            inorder_iterator(BinaryTree const* node) {
                cur = node;
                if(cur) {
                    while(cur->left()) cur = cur->left();
                }
            }

            reference operator*() const {
                return cur->value();
            }

            pointer operator->() const {
                return &(cur->value());
            }

            inorder_iterator& operator++() {
                if(!cur) return *this;
                if(cur->right()) {
                    cur = cur->right();
                    while(cur->left()) cur = cur->left();
                    return *this;
                }
                auto p = cur->parent();
                while(p && p->right() == cur) {
                    cur = p;
                    p = p->parent();
                }
                cur = p;
                return *this;
            }
            
            inorder_iterator operator++(int) {
                inorder_iterator tmp = *this;
                ++(*this);
                return tmp;
            }

            bool operator==(inorder_iterator const& other) const {
                return cur == other.cur;
            }
            
            bool operator!=(inorder_iterator const& other) const {
                return !(*this == other);
            }
    };
}

#endif

Assignment 5

AVL Tree

Implement an AVL Tree from scratch.

student_avl_tree.h

#ifndef STUDENT_AVL_TREE_H
#define STUDENT_AVL_TREE_H

#include <algorithm>
#include <functional>
#include <utility>
#include <cstddef>
#include <iterator>

namespace student_std {
    template <typename Key, typename Comp = std::less<Key>>
    class avl_tree {
        class avl_node {
            public:
                using size_type = std::size_t;
                using difference_type = std::ptrdiff_t;

                avl_node(const Key& k, avl_node* p = nullptr) : 
                    m_key(k), m_parent(p), m_left(nullptr), m_right(nullptr), m_size(1), m_height(0) {}
                Key const& value() const { return m_key; }; 
                avl_node const* parent() const { return m_parent; }; 
                avl_node const* left() const { return m_left; }; 
                avl_node const* right() const{ return m_right; }; 
                
                size_type size() const { return m_size; } 
                std::ptrdiff_t height() const { return m_height; } 
            
            private:
                size_type m_size; 
                std::ptrdiff_t m_height; 
                Key m_key;
                avl_node* m_parent;
                avl_node* m_left;
                avl_node* m_right;
                friend class avl_tree;
        };

        class iterator {
            public:
                using value_type = avl_node;
                using reference = value_type const&;
                using pointer = value_type const*;
                using difference_type = std::ptrdiff_t;
                using iterator_category = std::bidirectional_iterator_tag;
                
                iterator(avl_node* node = nullptr) : m_node{node} {}
                iterator(pointer node) : m_node{const_cast<avl_node*>(node)} {}

                iterator& operator++() { // O(log n)
                    if (m_node == nullptr) return *this;
                    
                    if (m_node->m_right) {
                        m_node = m_node->m_right;
                        while (m_node->m_left) {
                            m_node = m_node->m_left;
                        }
                    } else {
                        avl_node* p = m_node->m_parent;
                        while (p && m_node == p->m_right) {
                            m_node = p;
                            p = p->m_parent;
                        }
                        m_node = p;
                    }
                    return *this;
                }

                iterator operator++(int) { // O(log n)
                    iterator tmp = *this;
                    ++(*this);
                    return tmp;
                }

                iterator& operator--() { // O(log n)
                    if (m_node == nullptr) return *this; 

                    if (m_node->m_left) {
                        m_node = m_node->m_left;
                        while (m_node->m_right) {
                            m_node = m_node->m_right;
                        }
                    } else {
                        avl_node* p = m_node->m_parent;
                        while (p && m_node == p->m_left) {
                            m_node = p;
                            p = p->m_parent;
                        }
                        m_node = p;
                    }
                    return *this;
                }

                iterator operator--(int) { // O(log n)
                    iterator tmp = *this;
                    --(*this);
                    return tmp;
                }

                reference operator*() const { // O(1)
                    return *m_node;
                }

                pointer operator->() const { // O(1)
                    return m_node;
                }

                bool operator==(iterator const& other) const {
                    return m_node == other.m_node;
                }
                bool operator!=(iterator const& other) const {
                    return m_node != other.m_node;
                }
            private:
                avl_node* m_node;
                friend class avl_tree;
        };

        public:
            using key_type = Key;
            using node_type = avl_node;
            using size_type = std::size_t;
            using comparison = Comp;
            using const_iterator = iterator;

            avl_tree() : m_root(nullptr), m_comp(comparison()) {}
            ~avl_tree() { clear(m_root); }

            iterator insert(Key const& key) { // O(log n)
                if (!m_root) {
                    m_root = new avl_node(key);
                    return iterator(m_root);
                }

                avl_node* current = m_root;
                avl_node* parent = nullptr;

                while (current) {
                    parent = current;
                    if (m_comp(key, current->m_key)) {
                        current = current->m_left;
                    } else if (m_comp(current->m_key, key)) {
                        current = current->m_right;
                    } else {
                        return iterator(current);
                    }
                }

                avl_node* new_node = new avl_node(key, parent);
                if (m_comp(key, parent->m_key)) {
                    parent->m_left = new_node;
                } else {
                    parent->m_right = new_node;
                }

                rebalance_path(new_node);
                return iterator(new_node);
            }

            iterator erase(Key const& key) { // O(log n)
                avl_node* node = find_node(key);
                if (!node) return end();

                avl_node* node_to_delete = node;
                avl_node* successor_to_return = nullptr;

                if (node->m_right) {
                    successor_to_return = node->m_right;
                    while (successor_to_return->m_left) {
                        successor_to_return = successor_to_return->m_left;
                    }
                } else {
                    avl_node* p = node->m_parent;
                    while (p && node == p->m_right) {
                        node = p;
                        p = p->m_parent;
                    }
                    successor_to_return = p;
                }

                if (node_to_delete->m_left && node_to_delete->m_right) {
                    avl_node* successor_swap = node_to_delete->m_right;
                    while (successor_swap->m_left) {
                        successor_swap = successor_swap->m_left;
                    }
                    
                    node_to_delete->m_key = successor_swap->m_key; 
                    
                    node_to_delete = successor_swap;
                }

                avl_node* child = (node_to_delete->m_left) ? node_to_delete->m_left : node_to_delete->m_right;
                avl_node* parent_of_deleted = node_to_delete->m_parent;
                
                avl_node* rebalance_start_node = parent_of_deleted; 

                if (child) {
                    child->m_parent = parent_of_deleted;
                }

                if (!parent_of_deleted) {
                    m_root = child;
                } else {
                    if (parent_of_deleted->m_left == node_to_delete) {
                        parent_of_deleted->m_left = child;
                    } else {
                        parent_of_deleted->m_right = child;
                    }
                }

                delete node_to_delete;
                if (rebalance_start_node) {
                    rebalance_path(rebalance_start_node);
                }
                return iterator(successor_to_return); 
            }

            iterator find(Key const& key) const { // O(log n)
                avl_node* res = find_node(key);
                return iterator(res);
            }

            bool contains(Key const& key) const { // O(log n)
                return find_node(key) != nullptr;
            }

            size_type size() const { // O(1)
                return m_root ? m_root->m_size : 0;
            }

            std::ptrdiff_t height() const { // O(1)
                return m_root ? m_root->m_height : -1;
            }

            std::ptrdiff_t getBL() const {
                return get_balance_factor(m_root);
            }

            iterator begin() const { // O(log n)
                if (!m_root) return iterator(static_cast<avl_node*>(nullptr));
                avl_node* curr = m_root;
                while (curr->m_left) {
                    curr = curr->m_left;
                }
                return iterator(curr);
            }

            iterator end() const { 
                return iterator(static_cast<avl_node*>(nullptr));
            }

            iterator root() const { // O(1)
                return iterator(m_root);
            }

        private:
            avl_node* m_root;
            comparison m_comp;

            avl_node* find_node(Key const& key) const { // O(log n)
                avl_node* curr = m_root;
                while (curr) {
                    if (m_comp(key, curr->m_key)) {
                        curr = curr->m_left;
                    } else if (m_comp(curr->m_key, key)) {
                        curr = curr->m_right;
                    } else {
                        return curr;
                    }
                }
                return nullptr;
            }

            void clear(avl_node* node) {
                if (node) {
                    clear(node->m_left);
                    clear(node->m_right);
                    delete node;
                }
            }

            std::ptrdiff_t get_height(avl_node* n) const {
                return n ? n->m_height : -1;
            }

            size_type get_size(avl_node* n) const {
                return n ? n->m_size : 0;
            }

            std::ptrdiff_t get_balance_factor(avl_node* n) const {
                if (!n) return 0;
                return get_height(n->m_left) - get_height(n->m_right);
            }

            void update_stats(avl_node* n) {
                if (n) {
                    n->m_height = 1 + std::max(get_height(n->m_left), get_height(n->m_right));
                    n->m_size = 1 + get_size(n->m_left) + get_size(n->m_right);
                }
            }

            void rotate_right(avl_node* y) {
                avl_node* x = y->m_left;
                avl_node* T2 = x->m_right;

                x->m_right = y;
                y->m_left = T2;

                x->m_parent = y->m_parent;
                y->m_parent = x;
                if (T2) T2->m_parent = y;

                if (x->m_parent) {
                    if (x->m_parent->m_left == y) x->m_parent->m_left = x;
                    else x->m_parent->m_right = x;
                } else {
                    m_root = x;
                }

                update_stats(y);
                update_stats(x);
            }

            void rotate_left(avl_node* x) {
                avl_node* y = x->m_right;
                avl_node* T2 = y->m_left;

                y->m_left = x;
                x->m_right = T2;

                y->m_parent = x->m_parent;
                x->m_parent = y;
                if (T2) T2->m_parent = x;

                if (y->m_parent) {
                    if (y->m_parent->m_left == x) y->m_parent->m_left = y;
                    else y->m_parent->m_right = y;
                } else {
                    m_root = y;
                }

                update_stats(x);
                update_stats(y);
            }

            void rebalance_path(avl_node* node) {
                while (node) {
                    update_stats(node);
                    std::ptrdiff_t balance = get_balance_factor(node);

                    // Left Heavy
                    if (balance > 1) {
                        if (get_balance_factor(node->m_left) < 0) {
                            rotate_left(node->m_left); // LR
                        }
                        rotate_right(node); // LL or LR
                    }
                    // Right Heavy
                    else if (balance < -1) {
                        if (get_balance_factor(node->m_right) > 0) {
                            rotate_right(node->m_right); // RL
                        }
                        rotate_left(node); // RR or RL
                    }
                    node = node->m_parent; 
                }
            }
    };
}

#endif