You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ravi/tests/performance/matmul1.ravi

215 lines
5.5 KiB

-- Adapted from https://github.com/attractivechaos/plb/blob/master/matmul/matmul_v1.lua
-- dummy cast
function cast(n, to)
return n
end
local slice = table.slice
matrix = {}
function matrix.new(m, n)
local t = {m, n, table.numarray(m*n, 0)}
return t
end
function matrix.getrow(m, row)
local rows = m[1]
local cols = m[2]
local data = m[3]
return slice(data, (row-1)*cols+1, cols)
end
function matrix.getdata(m)
return m[3]
end
function matrix.rows(m)
return m[1]
end
function matrix.cols(m)
return m[2]
end
-- Matrix transpose
-- This version uses slices
function matrix.T(a)
--local t1 = os.clock()
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
local m: integer, n: integer = mrows(a), mcols(a);
local x = mnew(n,m)
local data: number[] = mdata(x)
-- for each row
for i = 1, m do
local slice: number[] = mrow(a, i)
-- for each column
for j = 1, n do
-- switch row and column
local c: integer, r:integer = i, j
-- calculate the array position in transposed matrix [j][i]
local pos: integer = (r-1)*m+c
data[pos] = slice[j]
end
end
--local t2 = os.clock()
--print("T: time ", t2-t1)
return x;
end
-- Matrix transpose
-- Does not use slices
function matrix.T2(a)
--local t1 = os.clock()
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
local m: integer, n: integer = mrows(a), mcols(a);
local x = mnew(n,m)
local data: number[] = mdata(x)
local adata: number[] = mdata(a)
-- for each row
for i = 1, m do
local ri: integer = (i-1)*n
-- for each column
for j = 1, n do
-- calculate the array position in transposed matrix [j][i]
data[(j-1)*m+i] = adata[ri+j]
end
end
--local t2 = os.clock()
--print("T2: time ", t2-t1)
return x;
end
-- Matrix multiply
-- Uses slices
function matrix.mul(a, b)
--local t1 = os.clock()
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
local m: integer, n: integer, p: integer = mrows(a), mcols(a), mcols(b);
assert(n == p)
local x = mnew(m,n)
local c = mtran(b); -- transpose for efficiency
for i = 1, m do
local xi: number[] = mrow(x,i);
for j = 1, p do
local sum: number, ai: number[], cj: number[] = 0.0, mrow(a,i), mrow(c,j);
for k = 1, n do
sum = sum + ai[k] * cj[k]
end
xi[j] = sum;
end
end
--local t2 = os.clock()
--print("mul: time ", t2-t1)
return x;
end
-- Matrix multiply
-- this version avoids using slices - we operate on the
-- one dimensional array
function matrix.mul2(a, b)
--local t1 = os.clock()
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T2
local m: integer, n: integer, p: integer = mrows(a), mcols(a), mcols(b);
assert(n == p)
local x = mnew(m,n)
local c = mtran(b); -- transpose for efficiency
local xdata: number[] = mdata(x)
local adata: number[] = mdata(a)
local cdata: number[] = mdata(c)
local sum: number
local cj: integer
local xi: integer
local t,s
for i = 1, m do
xi = (i-1)*m
for j = 1, p do
sum = 0.0;
cj = (j-1)*p
for k = 1, n do
sum = sum + adata[xi+k] * cdata[cj+k]
end
xdata[xi+j] = sum;
end
end
--local t2 = os.clock()
--print("mul2: time ", t2-t1)
return x;
end
-- Generate the matrix - uses slices
function matrix.gen(n: integer)
--local t1 = os.clock()
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
local a = mnew(n, n)
local tmp: number = 1.0 / n / n;
for i = 1, n do
local row: number[] = mrow(a, i)
for j = 1, #row do
row[j] = tmp * (i - j) * (i + j - 2)
end
end
--local t2 = os.clock()
--print("gen: time ", t2-t1)
return a;
end
-- Generate the matrix - this version does not use slices
function matrix.gen2(n: integer)
--local t1 = os.clock()
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
local a = mnew(n, n)
local data: number[] = mdata(a)
local tmp: number = 1.0 / n / n;
local ri: integer
for i = 1, n do
--local row: number[] = mrow(a, i)
ri = (i-1)*n
for j = 1, n do
data[ri+j] = tmp * (i - j) * (i + j - 2)
end
end
--local t2 = os.clock()
--print("gen2: time ", t2-t1)
return a;
end
function matrix.print(a)
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
local m: integer, n: integer = mrows(a), mcols(a);
-- for each row
for i = 1, m do
local str = ""
local slice: number[] = mrow(a, i)
print(table.unpack(slice))
end
end
if ravi.jit() then
assert(ravi.compile(cast))
assert(ravi.compile(matrix, {omitArrayGetRangeCheck=1}))
--ravi.dumpir(matrix.new)
end
local n = arg[1] or 1000;
n = math.floor(n/2) * 2;
local t1 = os.clock()
local a = matrix.mul2(matrix.gen2(n), matrix.gen2(n));
local t2 = os.clock()
print("total time taken ", t2-t1)
--print(a[n/2+1][n/2+1]);
local y: integer = cast(n/2+1, "integer")
print(matrix.getrow(a, y)[y])
--matrix.print(matrix.gen(2))
--matrix.print(matrix.gen2(2))
--matrix.print(matrix.T(matrix.gen(2)))
--matrix.print(matrix.T2(matrix.gen(2)))
--matrix.print(matrix.mul(matrix.gen(2), matrix.gen(2)))
--ravi.dumpllvm(matrix.mul2)
--ravi.dumplua(matrix.T2)