play around with matmul

pull/81/head
Dibyendu Majumdar 9 years ago
parent c22e685127
commit f0173d0b04

@ -1,93 +1,142 @@
-- Adapted from https://github.com/attractivechaos/plb/blob/master/matmul/matmul_v1.lua
-- dummy cast
function cast(n, to)
return n
return n
end
local slice = table.slice
matrix = {}
function matrix.new(m, n)
local t = {m, n, table.numarray(m*n, 0)}
return t
local t = {m, n, table.numarray(m*n, 0)}
return t
end
function matrix.getrow(m, row)
local rows = m[1]
local cols = m[2]
local data = m[3]
assert(row > 0 and row <= rows)
return slice(data, (row-1)*cols+1, cols)
local rows = m[1]
local cols = m[2]
local data = m[3]
return slice(data, (row-1)*cols+1, cols)
end
function matrix.getdata(m)
return m[3]
return m[3]
end
function matrix.rows(m)
return m[1]
return m[1]
end
function matrix.cols(m)
return m[2]
return m[2]
end
function matrix.T(a)
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
local m: integer, n: integer = mrows(a), mcols(a);
local x = mnew(n,m)
local data: number[] = mdata(x)
-- for each row
for i = 1, m do
local slice: number[] = mrow(a, i)
-- for each column
for j = 1, n do
-- switch row and column
local c: integer, r:integer = i, j
-- calculate the array position in transposed matrix [j][i]
local pos: integer = (r-1)*m+c
data[pos] = slice[j]
end
end
return x;
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
local m: integer, n: integer = mrows(a), mcols(a);
local x = mnew(n,m)
local data: number[] = mdata(x)
-- for each row
for i = 1, m do
local slice: number[] = mrow(a, i)
-- for each column
for j = 1, n do
-- switch row and column
local c: integer, r:integer = i, j
-- calculate the array position in transposed matrix [j][i]
local pos: integer = (r-1)*m+c
data[pos] = slice[j]
end
end
return x;
end
function matrix.mul(a, b)
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
local m: integer, n: integer, p: integer = mrows(a), mcols(a), mcols(b);
assert(n == p)
local x = mnew(m,n)
local c = matrix.T(b); -- transpose for efficiency
for i = 1, m do
local xi: number[] = mrow(x,i);
for j = 1, p do
local sum: number, ai: number[], cj: number[] = 0.0, mrow(a,i), mrow(c,j);
-- for luajit, caching c[j] or not makes no difference; lua is not so clever
for k = 1, n do
sum = sum + ai[k] * cj[k]
end
xi[j] = sum;
end
end
return x;
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
local m: integer, n: integer, p: integer = mrows(a), mcols(a), mcols(b);
assert(n == p)
local x = mnew(m,n)
local c = mtran(b); -- transpose for efficiency
for i = 1, m do
local xi: number[] = mrow(x,i);
for j = 1, p do
local sum: number, ai: number[], cj: number[] = 0.0, mrow(a,i), mrow(c,j);
for k = 1, n do
sum = sum + ai[k] * cj[k]
end
xi[j] = sum;
end
end
return x;
end
-- this version avoids using slices - we operate on the
-- one dimensional array; however the version using slices
-- is faster
function matrix.mul2(a, b)
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
local m: integer, n: integer, p: integer = mrows(a), mcols(a), mcols(b);
assert(n == p)
local x = mnew(m,n)
local c = mtran(b); -- transpose for efficiency
local xdata: number[] = mdata(x)
local adata: number[] = mdata(a)
local cdata: number[] = mdata(c)
local sum: number
local cj: integer
local xi: integer
local t,s
for i = 1, m do
xi = (i-1)*m
for j = 1, p do
sum = 0.0;
cj = (j-1)*p
for k = 1, n do
sum = sum + adata[xi+k] * cdata[cj+k]
end
xdata[xi+j] = sum;
end
end
return x;
end
function matrix.gen(arg)
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
local n: integer = cast(arg, "integer")
local a = mnew(n, n)
local tmp: number = 1.0 / n / n;
for i = 1, n do
local row: number[] = mrow(a, i)
for j = 1, #row do
row[j] = tmp * (i - j) * (i + j - 2)
end
end
return a;
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
local n: integer = cast(arg, "integer")
local a = mnew(n, n)
local tmp: number = 1.0 / n / n;
for i = 1, n do
local row: number[] = mrow(a, i)
for j = 1, #row do
row[j] = tmp * (i - j) * (i + j - 2)
end
end
return a;
end
function matrix.print(a)
local mrows, mcols, mnew, mdata, mrow, mtran = matrix.rows, matrix.cols, matrix.new, matrix.getdata, matrix.getrow, matrix.T
local m: integer, n: integer = mrows(a), mcols(a);
-- for each row
for i = 1, m do
local str = ""
local slice: number[] = mrow(a, i)
-- for each column
for j = 1, n do
if j == 1 then
str = str .. slice[j]
else
str = str .. ", " .. slice[j]
end
end
print(str)
end
end
assert(ravi.compile(cast))
assert(ravi.compile(matrix.gen))
assert(ravi.compile(matrix.mul))
assert(ravi.compile(matrix.mul2))
assert(ravi.compile(matrix.T))
assert(ravi.compile(matrix.cols))
assert(ravi.compile(matrix.rows))
@ -105,3 +154,11 @@ print("time taken ", t2-t1)
--print(a[n/2+1][n/2+1]);
local y: integer = cast(n/2+1, "integer")
print(matrix.getrow(a, y)[y])
--ravi.dumplua(matrix.mul)
--matrix.print(matrix.gen(2))
--matrix.print(matrix.T(matrix.gen(2)))
--matrix.print(matrix.mul(matrix.gen(2), matrix.gen(2)))
--ravi.dumpllvmasm(matrix.mul)

@ -21,11 +21,12 @@ The programs used in the performance testing can be found at `Ravi Tests <https:
|matmul(1000) | 34.604 | 4.2 | 0.968 |
+---------------+---------+----------+-----------+
There are a number of reasons why Ravi's performance is not as good as Luajit.
Following points are worth bearing in mind when looking at above benchmarks.
1. Luajit uses an optimized representation of values. In Lua 5.3 and
in Ravi, the value is 16 bytes - and many operations require two loads
/ two stores. Luajit therefore will always have an advantage here.
1. Luajit uses an optimized representation of double values. In Lua 5.3 and
in Ravi, a value is 16 bytes - and floating point operations require two loads
/ two stores. Luajit has a performance advantage when it comes to floating
point operations due to this.
2. More work is needed to optimize fornum loops in Ravi. I have some
ideas on what can be improved but have parked this for now as I want
@ -35,6 +36,10 @@ There are a number of reasons why Ravi's performance is not as good as Luajit.
the actual execution path taken by the code at runtime whereas Ravi
compiles each function as a whole regardless of how it will be used.
4. For Ravi the timings above do not include the LLVM compilation time.
But LuaJIT timings include the JIT compilation times, so they show
incredible performance.
Ideas
-----
There are a number of improvements possible. Below are some of my thoughts.
@ -49,7 +54,7 @@ external variable if necessary.
The Fornum loop needs to handle four different scenarios, resulting from the type of the index variable and whether the loop increments or decrements.
The generated code is not very efficient due to branching. The common case of integer index with constant step can be specialized for greater
performance.
performance. I have implemented the case when index is an integer and the step size is a positive constant. This seems to be the most common case.
The Value Storage
-----------------

Loading…
Cancel
Save