summary refs log blame commit diff stats
path: root/src/cratera.cratera.d/lib.cratera.d/bucket.cratera
blob: 8aa074b9146a37d8894873699cc68ca3559777ee (plain) (tree)


















































































































































































































































































































































































                                                                                                                                   
--[[
    Cratera Library - (User-defined) Hash Map
    Copyright (C) 2024  Soni L.

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.

    You should have received a copy of the GNU Affero General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
--]]

-- it'll be funny if this implementation turns out to also be similar to Lua's
-- own

local error = error
local Iter = Iter
local mktrait = mktrait
local mkstruct = mkstruct
local Struct = Struct
local frexp = math.frexp
local tointeger = math.tointeger or function(x) return x end
local sbyte = string.byte
local ssub = string.sub
-- helper function that preserves "integerness" of val
-- val is expected to be an integer, lim is expected to be an integer power of two
local imod = (lua.loadstring or lua.load) "local val, lim = ...; return val & (lim - 1)" or function(val, lim) return val % lim end
local _G=_G
local _ENV = nil
frexp = frexp or function(n, mul)
    mul = mul or 1
    -- naive implementation, replace later
    if n < 0 then return frexp(-n, -1) end
    if n == 0 then return 0, 0 end
    local e = 0
    while n >= 1 do
        n = n * 0.5
        e = e + 1
    end
    while n < 0.5 do
        n = n * 2
        e = e - 1
    end
    return n*mul, e
end

-- this is the object-to-hash interface
local Bucketable = mktrait()

function Bucketable:bucket(bucket)
    return [[
Updates the bucket with the contents of this object.
]]
end

-- this is the "hasher" interface
local Bucket = mktrait()

function Bucket:put_number(n)
    return [[
Puts the number in this bucket.
]]
end

function Bucket:put_string(s)
    return [[
Puts the string in this bucket.
]]
end

function Bucket:put_object(o)
    return [[
Puts the object, which must implement `Bucketable`, in this bucket.
]]
end

function Bucket:finish()
    return [[
Returns an integral value, traditionally from 0 to 2147483647, representing the contents of this bucket.
]]
end

local bitwidth
do
    local c = 2147483647
    if c + 1 < c then
        bitwidth = 30
    end
    if c + 1 == c then
        bitwidth = 23
    end
    if not bitwidth then
        c = 9223372036854775807
        if c + 1 < c then
            bitwidth = 62
        end
        if c + 1 == c then
            bitwidth = 52
        end
    end
    if not bitwidth then
        error("unsupported")
    end
end
local wraparound = tointeger(2^bitwidth)

local DefaultBucket = mkstruct(function(_struct)
    return {_state = 0}
end)

-- FIXME clean up once `function foo.bar:[baz].qux(...)` is supported by
-- cratera compiler
local DefaultBucketBucket = {}
DefaultBucket[Bucket] = DefaultBucketBucket

local function shift_state(state)
    -- a simple way to do this is to do it like java does strings
    -- N.B. unfortunately we might not have integer math and floats lose
    -- precision on the low end, so we must split this out
    -- x * 31 == x * 32 - x
    local tmp = imod(state * 32, wraparound)
    tmp = imod(tmp - state, wraparound)
    state = tmp
    return state
end

function DefaultBucketBucket:put_string(s)
    self:[Bucket].put_number(#s)
    local state = self._state
    for i=1,#s do
        state = shift_state(state)
        -- conveniently we do not need imod() here
        state = state + sbyte(ssub(s,i,i))
    end
    self._state = state
end

function DefaultBucketBucket:put_number(n)
    local state = self._state
    state = shift_state(state)
    if n == 0 then
        return
    end
    local sign = 0
    if n < 0 then
        if n == -n then
            state = imod(state + (wraparound - 1), wraparound)
            self._state = state
            return
        end
        n = -n
        sign = 1
    end
    if n % 1 == 0 and n <= (2*wraparound-1) then
        n = tointeger(n) -- ensure it is an integer
        n = imod(n, wraparound) -- ensure it fits
        state = imod(state + n, wraparound)
        self._state = state
        return
    end
    local m, e = frexp(n)
    _G.print(m, e)
    m = 2 * m - 1 -- within [0, 1), contains 52/23 bits
    _G.print(m * wraparound, e)
    m = tointeger(m * wraparound) -- within [0, wraparound)
    e = (e - 1) * 2 + sign -- significantly smaller than wraparound
    _G.print(m, e)
    local tmp = imod(e*2 + m, wraparound)
    state = imod(state + tmp, wraparound)
    self._state = state
end

function DefaultBucketBucket:put_object(o)
    o:[Bucketable].bucket(self)
end

function DefaultBucketBucket:finish()
    return self._state
end

-- this is the actual hashmap
local BucketingMap = mkstruct(function(_struct, bucket_factory)
    return {
        _bucket_factory = bucket_factory or DefaultBucket,
        _size = 0,
        _capacity = 0,
        _buckets = {},
        -- can be changed at runtime
        load_factor = 0.5,
    }
end, {
    __eq=function(a, b)
        -- are types equal?
        if a[Struct] ~= b[Struct] then
            return false
        end
        -- are elements equal?
        for k, v in a:[Iter].iter() do
            if b:get(k) ~= v then
                return false
            end
        end
        return true
    end
})

do local BucketingMapIter = {}
    BucketingMap[Iter] = BucketingMapIter
    local function iterator(state)
        local bucket_pos = state[3] + 1
        local buckets = state[1]
        local bucket = buckets[bucket_pos]
        while not bucket do
            if bucket_pos > state[2] then
                return nil
            end
            bucket_pos = bucket_pos + 1
            bucket = buckets[bucket_pos]
        end
        return bucket[2], bucket[3]
    end
    function BucketingMapIter:iter()
        return iterator, {self._buckets, self._capacity, 0}
    end
end

-- finds the (possibly free) bucket for a given key
local function find_bucket(buckets, capacity, hash, key)
    local i = (hash % capacity) + 1
    local overflow = false
    while true do
        local bucket = buckets[i]
        if not bucket then
            return i
        end
        if bucket[1] == hash and bucket[2] == key then
            return i
        end
        i = i + 1
        if i > capacity then
            if overflow then
                return nil
            end
            overflow = true
            i = 1
        end
    end
end

local function put(self, key, value, update_key)
    local hasher = self._bucket_factory()
    key:[Bucketable].bucket(hasher)
    local hash = hasher:[Bucket].finish()
    if self._size >= self._capacity * self.load_factor then
        local old_capacity = self._capacity
        local new_capacity = old_capacity == 0 and 1 or old_capacity * 2
        local old_buckets = self._buckets
        local new_buckets = {}
        for i=1,new_capacity do
            new_buckets[i] = false
        end
        for i=1,old_capacity do
            local old_bucket = old_buckets[i]
            if old_bucket then
                local new_bucket_pos = find_bucket(new_buckets, new_capacity, old_bucket[1], old_bucket[2])
                new_buckets[new_bucket_pos] = old_bucket
            end
        end
        self._buckets = new_buckets
        self._capacity = new_capacity
    end
    local buckets = self._buckets
    local bucket_pos = find_bucket(buckets, self._capacity, hash, key)
    if not bucket_pos then
        error("no space for key")
    end
    local bucket = buckets[bucket_pos]
    if not bucket then
        self._size = self._size + 1
        buckets[bucket_pos] = {hash, key, value}
        return nil
    else
        local old_value = bucket[3]
        bucket[3] = value
        if update_key then
            bucket[2] = key
        end
        return old_value
    end
end

function BucketingMap:put(key, value)
    put(self, key, value, false)
end

function BucketingMap:has_key(key)
    local hasher = self._bucket_factory()
    key:[Bucketable].bucket(hasher)
    local hash = hasher:[Bucket].finish()
    local buckets = self._buckets
    local bucket_pos = find_bucket(buckets, self._capacity, hash, key)
    local bucket = buckets[bucket_pos] -- buckets[nil] is always nil
    if not bucket then
        return false
    else
        return true
    end
end

function BucketingMap:get(key, value)
    local hasher = self._bucket_factory()
    key:[Bucketable].bucket(hasher)
    local hash = hasher:[Bucket].finish()
    local buckets = self._buckets
    local bucket_pos = find_bucket(buckets, self._capacity, hash, key)
    local bucket = buckets[bucket_pos] -- buckets[nil] is always nil
    if not bucket then
        return nil
    else
        return bucket[3]
    end
end

function BucketingMap:remove(key)
    local hasher = self._bucket_factory()
    key:[Bucketable].bucket(hasher)
    local hash = hasher:[Bucket].finish()
    local buckets = self._buckets
    local bucket_pos = find_bucket(buckets, self._capacity, hash, key)
    local bucket = buckets[bucket_pos]
    if not bucket then
        return nil
    end
    local value = bucket[3]
    self._size = self._size + 1
    -- just rehash lol
    local capacity = self._capacity
    local new_buckets = {}
    for i=1,capacity do
        new_buckets[i] = false
    end
    for i=1,bucket_pos-1 do
        local old_bucket = buckets[i]
        if old_bucket then
            local new_bucket_pos = find_bucket(new_buckets, capacity, old_bucket[1], old_bucket[2])
            new_buckets[new_bucket_pos] = old_bucket
        end
    end
    for i=bucket_pos+1,capacity do
        local old_bucket = buckets[i]
        if old_bucket then
            local new_bucket_pos = find_bucket(new_buckets, capacity, old_bucket[1], old_bucket[2])
            new_buckets[new_bucket_pos] = old_bucket
        end
    end
    self._buckets = new_buckets
    return value
end

return {
    BucketingMap = BucketingMap,
    Bucket = Bucket,
    Bucketable = Bucketable,
    DefaultBucket = DefaultBucket,
}