issue #98 type check for statements

pull/168/head
Dibyendu Majumdar 5 years ago
parent 859eb75316
commit e4f0b7aa1a

@ -472,3 +472,413 @@ function()
--[suffixed expr end]
end
function()
--locals t
local
--[symbols]
t --local symbol integer[]
if
--[suffixed expr start] boolean
--[primary start] boolean
--[binary expr start] boolean
--[suffixed expr start] integer
--[primary start] integer[]
t --local symbol integer[]
--[primary end]
--[suffix list start]
--[Y index start] integer
[
1
]
--[Y index end]
--[suffix list end]
--[suffixed expr end]
==
5
--[binary expr end]
--[primary end]
--[suffixed expr end]
then
return
true
end
return
false
end
function()
matmul --global symbol ?
=
function(
a --local symbol table
,
b --local symbol table
)
--locals a, b, m, n, p, x, c, i, xi, j, sum, ai, cj, k
--[expression statement start]
--[expression list start]
--[suffixed expr start] any
--[primary start] any
assert --global symbol ?
--[primary end]
--[suffix list start]
--[function call start] any
(
--[binary expr start] any
--[unary expr start] integer
@integer
--[suffixed expr start] any
--[primary start] any
--[unary expr start] any
#
--[suffixed expr start] any
--[primary start] table
a --local symbol table
--[primary end]
--[suffix list start]
--[Y index start] any
[
1
]
--[Y index end]
--[suffix list end]
--[suffixed expr end]
--[unary expr end]
--[primary end]
--[suffixed expr end]
--[unary expr end]
==
--[unary expr start] any
#
--[suffixed expr start] table
--[primary start] table
b --local symbol table
--[primary end]
--[suffixed expr end]
--[unary expr end]
--[binary expr end]
)
--[function call end]
--[suffix list end]
--[suffixed expr end]
--[expression list end]
--[expression statement end]
local
--[symbols]
m --local symbol integer
,
n --local symbol integer
,
p --local symbol integer
,
x --local symbol table
--[expressions]
--[unary expr start] any
#
--[suffixed expr start] table
--[primary start] table
a --local symbol table
--[primary end]
--[suffixed expr end]
--[unary expr end]
,
--[unary expr start] any
#
--[suffixed expr start] any
--[primary start] table
a --local symbol table
--[primary end]
--[suffix list start]
--[Y index start] any
[
1
]
--[Y index end]
--[suffix list end]
--[suffixed expr end]
--[unary expr end]
,
--[unary expr start] any
#
--[suffixed expr start] any
--[primary start] table
b --local symbol table
--[primary end]
--[suffix list start]
--[Y index start] any
[
1
]
--[Y index end]
--[suffix list end]
--[suffixed expr end]
--[unary expr end]
,
{ --[table constructor start] table
} --[table constructor end]
local
--[symbols]
c --local symbol table
--[expressions]
--[suffixed expr start] any
--[primary start] any
matrix --global symbol ?
--[primary end]
--[suffix list start]
--[field selector start] any
.
'T'
--[field selector end]
--[function call start] any
(
--[suffixed expr start] table
--[primary start] table
b --local symbol table
--[primary end]
--[suffixed expr end]
)
--[function call end]
--[suffix list end]
--[suffixed expr end]
for
i --local symbol integer
=
1
,
--[suffixed expr start] integer
--[primary start] integer
m --local symbol integer
--[primary end]
--[suffixed expr end]
do
local
--[symbols]
xi --local symbol number[]
--[expressions]
--[suffixed expr start] any
--[primary start] any
table --global symbol ?
--[primary end]
--[suffix list start]
--[field selector start] any
.
'numarray'
--[field selector end]
--[function call start] any
(
--[suffixed expr start] integer
--[primary start] integer
p --local symbol integer
--[primary end]
--[suffixed expr end]
,
0.0000000000000000
)
--[function call end]
--[suffix list end]
--[suffixed expr end]
--[expression statement start]
--[var list start]
--[suffixed expr start] any
--[primary start] table
x --local symbol table
--[primary end]
--[suffix list start]
--[Y index start] any
[
--[suffixed expr start] integer
--[primary start] integer
i --local symbol integer
--[primary end]
--[suffixed expr end]
]
--[Y index end]
--[suffix list end]
--[suffixed expr end]
= --[var list end]
--[expression list start]
--[suffixed expr start] number[]
--[primary start] number[]
xi --local symbol number[]
--[primary end]
--[suffixed expr end]
--[expression list end]
--[expression statement end]
for
j --local symbol integer
=
1
,
--[suffixed expr start] integer
--[primary start] integer
p --local symbol integer
--[primary end]
--[suffixed expr end]
do
local
--[symbols]
sum --local symbol number
,
ai --local symbol number[]
,
cj --local symbol number[]
--[expressions]
0.0000000000000000
,
--[unary expr start] number[]
@number[]
--[suffixed expr start] any
--[primary start] any
--[suffixed expr start] any
--[primary start] table
a --local symbol table
--[primary end]
--[suffix list start]
--[Y index start] any
[
--[suffixed expr start] integer
--[primary start] integer
i --local symbol integer
--[primary end]
--[suffixed expr end]
]
--[Y index end]
--[suffix list end]
--[suffixed expr end]
--[primary end]
--[suffixed expr end]
--[unary expr end]
,
--[unary expr start] number[]
@number[]
--[suffixed expr start] any
--[primary start] any
--[suffixed expr start] any
--[primary start] table
c --local symbol table
--[primary end]
--[suffix list start]
--[Y index start] any
[
--[suffixed expr start] integer
--[primary start] integer
j --local symbol integer
--[primary end]
--[suffixed expr end]
]
--[Y index end]
--[suffix list end]
--[suffixed expr end]
--[primary end]
--[suffixed expr end]
--[unary expr end]
for
k --local symbol integer
=
1
,
--[suffixed expr start] integer
--[primary start] integer
n --local symbol integer
--[primary end]
--[suffixed expr end]
do
--[expression statement start]
--[var list start]
--[suffixed expr start] number
--[primary start] number
sum --local symbol number
--[primary end]
--[suffixed expr end]
= --[var list end]
--[expression list start]
--[binary expr start] number
--[suffixed expr start] number
--[primary start] number
sum --local symbol number
--[primary end]
--[suffixed expr end]
+
--[binary expr start] number
--[suffixed expr start] number
--[primary start] number[]
ai --local symbol number[]
--[primary end]
--[suffix list start]
--[Y index start] number
[
--[suffixed expr start] integer
--[primary start] integer
k --local symbol integer
--[primary end]
--[suffixed expr end]
]
--[Y index end]
--[suffix list end]
--[suffixed expr end]
*
--[suffixed expr start] number
--[primary start] number[]
cj --local symbol number[]
--[primary end]
--[suffix list start]
--[Y index start] number
[
--[suffixed expr start] integer
--[primary start] integer
k --local symbol integer
--[primary end]
--[suffixed expr end]
]
--[Y index end]
--[suffix list end]
--[suffixed expr end]
--[binary expr end]
--[binary expr end]
--[expression list end]
--[expression statement end]
end
--[expression statement start]
--[var list start]
--[suffixed expr start] number
--[primary start] number[]
xi --local symbol number[]
--[primary end]
--[suffix list start]
--[Y index start] number
[
--[suffixed expr start] integer
--[primary start] integer
j --local symbol integer
--[primary end]
--[suffixed expr end]
]
--[Y index end]
--[suffix list end]
--[suffixed expr end]
= --[var list end]
--[expression list start]
--[suffixed expr start] number
--[primary start] number
sum --local symbol number
--[primary end]
--[suffixed expr end]
--[expression list end]
--[expression statement end]
end
end
return
--[suffixed expr start] table
--[primary start] table
x --local symbol table
--[primary end]
--[suffixed expr end]
end
return
--[suffixed expr start] any
--[primary start] any
matmul --global symbol ?
--[primary end]
--[suffixed expr end]
end

@ -98,4 +98,34 @@ doast(str)
str=
[[return x.y[1]
]]
doast(str)
str=
[[local t: integer[]
if (t[1] == 5) then
return true
end
return false
]]
doast(str)
str=
[[function matmul(a: table, b: table)
assert(@integer(#a[1]) == #b);
local m: integer, n: integer, p: integer, x: table = #a, #a[1], #b[1], {};
local c: table = matrix.T(b); -- transpose for efficiency
for i = 1, m do
local xi: number[] = table.numarray(p, 0.0)
x[i] = xi
for j = 1, p do
local sum: number, ai: number[], cj: number[] = 0.0, @number[](a[i]), @number[](c[j]);
-- for luajit, caching c[j] or not makes no difference; lua is not so clever
for k = 1, n do sum = sum + ai[k] * cj[k] end
xi[j] = sum;
end
end
return x
end
return matmul
]]
doast(str)

@ -925,7 +925,7 @@ static struct ast_node *parse_sub_expression(struct parser_state *parser, int li
expr->type = AST_UNARY_EXPR;
expr->unary_expr.expr = subexpr;
expr->unary_expr.unary_op = uop;
expr->unary_expr.type.type_name = usertype;
expr->unary_expr.type.type_name = usertype;
}
else {
expr = parse_simple_expression(parser);
@ -943,7 +943,7 @@ static struct ast_node *parse_sub_expression(struct parser_state *parser, int li
binexpr->binary_expr.expr_left = expr;
binexpr->binary_expr.expr_right = exprright;
binexpr->binary_expr.binary_op = op;
expr = binexpr; // Becomes the left expr for next iteration
expr = binexpr; // Becomes the left expr for next iteration
op = nextop;
}
*untreated_op = op; /* return first untreated operator */

@ -2,19 +2,19 @@
#include "ravi_ast.h"
/* Type checker - WIP */
static void typecheck_ast_node(struct ast_node *function, struct ast_node *node);
static void typecheck_ast_node(struct ast_container *container, struct ast_node *function, struct ast_node *node);
/* Type checker - WIP */
static void typecheck_ast_list(struct ast_node *function, struct ast_node_list *list) {
static void typecheck_ast_list(struct ast_container *container, struct ast_node *function, struct ast_node_list *list) {
struct ast_node *node;
FOR_EACH_PTR(list, node) { typecheck_ast_node(function, node); }
FOR_EACH_PTR(list, node) { typecheck_ast_node(container, function, node); }
END_FOR_EACH_PTR(node);
}
/* Type checker - WIP */
static void typecheck_unaryop(struct ast_node *function, struct ast_node *node) {
static void typecheck_unaryop(struct ast_container *container, struct ast_node *function, struct ast_node *node) {
UnOpr op = node->unary_expr.unary_op;
typecheck_ast_node(function, node->unary_expr.expr);
typecheck_ast_node(container, function, node->unary_expr.expr);
ravitype_t subexpr_type = node->unary_expr.expr->common_expr.type.type_code;
switch (op) {
case OPR_MINUS:
@ -70,12 +70,12 @@ static void typecheck_unaryop(struct ast_node *function, struct ast_node *node)
}
/* Type checker - WIP */
static void typecheck_binaryop(struct ast_node *function, struct ast_node *node) {
static void typecheck_binaryop(struct ast_container *container, struct ast_node *function, struct ast_node *node) {
BinOpr op = node->binary_expr.binary_op;
struct ast_node *e1 = node->binary_expr.expr_left;
struct ast_node *e2 = node->binary_expr.expr_right;
typecheck_ast_node(function, e1);
typecheck_ast_node(function, e2);
typecheck_ast_node(container, function, e1);
typecheck_ast_node(container, function, e2);
switch (op) {
case OPR_ADD:
case OPR_SUB:
@ -157,12 +157,12 @@ static bool is_unindexable_type(struct var_type *type) {
* x[1][2]
* x.y[1]
*/
static void typecheck_suffixedexpr(struct ast_node *function, struct ast_node *node) {
typecheck_ast_node(function, node->suffixed_expr.primary_expr);
static void typecheck_suffixedexpr(struct ast_container *container, struct ast_node *function, struct ast_node *node) {
typecheck_ast_node(container, function, node->suffixed_expr.primary_expr);
struct ast_node *prev_node = node->suffixed_expr.primary_expr;
struct ast_node *this_node;
FOR_EACH_PTR(node->suffixed_expr.suffix_list, this_node) {
typecheck_ast_node(function, this_node);
typecheck_ast_node(container, function, this_node);
if (this_node->type == AST_Y_INDEX_EXPR) {
if (prev_node->common_expr.type.type_code == RAVI_TARRAYFLT) {
if (this_node->index_expr.expr->common_expr.type.type_code == RAVI_TNUMINT) {
@ -190,7 +190,8 @@ static void typecheck_suffixedexpr(struct ast_node *function, struct ast_node *n
copy_type(node->suffixed_expr.type, prev_node->common_expr.type);
}
static void typecheck_var_assignment(struct var_type *var_type, struct var_type *expr_type, const char *var_name) {
static void typecheck_var_assignment(struct ast_container *container, struct var_type *var_type,
struct var_type *expr_type, const char *var_name) {
if (var_type->type_code == RAVI_TANY)
// Any value can be assigned to type ANY
return;
@ -212,13 +213,14 @@ static void typecheck_var_assignment(struct var_type *var_type, struct var_type
}
}
static void typecheck_local_statement(struct ast_node *function, struct ast_node *node) {
static void typecheck_local_statement(struct ast_container *container, struct ast_node *function,
struct ast_node *node) {
// The local vars should already be annotated
// We need to typecheck the expressions to the right of =
// Then we need to ensure that the assignments are valid
// We can perhaps insert type assertions where we have a mismatch?
typecheck_ast_list(function, node->local_stmt.expr_list);
typecheck_ast_list(container, function, node->local_stmt.expr_list);
struct lua_symbol *var;
struct ast_node *expr;
@ -233,17 +235,18 @@ static void typecheck_local_statement(struct ast_node *function, struct ast_node
struct var_type *expr_type = &expr->common_expr.type;
const char *var_name = getstr(var->var.var_name);
typecheck_var_assignment(var_type, expr_type, var_name);
typecheck_var_assignment(container, var_type, expr_type, var_name);
NEXT_PTR_LIST(var);
NEXT_PTR_LIST(expr);
}
}
static void typecheck_expr_statement(struct ast_node *function, struct ast_node *node) {
static void typecheck_expr_statement(struct ast_container *container, struct ast_node *function,
struct ast_node *node) {
if (node->expression_stmt.var_expr_list)
typecheck_ast_list(function, node->expression_stmt.var_expr_list);
typecheck_ast_list(function, node->expression_stmt.expr_list);
typecheck_ast_list(container, function, node->expression_stmt.var_expr_list);
typecheck_ast_list(container, function, node->expression_stmt.expr_list);
if (!node->expression_stmt.var_expr_list)
return;
@ -261,20 +264,22 @@ static void typecheck_expr_statement(struct ast_node *function, struct ast_node
struct var_type *expr_type = &expr->common_expr.type;
const char *var_name = ""; // FIXME how do we get this?
typecheck_var_assignment(var_type, expr_type, var_name);
typecheck_var_assignment(container, var_type, expr_type, var_name);
NEXT_PTR_LIST(var);
NEXT_PTR_LIST(expr);
}
}
static void typecheck_for_in_statment(struct ast_node *function, struct ast_node *node) {
typecheck_ast_list(function, node->for_stmt.expr_list);
typecheck_ast_list(function, node->for_stmt.for_statement_list);
static void typecheck_for_in_statment(struct ast_container *container, struct ast_node *function,
struct ast_node *node) {
typecheck_ast_list(container, function, node->for_stmt.expr_list);
typecheck_ast_list(container, function, node->for_stmt.for_statement_list);
}
static void typecheck_for_num_statment(struct ast_node *function, struct ast_node *node) {
typecheck_ast_list(function, node->for_stmt.expr_list);
static void typecheck_for_num_statment(struct ast_container *container, struct ast_node *function,
struct ast_node *node) {
typecheck_ast_list(container, function, node->for_stmt.expr_list);
struct ast_node *expr;
enum { I = 1, F = 2, A = 4 }; /* bits representing integer, number, any */
int index_type = 0;
@ -310,29 +315,41 @@ static void typecheck_for_num_statment(struct ast_node *function, struct ast_nod
}
END_FOR_EACH_PTR(sym);
}
typecheck_ast_list(function, node->for_stmt.for_statement_list);
typecheck_ast_list(container, function, node->for_stmt.for_statement_list);
}
static void typecheck_if_statement(struct ast_container *container, struct ast_node *function, struct ast_node *node) {
struct ast_node *test_then_block;
FOR_EACH_PTR(node->if_stmt.if_condition_list, test_then_block) {
typecheck_ast_node(container, function, test_then_block->test_then_block.condition);
typecheck_ast_list(container, function, test_then_block->test_then_block.test_then_statement_list);
}
END_FOR_EACH_PTR(node);
if (node->if_stmt.else_statement_list) {
typecheck_ast_list(container, function, node->if_stmt.else_statement_list);
}
}
/* Type checker - WIP */
static void typecheck_ast_node(struct ast_node *function, struct ast_node *node) {
static void typecheck_ast_node(struct ast_container *container, struct ast_node *function, struct ast_node *node) {
switch (node->type) {
case AST_FUNCTION_EXPR: {
typecheck_ast_list(function, node->function_expr.function_statement_list);
typecheck_ast_list(container, function, node->function_expr.function_statement_list);
break;
}
case AST_NONE: {
break;
}
case AST_RETURN_STMT: {
typecheck_ast_list(function, node->return_stmt.expr_list);
typecheck_ast_list(container, function, node->return_stmt.expr_list);
break;
}
case AST_LOCAL_STMT: {
typecheck_local_statement(function, node);
typecheck_local_statement(container, function, node);
break;
}
case AST_FUNCTION_STMT: {
typecheck_ast_node(function, node->function_stmt.function_expr);
typecheck_ast_node(container, function, node->function_stmt.function_expr);
break;
}
case AST_LABEL_STMT: {
@ -345,15 +362,11 @@ static void typecheck_ast_node(struct ast_node *function, struct ast_node *node)
break;
}
case AST_EXPR_STMT: {
typecheck_expr_statement(function, node);
typecheck_expr_statement(container, function, node);
break;
}
case AST_IF_STMT: {
struct ast_node *test_then_block;
FOR_EACH_PTR(node->if_stmt.if_condition_list, test_then_block) {}
END_FOR_EACH_PTR(node);
if (node->if_stmt.else_block) {
}
typecheck_if_statement(container, function, node);
break;
}
case AST_WHILE_STMT: {
@ -363,15 +376,15 @@ static void typecheck_ast_node(struct ast_node *function, struct ast_node *node)
break;
}
case AST_FORIN_STMT: {
typecheck_for_in_statment(function, node);
typecheck_for_in_statment(container, function, node);
break;
}
case AST_FORNUM_STMT: {
typecheck_for_num_statment(function, node);
typecheck_for_num_statment(container, function, node);
break;
}
case AST_SUFFIXED_EXPR: {
typecheck_suffixedexpr(function, node);
typecheck_suffixedexpr(container, function, node);
break;
}
case AST_FUNCTION_CALL_EXPR: {
@ -379,7 +392,7 @@ static void typecheck_ast_node(struct ast_node *function, struct ast_node *node)
}
else {
}
typecheck_ast_list(function, node->function_call_expr.arg_list);
typecheck_ast_list(container, function, node->function_call_expr.arg_list);
break;
}
case AST_SYMBOL_EXPR: {
@ -388,11 +401,11 @@ static void typecheck_ast_node(struct ast_node *function, struct ast_node *node)
break;
}
case AST_BINARY_EXPR: {
typecheck_binaryop(function, node);
typecheck_binaryop(container, function, node);
break;
}
case AST_UNARY_EXPR: {
typecheck_unaryop(function, node);
typecheck_unaryop(container, function, node);
break;
}
case AST_LITERAL_EXPR: {
@ -400,11 +413,11 @@ static void typecheck_ast_node(struct ast_node *function, struct ast_node *node)
break;
}
case AST_FIELD_SELECTOR_EXPR: {
typecheck_ast_node(function, node->index_expr.expr);
typecheck_ast_node(container, function, node->index_expr.expr);
break;
}
case AST_Y_INDEX_EXPR: {
typecheck_ast_node(function, node->index_expr.expr);
typecheck_ast_node(container, function, node->index_expr.expr);
break;
}
case AST_INDEXED_ASSIGN_EXPR: {
@ -421,12 +434,12 @@ static void typecheck_ast_node(struct ast_node *function, struct ast_node *node)
}
/* Type checker - WIP */
static void typecheck_function(struct ast_node *func) {
typecheck_ast_list(func, func->function_expr.function_statement_list);
static void typecheck_function(struct ast_container *container, struct ast_node *func) {
typecheck_ast_list(container, func, func->function_expr.function_statement_list);
}
/* Type checker - WIP */
void raviA_ast_typecheck(struct ast_container *container) {
struct ast_node *main_function = container->main_function;
typecheck_function(main_function);
typecheck_function(container, main_function);
}

Loading…
Cancel
Save