--[[
Cratera Library - (User-defined) Hash Map
Copyright (C) 2024 Soni L.
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
--]]
-- it'll be funny if this implementation turns out to also be similar to Lua's
-- own
local error = error
local Iter = Iter
local mktrait = mktrait
local mkstruct = mkstruct
local Struct = Struct
local frexp = math.frexp
local tointeger = math.tointeger or function(x) return x end
local sbyte = string.byte
local ssub = string.sub
-- helper function that preserves "integerness" of val
-- val is expected to be an integer, lim is expected to be an integer power of two
local imod = (lua.loadstring or lua.load) "local val, lim = ...; return val & (lim - 1)" or function(val, lim) return val % lim end
local _G=_G
local _ENV = nil
frexp = frexp or function(n, mul)
mul = mul or 1
-- naive implementation, replace later
if n < 0 then return frexp(-n, -1) end
if n == 0 then return 0, 0 end
local e = 0
while n >= 1 do
n = n * 0.5
e = e + 1
end
while n < 0.5 do
n = n * 2
e = e - 1
end
return n*mul, e
end
-- this is the object-to-hash interface
local Bucketable = mktrait()
function Bucketable:bucket(bucket)
return [[
Updates the bucket with the contents of this object.
]]
end
-- this is the "hasher" interface
local Bucket = mktrait()
function Bucket:put_number(n)
return [[
Puts the number in this bucket.
]]
end
function Bucket:put_string(s)
return [[
Puts the string in this bucket.
]]
end
function Bucket:put_object(o)
return [[
Puts the object, which must implement `Bucketable`, in this bucket.
]]
end
function Bucket:finish()
return [[
Returns an integral value, traditionally from 0 to 2147483647, representing the contents of this bucket.
]]
end
local bitwidth
do
local c = 2147483647
if c + 1 < c then
bitwidth = 30
end
if c + 1 == c then
bitwidth = 23
end
if not bitwidth then
c = 9223372036854775807
if c + 1 < c then
bitwidth = 62
end
if c + 1 == c then
bitwidth = 52
end
end
if not bitwidth then
error("unsupported")
end
end
local wraparound = tointeger(2^bitwidth)
local DefaultBucket = mkstruct(function(_struct)
return {_state = 0}
end)
-- FIXME clean up once `function foo.bar:[baz].qux(...)` is supported by
-- cratera compiler
local DefaultBucketBucket = {}
DefaultBucket[Bucket] = DefaultBucketBucket
local function shift_state(state)
-- a simple way to do this is to do it like java does strings
-- N.B. unfortunately we might not have integer math and floats lose
-- precision on the low end, so we must split this out
-- x * 31 == x * 32 - x
local tmp = imod(state * 32, wraparound)
tmp = imod(tmp - state, wraparound)
state = tmp
return state
end
function DefaultBucketBucket:put_string(s)
self:[Bucket].put_number(#s)
local state = self._state
for i=1,#s do
state = shift_state(state)
-- conveniently we do not need imod() here
state = state + sbyte(ssub(s,i,i))
end
self._state = state
end
function DefaultBucketBucket:put_number(n)
local state = self._state
state = shift_state(state)
if n == 0 then
return
end
local sign = 0
if n < 0 then
if n == -n then
state = imod(state + (wraparound - 1), wraparound)
self._state = state
return
end
n = -n
sign = 1
end
if n % 1 == 0 and n <= (2*wraparound-1) then
n = tointeger(n) -- ensure it is an integer
n = imod(n, wraparound) -- ensure it fits
state = imod(state + n, wraparound)
self._state = state
return
end
local m, e = frexp(n)
_G.print(m, e)
m = 2 * m - 1 -- within [0, 1), contains 52/23 bits
_G.print(m * wraparound, e)
m = tointeger(m * wraparound) -- within [0, wraparound)
e = (e - 1) * 2 + sign -- significantly smaller than wraparound
_G.print(m, e)
local tmp = imod(e*2 + m, wraparound)
state = imod(state + tmp, wraparound)
self._state = state
end
function DefaultBucketBucket:put_object(o)
o:[Bucketable].bucket(self)
end
function DefaultBucketBucket:finish()
return self._state
end
-- this is the actual hashmap
local BucketingMap = mkstruct(function(_struct, bucket_factory)
return {
_bucket_factory = bucket_factory or DefaultBucket,
_size = 0,
_capacity = 0,
_buckets = {},
-- can be changed at runtime
load_factor = 0.5,
}
end, {
__eq=function(a, b)
-- are types equal?
if a[Struct] ~= b[Struct] then
return false
end
-- are elements equal?
for k, v in a:[Iter].iter() do
if b:get(k) ~= v then
return false
end
end
return true
end
})
do local BucketingMapIter = {}
BucketingMap[Iter] = BucketingMapIter
local function iterator(state)
local bucket_pos = state[3] + 1
local buckets = state[1]
local bucket = buckets[bucket_pos]
while not bucket do
if bucket_pos > state[2] then
return nil
end
bucket_pos = bucket_pos + 1
bucket = buckets[bucket_pos]
end
return bucket[2], bucket[3]
end
function BucketingMapIter:iter()
return iterator, {self._buckets, self._capacity, 0}
end
end
-- finds the (possibly free) bucket for a given key
local function find_bucket(buckets, capacity, hash, key)
local i = (hash % capacity) + 1
local overflow = false
while true do
local bucket = buckets[i]
if not bucket then
return i
end
if bucket[1] == hash and bucket[2] == key then
return i
end
i = i + 1
if i > capacity then
if overflow then
return nil
end
overflow = true
i = 1
end
end
end
local function put(self, key, value, update_key)
local hasher = self._bucket_factory()
key:[Bucketable].bucket(hasher)
local hash = hasher:[Bucket].finish()
if self._size >= self._capacity * self.load_factor then
local old_capacity = self._capacity
local new_capacity = old_capacity == 0 and 1 or old_capacity * 2
local old_buckets = self._buckets
local new_buckets = {}
for i=1,new_capacity do
new_buckets[i] = false
end
for i=1,old_capacity do
local old_bucket = old_buckets[i]
if old_bucket then
local new_bucket_pos = find_bucket(new_buckets, new_capacity, old_bucket[1], old_bucket[2])
new_buckets[new_bucket_pos] = old_bucket
end
end
self._buckets = new_buckets
self._capacity = new_capacity
end
local buckets = self._buckets
local bucket_pos = find_bucket(buckets, self._capacity, hash, key)
if not bucket_pos then
error("no space for key")
end
local bucket = buckets[bucket_pos]
if not bucket then
self._size = self._size + 1
buckets[bucket_pos] = {hash, key, value}
return nil
else
local old_value = bucket[3]
bucket[3] = value
if update_key then
bucket[2] = key
end
return old_value
end
end
function BucketingMap:put(key, value)
put(self, key, value, false)
end
function BucketingMap:has_key(key)
local hasher = self._bucket_factory()
key:[Bucketable].bucket(hasher)
local hash = hasher:[Bucket].finish()
local buckets = self._buckets
local bucket_pos = find_bucket(buckets, self._capacity, hash, key)
local bucket = buckets[bucket_pos] -- buckets[nil] is always nil
if not bucket then
return false
else
return true
end
end
function BucketingMap:get(key, value)
local hasher = self._bucket_factory()
key:[Bucketable].bucket(hasher)
local hash = hasher:[Bucket].finish()
local buckets = self._buckets
local bucket_pos = find_bucket(buckets, self._capacity, hash, key)
local bucket = buckets[bucket_pos] -- buckets[nil] is always nil
if not bucket then
return nil
else
return bucket[3]
end
end
function BucketingMap:remove(key)
local hasher = self._bucket_factory()
key:[Bucketable].bucket(hasher)
local hash = hasher:[Bucket].finish()
local buckets = self._buckets
local bucket_pos = find_bucket(buckets, self._capacity, hash, key)
local bucket = buckets[bucket_pos]
if not bucket then
return nil
end
local value = bucket[3]
self._size = self._size + 1
-- just rehash lol
local capacity = self._capacity
local new_buckets = {}
for i=1,capacity do
new_buckets[i] = false
end
for i=1,bucket_pos-1 do
local old_bucket = buckets[i]
if old_bucket then
local new_bucket_pos = find_bucket(new_buckets, capacity, old_bucket[1], old_bucket[2])
new_buckets[new_bucket_pos] = old_bucket
end
end
for i=bucket_pos+1,capacity do
local old_bucket = buckets[i]
if old_bucket then
local new_bucket_pos = find_bucket(new_buckets, capacity, old_bucket[1], old_bucket[2])
new_buckets[new_bucket_pos] = old_bucket
end
end
self._buckets = new_buckets
return value
end
return {
BucketingMap = BucketingMap,
Bucket = Bucket,
Bucketable = Bucketable,
DefaultBucket = DefaultBucket,
}