Jump to content

Module:memoize

ពីWiktionary

Documentation for this module may be created at Module:memoize/doc

local math_module = "Module:math"

local select = select
local unpack = unpack or table.unpack -- Lua 5.2 compatibility

local function sign(...)
	sign = require(math_module).sign
	return sign(...)
end

----- M E M O I Z A T I O N-----
-- Memoizes a function or callable table.
-- Supports any number of arguments and return values.
-- If the optional parameter `simple` is set, then the memoizer will use a faster implementation, but this is only compatible with one argument and one return value. If `simple` is set, additional arguments will be accepted, but this should only be done if those arguments will always be the same.

-- Sentinels.
local _nil, neg_0, pos_nan, neg_nan = {}, {}, {}, {}

-- Certain values can't be used as table keys, so they require sentinels as well: e.g. f("foo", nil, "bar") would be memoized at memo["foo"][_nil]["bar"][memo]. These values are:
	-- nil.
	-- -0, which is equivalent to 0 in most situations, but becomes "-0" on conversion to string; it also behaves differently in some operations (e.g. 1/a evaluates to inf if a is 0, but -inf if a is -0).
	-- NaN and -NaN, which are the only values for which n == n is false; they only seem to differ on conversion to string ("nan" and "-nan").
local function get_key(x)
	if x == x then
		return x == nil and _nil or x == 0 and 1 / x < 0 and neg_0 or x
	end
	return sign(x) == 1 and pos_nan or neg_nan
end

-- Return values are memoized as tables of return values, which are looked up using each input argument as a key, followed by `memo`. e.g. if the input arguments were (1, 2, 3), the memo would be located at t[1][2][3][memo]. `memo` is always used as the final lookup key so that (for example) the memo for f(1, 2, 3), f[1][2][3][memo], doesn't interfere with the memo for f(1, 2), f[1][2][memo].
local function get_memo(memo, n, nargs, key, ...)
	key = get_key(key)
	local next_memo = memo[key]
	if next_memo == nil then
		next_memo = {}
		memo[key] = next_memo
	end
	memo = next_memo
	return n == nargs and memo or get_memo(memo, n + 1, nargs, ...)
end

-- table.pack: since Lua 5.2, this is a function that wraps the parameters given into a table with the additional key n
-- that contains the total number of parameters given.
-- 
-- This is used to catch the function output values. We cannot catch the output in a table directly,
-- because we also need to account for any nils returned after the last non-nil value
-- (e.g. select("#", nil) == 1, select("#") == 0, select("#", nil, "foo", nil, nil) == 4 etc.).
--
-- The distinction between nil and nothing affects some native functions (e.g. tostring() throws an error, but tostring(nil) returns "nil"),
-- so it needs to be reconstructable from the memo.
--
-- On Lua 5.1, we can use the hidden variable arg (which is {...} but with n, and available when a function has ...).
local pack = table.pack or (function (...) return arg end)

return function(func, simple)
	local memo
	return simple and function(...)
		local key = get_key(...)
		if not memo then
			memo = {}
		end
		local output = memo[key]
		if output == nil then
			output = func(...)
			if output ~= nil then
				memo[key] = output
				return output
			end
			memo[key] = _nil
			return nil
		elseif output == _nil then
			return nil
		end
		return output
	end or function(...)
		local nargs = select("#", ...)
		if not memo then
			memo = {}
		end
		-- Since all possible inputs need to be memoized (including true, false and nil), the memo table itself is used as the key for the arguments.
		local _memo = nargs == 0 and memo or get_memo(memo, 1, nargs, ...)
		local output = _memo[memo]
		if output == nil then
			output = pack(func(...))
			_memo[memo] = output
		end
		-- Unpack from 1 to the original number of return values (memoized as output.n); unpack returns nil for any values not in output.
		return unpack(output, 1, output.n)
	end
end