implement OP_CALL

Dibyendu Majumdar 9 years ago
parent c94b382231
commit 74e2a829d1

@ -70,7 +70,7 @@ set (LUA_CORE_SRCS src/lapi.c src/lcode.c src/lctype.c src/ldebug.c src/ldo.c sr
src/lvm.c src/lzio.c src/ravijit.cpp src/ravi_llvmtypes.cpp
src/ravi_llvmcodegen.cpp src/ravi_llvmforprep.cpp src/ravi_llvmcomp.cpp
src/ravi_llvmreturn.cpp src/ravi_llvmload.cpp src/ravi_llvmforloop.cpp
src/ravi_llvmarith1.cpp)
src/ravi_llvmarith1.cpp src/ravi_llvmcall.cpp)
# define the lua lib source files
set (LUA_LIB_SRCS src/lauxlib.c src/lbaselib.c src/lbitlib.c src/lcorolib.c src/ldblib.c src/liolib.c
src/lmathlib.c src/loslib.c src/ltablib.c src/lstrlib.c src/loadlib.c src/linit.c src/lutf8lib.c)

@ -164,14 +164,16 @@ struct LuaLLVMTypes {
llvm::FunctionType *jitFunctionT;
llvm::FunctionType *luaD_poscallT;
llvm::FunctionType *luaD_precallT;
llvm::FunctionType *luaF_closeT;
llvm::FunctionType *luaG_runerrorT;
llvm::FunctionType *luaV_equalobjT;
llvm::FunctionType *luaV_lessthanT;
llvm::FunctionType *luaV_lessequalT;
llvm::FunctionType *luaG_runerrorT;
llvm::FunctionType *luaV_forlimitT;
llvm::FunctionType *luaV_tonumberT;
llvm::FunctionType *luaV_tointegerT;
llvm::FunctionType *luaV_executeT;
llvm::FunctionType *luaV_op_loadnilT;
@ -346,14 +348,16 @@ struct RaviFunctionDef {
// Lua function declarations
llvm::Constant *luaD_poscallF;
llvm::Constant *luaD_precallF;
llvm::Constant *luaF_closeF;
llvm::Constant *luaG_runerrorF;
llvm::Constant *luaV_equalobjF;
llvm::Constant *luaV_lessthanF;
llvm::Constant *luaV_lessequalF;
llvm::Constant *luaG_runerrorF;
llvm::Constant *luaV_forlimitF;
llvm::Constant *luaV_tonumberF;
llvm::Constant *luaV_tointegerF;
llvm::Constant *luaV_executeF;
// Some cheats
llvm::Constant *luaV_op_loadnilF;
@ -552,6 +556,9 @@ public:
void emit_RETURN(RaviFunctionDef *def, llvm::Value *L_ci, llvm::Value *proto,
int A, int B);
void emit_CALL(RaviFunctionDef *def, llvm::Value *L_ci, llvm::Value *proto,
int A, int B, int C);
void emit_JMP(RaviFunctionDef *def, int j);
void emit_FORPREP(RaviFunctionDef *def, llvm::Value *L_ci, llvm::Value *proto,

@ -0,0 +1,32 @@
local z,x,y
-- test 1
z = function(a)
print(a)
return a+1
end
x = function(yy)
local j = 5
j = yy(j)
return j
end
y = x(z)
assert(y == 6)
-- test 2
z = function (a,p)
p(a)
return 6
end
x = function (yy,p)
local j = 5
j = yy(j,p)
return j
end
y = x(z,print)
assert(y == 6)

@ -380,8 +380,8 @@ int luaD_precall (lua_State *L, StkId func, int nresults, int compile) {
lua_assert(L->ci == prevci);
ci = L->ci;
lua_assert(isLua(ci));
lua_assert(GET_OPCODE(*((ci)->u.l.savedpc - 1)) == OP_CALL);
return 1;
/* Return a different value from 1 to allow luaV_execute() to distinguish between JITed function and true C function*/
return 2;
}
}
return 0;

@ -1035,8 +1035,12 @@ newframe: /* reentry point when frame changes (call/return) */
int b = GETARG_B(i);
int nresults = GETARG_C(i) - 1;
if (b != 0) L->top = ra + b; /* else previous instruction set top */
if (luaD_precall(L, ra, nresults, 1)) { /* C function? */
if (nresults >= 0) L->top = ci->top; /* adjust results */
int c_or_compiled = luaD_precall(L, ra, nresults, 1);
if (c_or_compiled) { /* C or Lua JITed function? */
/* RAVI change - if the Lua function was JIT compiled then luaD_precall() returns 2
* A return value of 1 indicates non Lua C function
*/
if (c_or_compiled == 1 && nresults >= 0) L->top = ci->top; /* adjust results */
base = ci->u.l.base;
}
else { /* Lua function */

@ -0,0 +1,80 @@
#include "ravi_llvmcodegen.h"
namespace ravi {
void RaviCodeGenerator::emit_CALL(RaviFunctionDef *def, llvm::Value *L_ci, llvm::Value *proto,
int A, int B, int C) {
//int nresults = c - 1;
//if (b != 0) L->top = ra + b; /* else previous instruction set top */
//if (luaD_precall(L, ra, nresults, 1)) { /* C function? */
// if (nresults >= 0) L->top = ci->top; /* adjust results */
//}
//else { /* Lua function */
// luaV_execute(L);
//}
llvm::Instruction *base_ptr = emit_load_base(def);
llvm::Value *top = nullptr;
int nresults = C - 1;
if (B != 0) {
// L->top = ra + b
// See similar construct in OP_RETURN
// Get pointer to register at ra + b
llvm::Value *ptr = emit_array_get(def, base_ptr, A + B);
// Get pointer to L->top
top = emit_gep(def, "L.top", def->L, 0, 4);
// Assign to L->top
llvm::Instruction *ins = def->builder->CreateStore(ptr, top);
ins->setMetadata(llvm::LLVMContext::MD_tbaa,
def->types->tbaa_luaState_topT);
}
llvm::Value *ra = emit_gep_ra(def, base_ptr, A);
llvm::Value *precall_result = def->builder->CreateCall4(def->luaD_precallF, def->L, ra, llvm::ConstantInt::get(def->types->C_intT, nresults),
def->types->kInt[2]);
llvm::Value *precall_bool = def->builder->CreateICmpEQ(precall_result, def->types->kInt[0]);
llvm::BasicBlock *then_block =
llvm::BasicBlock::Create(def->jitState->context(), "if.lua.function", def->f);
llvm::BasicBlock *else_block =
llvm::BasicBlock::Create(def->jitState->context(), "if.not.lua.function");
llvm::BasicBlock *end_block =
llvm::BasicBlock::Create(def->jitState->context(), "op_call.done");
def->builder->CreateCondBr(precall_bool, then_block, else_block);
def->builder->SetInsertPoint(then_block);
// Call luaV_execute
def->builder->CreateCall(def->luaV_executeF, def->L);
def->builder->CreateBr(end_block);
def->f->getBasicBlockList().push_back(else_block);
def->builder->SetInsertPoint(else_block);
if (nresults >= 0) {
llvm::Value *precall_C = def->builder->CreateICmpEQ(precall_result, def->types->kInt[1]);
llvm::BasicBlock *then1_block =
llvm::BasicBlock::Create(def->jitState->context(), "if.C.function", def->f);
def->builder->CreateCondBr(precall_C, then1_block, end_block);
def->builder->SetInsertPoint(then1_block);
// Get pointer to ci->top
llvm::Value *citop = emit_gep(def, "ci_top", def->ci_val, 0, 1);
// Load ci->top
llvm::Instruction *citop_val = def->builder->CreateLoad(citop);
// TODO set tbaa
if (!top)
// Get L->top
top = emit_gep(def, "L_top", def->L, 0, 4);
// Assign ci>top to L->top
auto ins = def->builder->CreateStore(citop_val, top);
ins->setMetadata(llvm::LLVMContext::MD_tbaa, def->types->tbaa_luaState_topT);
}
def->builder->CreateBr(end_block);
def->f->getBasicBlockList().push_back(end_block);
def->builder->SetInsertPoint(end_block);
}
}

@ -170,7 +170,8 @@ bool RaviCodeGenerator::canCompile(Proto *p) {
int pc, n = p->sizecode;
// TODO we cannot handle variable arguments or
// if the function has sub functions (closures)
if (p->sizep > 0 || p->is_vararg) {
//if (p->sizep > 0 || p->is_vararg) {
if (p->is_vararg) {
p->ravi_jit.jit_status = 1;
return false;
}
@ -182,6 +183,7 @@ bool RaviCodeGenerator::canCompile(Proto *p) {
switch (o) {
case OP_LOADK:
case OP_LOADNIL:
case OP_CALL:
case OP_RETURN:
case OP_JMP:
case OP_EQ:
@ -263,9 +265,15 @@ void RaviCodeGenerator::emit_extern_declarations(RaviFunctionDef *def) {
def->luaD_poscallF = def->raviF->addExternFunction(
def->types->luaD_poscallT, reinterpret_cast<void *>(&luaD_poscall),
"luaD_poscall");
def->luaD_precallF = def->raviF->addExternFunction(
def->types->luaD_precallT, reinterpret_cast<void *>(&luaD_precall),
"luaD_precall");
def->luaF_closeF = def->raviF->addExternFunction(
def->types->luaF_closeT, reinterpret_cast<void *>(&luaF_close),
"luaF_close");
def->luaG_runerrorF = def->raviF->addExternFunction(
def->types->luaG_runerrorT, reinterpret_cast<void *>(&luaG_runerror),
"luaG_runerror");
def->luaV_equalobjF = def->raviF->addExternFunction(
def->types->luaV_equalobjT, reinterpret_cast<void *>(&luaV_equalobj),
"luaV_equalobj");
@ -275,9 +283,6 @@ void RaviCodeGenerator::emit_extern_declarations(RaviFunctionDef *def) {
def->luaV_lessequalF = def->raviF->addExternFunction(
def->types->luaV_lessequalT, reinterpret_cast<void *>(&luaV_lessequal),
"luaV_lessequal");
def->luaG_runerrorF = def->raviF->addExternFunction(
def->types->luaG_runerrorT, reinterpret_cast<void *>(&luaG_runerror),
"luaG_runerror");
def->luaV_forlimitF = def->raviF->addExternFunction(
def->types->luaV_forlimitT, reinterpret_cast<void *>(&luaV_forlimit),
"luaV_forlimit");
@ -287,9 +292,13 @@ void RaviCodeGenerator::emit_extern_declarations(RaviFunctionDef *def) {
def->luaV_tointegerF = def->raviF->addExternFunction(
def->types->luaV_tointegerT, reinterpret_cast<void *>(&luaV_tointeger_),
"luaV_tointeger_");
def->luaV_executeF = def->raviF->addExternFunction(
def->types->luaV_executeT, reinterpret_cast<void *>(&luaV_execute),
"luaV_execute");
def->luaV_op_loadnilF = def->raviF->addExternFunction(
def->types->luaV_op_loadnilT, reinterpret_cast<void *>(&luaV_op_loadnil),
"luaV_op_loadnil");
// Create printf declaration
std::vector<llvm::Type *> args;
args.push_back(def->types->C_pcharT);
@ -319,7 +328,8 @@ void RaviCodeGenerator::emit_extern_declarations(RaviFunctionDef *def) {
check_exp(getCMode(GET_OPCODE(i)) == OpArgK, k + INDEXK(GETARG_C(i)))
void RaviCodeGenerator::link_block(RaviFunctionDef *def, int pc) {
if (def->jmp_targets[pc].jmp2 && !def->builder->GetInsertBlock()->getTerminator()) {
if (def->jmp_targets[pc].jmp2 &&
!def->builder->GetInsertBlock()->getTerminator()) {
// Handle special case for body of FORLOOP
auto b = def->builder->CreateLoad(def->jmp_targets[pc].forloop_branch);
auto idb = def->builder->CreateIndirectBr(b, 4);
@ -352,7 +362,6 @@ void RaviCodeGenerator::link_block(RaviFunctionDef *def, int pc) {
void RaviCodeGenerator::compile(lua_State *L, Proto *p) {
if (p->ravi_jit.jit_status != 0 || !canCompile(p))
return;
#if 1
RaviFunctionDef def = {0};
llvm::LLVMContext &context = jitState_->context();
@ -484,6 +493,11 @@ void RaviCodeGenerator::compile(lua_State *L, Proto *p) {
case OP_RAVI_LOADIZ: {
emit_LOADIZ(&def, L_ci, proto, A);
} break;
case OP_CALL: {
int B = GETARG_B(i);
int C = GETARG_C(i);
emit_CALL(&def, L_ci, proto, A, B, C);
} break;
case OP_RAVI_ADDFN: {
int B = GETARG_B(i);
@ -616,10 +630,6 @@ void RaviCodeGenerator::compile(lua_State *L, Proto *p) {
} else {
p->ravi_jit.jit_status = 2;
}
#else
// For now
p->ravi_jit.jit_status = 1; // can't compile
#endif
}
void RaviCodeGenerator::scan_jump_targets(RaviFunctionDef *def, Proto *p) {

@ -98,7 +98,8 @@ void RaviCodeGenerator::emit_RETURN(RaviFunctionDef *def, llvm::Value *L_ci,
// Get L->top
top = emit_gep(def, "L_top", def->L, 0, 4);
// Assign ci>top to L->top
def->builder->CreateStore(citop_val, top);
auto ins = def->builder->CreateStore(citop_val, top);
ins->setMetadata(llvm::LLVMContext::MD_tbaa, def->types->tbaa_luaState_topT);
def->builder->CreateBr(ElseBB);
def->f->getBasicBlockList().push_back(ElseBB);
def->builder->SetInsertPoint(ElseBB);

@ -598,6 +598,19 @@ LuaLLVMTypes::LuaLLVMTypes(llvm::LLVMContext &context) : mdbuilder(context) {
elements.push_back(StkIdT);
luaD_poscallT = llvm::FunctionType::get(C_intT, elements, false);
// int luaD_precall (lua_State *L, StkId func, int nresults, int compile);
elements.clear();
elements.push_back(plua_StateT);
elements.push_back(StkIdT);
elements.push_back(C_intT);
elements.push_back(C_intT);
luaD_precallT = llvm::FunctionType::get(C_intT, elements, false);
// void luaV_execute(lua_State L);
elements.clear();
elements.push_back(plua_StateT);
luaV_executeT = llvm::FunctionType::get(llvm::Type::getVoidTy(context), elements, false);
// void luaF_close (lua_State *L, StkId level)
elements.clear();
elements.push_back(plua_StateT);

@ -174,7 +174,7 @@ void *RaviJITFunctionImpl::compile() {
MPM->run(*module_);
delete MPM;
// module_->dump();
//module_->dump();
// We don't need this anymore

@ -122,6 +122,9 @@ int main(int argc, const char *argv[])
{
int failures = 0;
//failures += test_luafileexec1("\\github\\ravi\\ravi-tests\\mandel1.ravi", 0);
failures += test_luacompexec1("local function z(a); print(a); return a+1; end; local function x(yy); local j = 5; j = yy(j); return j; end; local y = x(z); return y", 6);
failures += test_luacompexec1("local function z(a,p); p(a); return 6; end; local function x(yy,p); local j = 5; j = yy(j,p); return j; end; local y = x(z,print); return y", 6);
failures += test_luacompexec1("local function x(yy); local j = 5; yy(j); return j; end; local y = x(print); return y", 5);
failures += test_luacompexec1("local function x(); local i, j:int; j=0; for i=1,1000000000 do; j = j+1; end; return j; end; local y = x(); print(y); return y", 1000000000);
failures += test_luacompexec1("local function x(); local j:double; for i=1,1000000000 do; j = j+1; end; return j; end; local y = x(); print(y); return y", 1000000000);

Loading…
Cancel
Save