You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
215 lines
5.5 KiB
215 lines
5.5 KiB
-- Adapted from https://github.com/attractivechaos/plb/blob/master/matmul/matmul_v1.lua
|
|
-- dummy cast
|
|
function cast(n, to)
|
|
return n
|
|
end
|
|
|
|
local slice = table.slice
|
|
|
|
matrix = {}
|
|
function matrix.new(m, n)
|
|
local t = {m, n, table.numarray(m*n, 0)}
|
|
return t
|
|
end
|
|
|
|
function matrix.getrow(m, row)
|
|
local rows = m[1]
|
|
local cols = m[2]
|
|
local data = m[3]
|
|
return slice(data, (row-1)*cols+1, cols)
|
|
end
|
|
|
|
function matrix.getdata(m)
|
|
return m[3]
|
|
end
|
|
|
|
function matrix.rows(m)
|
|
return m[1]
|
|
end
|
|
|
|
function matrix.cols(m)
|
|
return m[2]
|
|
end
|
|
|
|
-- Matrix transpose
|
|
-- This version uses slices
|
|
function matrix.T(a)
|
|
--local t1 = os.clock()
|
|
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
|
|
local m: integer, n: integer = mrows(a), mcols(a);
|
|
local x = mnew(n,m)
|
|
local data: number[] = mdata(x)
|
|
-- for each row
|
|
for i = 1, m do
|
|
local slice: number[] = mrow(a, i)
|
|
-- for each column
|
|
for j = 1, n do
|
|
-- switch row and column
|
|
local c: integer, r:integer = i, j
|
|
-- calculate the array position in transposed matrix [j][i]
|
|
local pos: integer = (r-1)*m+c
|
|
data[pos] = slice[j]
|
|
end
|
|
end
|
|
--local t2 = os.clock()
|
|
--print("T: time ", t2-t1)
|
|
return x;
|
|
end
|
|
|
|
-- Matrix transpose
|
|
-- Does not use slices
|
|
function matrix.T2(a)
|
|
--local t1 = os.clock()
|
|
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
|
|
local m: integer, n: integer = mrows(a), mcols(a);
|
|
local x = mnew(n,m)
|
|
local data: number[] = mdata(x)
|
|
local adata: number[] = mdata(a)
|
|
-- for each row
|
|
for i = 1, m do
|
|
local ri: integer = (i-1)*n
|
|
-- for each column
|
|
for j = 1, n do
|
|
-- calculate the array position in transposed matrix [j][i]
|
|
data[(j-1)*m+i] = adata[ri+j]
|
|
end
|
|
end
|
|
--local t2 = os.clock()
|
|
--print("T2: time ", t2-t1)
|
|
return x;
|
|
end
|
|
|
|
-- Matrix multiply
|
|
-- Uses slices
|
|
function matrix.mul(a, b)
|
|
--local t1 = os.clock()
|
|
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
|
|
local m: integer, n: integer, p: integer = mrows(a), mcols(a), mcols(b);
|
|
assert(n == p)
|
|
local x = mnew(m,n)
|
|
local c = mtran(b); -- transpose for efficiency
|
|
for i = 1, m do
|
|
local xi: number[] = mrow(x,i);
|
|
for j = 1, p do
|
|
local sum: number, ai: number[], cj: number[] = 0.0, mrow(a,i), mrow(c,j);
|
|
for k = 1, n do
|
|
sum = sum + ai[k] * cj[k]
|
|
end
|
|
xi[j] = sum;
|
|
end
|
|
end
|
|
--local t2 = os.clock()
|
|
--print("mul: time ", t2-t1)
|
|
return x;
|
|
end
|
|
|
|
-- Matrix multiply
|
|
-- this version avoids using slices - we operate on the
|
|
-- one dimensional array
|
|
function matrix.mul2(a, b)
|
|
--local t1 = os.clock()
|
|
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T2
|
|
local m: integer, n: integer, p: integer = mrows(a), mcols(a), mcols(b);
|
|
assert(n == p)
|
|
local x = mnew(m,n)
|
|
local c = mtran(b); -- transpose for efficiency
|
|
local xdata: number[] = mdata(x)
|
|
local adata: number[] = mdata(a)
|
|
local cdata: number[] = mdata(c)
|
|
local sum: number
|
|
local cj: integer
|
|
local xi: integer
|
|
local t,s
|
|
for i = 1, m do
|
|
xi = (i-1)*m
|
|
for j = 1, p do
|
|
sum = 0.0;
|
|
cj = (j-1)*p
|
|
for k = 1, n do
|
|
sum = sum + adata[xi+k] * cdata[cj+k]
|
|
end
|
|
xdata[xi+j] = sum;
|
|
end
|
|
end
|
|
--local t2 = os.clock()
|
|
--print("mul2: time ", t2-t1)
|
|
return x;
|
|
end
|
|
|
|
-- Generate the matrix - uses slices
|
|
function matrix.gen(n: integer)
|
|
--local t1 = os.clock()
|
|
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
|
|
local a = mnew(n, n)
|
|
local tmp: number = 1.0 / n / n;
|
|
for i = 1, n do
|
|
local row: number[] = mrow(a, i)
|
|
for j = 1, #row do
|
|
row[j] = tmp * (i - j) * (i + j - 2)
|
|
end
|
|
end
|
|
--local t2 = os.clock()
|
|
--print("gen: time ", t2-t1)
|
|
return a;
|
|
end
|
|
|
|
-- Generate the matrix - this version does not use slices
|
|
function matrix.gen2(n: integer)
|
|
--local t1 = os.clock()
|
|
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
|
|
local a = mnew(n, n)
|
|
local data: number[] = mdata(a)
|
|
local tmp: number = 1.0 / n / n;
|
|
local ri: integer
|
|
for i = 1, n do
|
|
--local row: number[] = mrow(a, i)
|
|
ri = (i-1)*n
|
|
for j = 1, n do
|
|
data[ri+j] = tmp * (i - j) * (i + j - 2)
|
|
end
|
|
end
|
|
--local t2 = os.clock()
|
|
--print("gen2: time ", t2-t1)
|
|
return a;
|
|
end
|
|
|
|
function matrix.print(a)
|
|
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
|
|
local m: integer, n: integer = mrows(a), mcols(a);
|
|
-- for each row
|
|
for i = 1, m do
|
|
local str = ""
|
|
local slice: number[] = mrow(a, i)
|
|
print(table.unpack(slice))
|
|
end
|
|
end
|
|
|
|
if ravi.jit() then
|
|
assert(ravi.compile(cast))
|
|
assert(ravi.compile(matrix, {omitArrayGetRangeCheck=1}))
|
|
--ravi.dumpir(matrix.new)
|
|
end
|
|
|
|
local n = arg[1] or 1000;
|
|
n = math.floor(n/2) * 2;
|
|
local t1 = os.clock()
|
|
local a = matrix.mul2(matrix.gen2(n), matrix.gen2(n));
|
|
local t2 = os.clock()
|
|
print("total time taken ", t2-t1)
|
|
|
|
--print(a[n/2+1][n/2+1]);
|
|
local y: integer = cast(n/2+1, "integer")
|
|
print(matrix.getrow(a, y)[y])
|
|
|
|
--matrix.print(matrix.gen(2))
|
|
--matrix.print(matrix.gen2(2))
|
|
|
|
--matrix.print(matrix.T(matrix.gen(2)))
|
|
--matrix.print(matrix.T2(matrix.gen(2)))
|
|
|
|
--matrix.print(matrix.mul(matrix.gen(2), matrix.gen(2)))
|
|
|
|
--ravi.dumpllvm(matrix.mul2)
|
|
|
|
--ravi.dumplua(matrix.T2)
|