|
|
|
@ -1,93 +1,142 @@
|
|
|
|
|
-- Adapted from https://github.com/attractivechaos/plb/blob/master/matmul/matmul_v1.lua
|
|
|
|
|
-- dummy cast
|
|
|
|
|
function cast(n, to)
|
|
|
|
|
return n
|
|
|
|
|
return n
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
local slice = table.slice
|
|
|
|
|
|
|
|
|
|
matrix = {}
|
|
|
|
|
function matrix.new(m, n)
|
|
|
|
|
local t = {m, n, table.numarray(m*n, 0)}
|
|
|
|
|
return t
|
|
|
|
|
local t = {m, n, table.numarray(m*n, 0)}
|
|
|
|
|
return t
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function matrix.getrow(m, row)
|
|
|
|
|
local rows = m[1]
|
|
|
|
|
local cols = m[2]
|
|
|
|
|
local data = m[3]
|
|
|
|
|
assert(row > 0 and row <= rows)
|
|
|
|
|
return slice(data, (row-1)*cols+1, cols)
|
|
|
|
|
local rows = m[1]
|
|
|
|
|
local cols = m[2]
|
|
|
|
|
local data = m[3]
|
|
|
|
|
return slice(data, (row-1)*cols+1, cols)
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function matrix.getdata(m)
|
|
|
|
|
return m[3]
|
|
|
|
|
return m[3]
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function matrix.rows(m)
|
|
|
|
|
return m[1]
|
|
|
|
|
return m[1]
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function matrix.cols(m)
|
|
|
|
|
return m[2]
|
|
|
|
|
return m[2]
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function matrix.T(a)
|
|
|
|
|
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
|
|
|
|
|
local m: integer, n: integer = mrows(a), mcols(a);
|
|
|
|
|
local x = mnew(n,m)
|
|
|
|
|
local data: number[] = mdata(x)
|
|
|
|
|
-- for each row
|
|
|
|
|
for i = 1, m do
|
|
|
|
|
local slice: number[] = mrow(a, i)
|
|
|
|
|
-- for each column
|
|
|
|
|
for j = 1, n do
|
|
|
|
|
-- switch row and column
|
|
|
|
|
local c: integer, r:integer = i, j
|
|
|
|
|
-- calculate the array position in transposed matrix [j][i]
|
|
|
|
|
local pos: integer = (r-1)*m+c
|
|
|
|
|
data[pos] = slice[j]
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
return x;
|
|
|
|
|
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
|
|
|
|
|
local m: integer, n: integer = mrows(a), mcols(a);
|
|
|
|
|
local x = mnew(n,m)
|
|
|
|
|
local data: number[] = mdata(x)
|
|
|
|
|
-- for each row
|
|
|
|
|
for i = 1, m do
|
|
|
|
|
local slice: number[] = mrow(a, i)
|
|
|
|
|
-- for each column
|
|
|
|
|
for j = 1, n do
|
|
|
|
|
-- switch row and column
|
|
|
|
|
local c: integer, r:integer = i, j
|
|
|
|
|
-- calculate the array position in transposed matrix [j][i]
|
|
|
|
|
local pos: integer = (r-1)*m+c
|
|
|
|
|
data[pos] = slice[j]
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
return x;
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function matrix.mul(a, b)
|
|
|
|
|
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
|
|
|
|
|
local m: integer, n: integer, p: integer = mrows(a), mcols(a), mcols(b);
|
|
|
|
|
assert(n == p)
|
|
|
|
|
local x = mnew(m,n)
|
|
|
|
|
local c = matrix.T(b); -- transpose for efficiency
|
|
|
|
|
for i = 1, m do
|
|
|
|
|
local xi: number[] = mrow(x,i);
|
|
|
|
|
for j = 1, p do
|
|
|
|
|
local sum: number, ai: number[], cj: number[] = 0.0, mrow(a,i), mrow(c,j);
|
|
|
|
|
-- for luajit, caching c[j] or not makes no difference; lua is not so clever
|
|
|
|
|
for k = 1, n do
|
|
|
|
|
sum = sum + ai[k] * cj[k]
|
|
|
|
|
end
|
|
|
|
|
xi[j] = sum;
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
return x;
|
|
|
|
|
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
|
|
|
|
|
local m: integer, n: integer, p: integer = mrows(a), mcols(a), mcols(b);
|
|
|
|
|
assert(n == p)
|
|
|
|
|
local x = mnew(m,n)
|
|
|
|
|
local c = mtran(b); -- transpose for efficiency
|
|
|
|
|
for i = 1, m do
|
|
|
|
|
local xi: number[] = mrow(x,i);
|
|
|
|
|
for j = 1, p do
|
|
|
|
|
local sum: number, ai: number[], cj: number[] = 0.0, mrow(a,i), mrow(c,j);
|
|
|
|
|
for k = 1, n do
|
|
|
|
|
sum = sum + ai[k] * cj[k]
|
|
|
|
|
end
|
|
|
|
|
xi[j] = sum;
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
return x;
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
-- this version avoids using slices - we operate on the
|
|
|
|
|
-- one dimensional array; however the version using slices
|
|
|
|
|
-- is faster
|
|
|
|
|
function matrix.mul2(a, b)
|
|
|
|
|
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
|
|
|
|
|
local m: integer, n: integer, p: integer = mrows(a), mcols(a), mcols(b);
|
|
|
|
|
assert(n == p)
|
|
|
|
|
local x = mnew(m,n)
|
|
|
|
|
local c = mtran(b); -- transpose for efficiency
|
|
|
|
|
local xdata: number[] = mdata(x)
|
|
|
|
|
local adata: number[] = mdata(a)
|
|
|
|
|
local cdata: number[] = mdata(c)
|
|
|
|
|
local sum: number
|
|
|
|
|
local cj: integer
|
|
|
|
|
local xi: integer
|
|
|
|
|
local t,s
|
|
|
|
|
for i = 1, m do
|
|
|
|
|
xi = (i-1)*m
|
|
|
|
|
for j = 1, p do
|
|
|
|
|
sum = 0.0;
|
|
|
|
|
cj = (j-1)*p
|
|
|
|
|
for k = 1, n do
|
|
|
|
|
sum = sum + adata[xi+k] * cdata[cj+k]
|
|
|
|
|
end
|
|
|
|
|
xdata[xi+j] = sum;
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
return x;
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function matrix.gen(arg)
|
|
|
|
|
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
|
|
|
|
|
local n: integer = cast(arg, "integer")
|
|
|
|
|
local a = mnew(n, n)
|
|
|
|
|
local tmp: number = 1.0 / n / n;
|
|
|
|
|
for i = 1, n do
|
|
|
|
|
local row: number[] = mrow(a, i)
|
|
|
|
|
for j = 1, #row do
|
|
|
|
|
row[j] = tmp * (i - j) * (i + j - 2)
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
return a;
|
|
|
|
|
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
|
|
|
|
|
local n: integer = cast(arg, "integer")
|
|
|
|
|
local a = mnew(n, n)
|
|
|
|
|
local tmp: number = 1.0 / n / n;
|
|
|
|
|
for i = 1, n do
|
|
|
|
|
local row: number[] = mrow(a, i)
|
|
|
|
|
for j = 1, #row do
|
|
|
|
|
row[j] = tmp * (i - j) * (i + j - 2)
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
return a;
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function matrix.print(a)
|
|
|
|
|
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
|
|
|
|
|
local m: integer, n: integer = mrows(a), mcols(a);
|
|
|
|
|
-- for each row
|
|
|
|
|
for i = 1, m do
|
|
|
|
|
local str = ""
|
|
|
|
|
local slice: number[] = mrow(a, i)
|
|
|
|
|
-- for each column
|
|
|
|
|
for j = 1, n do
|
|
|
|
|
if j == 1 then
|
|
|
|
|
str = str .. slice[j]
|
|
|
|
|
else
|
|
|
|
|
str = str .. ", " .. slice[j]
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
print(str)
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
assert(ravi.compile(cast))
|
|
|
|
|
assert(ravi.compile(matrix.gen))
|
|
|
|
|
assert(ravi.compile(matrix.mul))
|
|
|
|
|
assert(ravi.compile(matrix.mul2))
|
|
|
|
|
assert(ravi.compile(matrix.T))
|
|
|
|
|
assert(ravi.compile(matrix.cols))
|
|
|
|
|
assert(ravi.compile(matrix.rows))
|
|
|
|
@ -105,3 +154,11 @@ print("time taken ", t2-t1)
|
|
|
|
|
--print(a[n/2+1][n/2+1]);
|
|
|
|
|
local y: integer = cast(n/2+1, "integer")
|
|
|
|
|
print(matrix.getrow(a, y)[y])
|
|
|
|
|
|
|
|
|
|
--ravi.dumplua(matrix.mul)
|
|
|
|
|
|
|
|
|
|
--matrix.print(matrix.gen(2))
|
|
|
|
|
--matrix.print(matrix.T(matrix.gen(2)))
|
|
|
|
|
--matrix.print(matrix.mul(matrix.gen(2), matrix.gen(2)))
|
|
|
|
|
|
|
|
|
|
--ravi.dumpllvmasm(matrix.mul)
|