diff --git a/vm/cmds.go b/vm/cmds.go index 4464154..20cf6a8 100644 --- a/vm/cmds.go +++ b/vm/cmds.go @@ -11,7 +11,7 @@ type buildCmd struct { pattern lang.Tuple inTypeFn func(lang.Term) (llvm.Type, error) outTypeFn func(lang.Term) (llvm.Type, error) - buildFn func(lang.Term) (llvm.Value, error) + buildFn func(lang.Term) (val, error) } func (cmd buildCmd) matches(t lang.Term) bool { @@ -32,7 +32,7 @@ func (cmd buildCmd) outType(t lang.Term) (llvm.Type, error) { return cmd.outTypeFn(t) } -func (cmd buildCmd) build(t lang.Term) (llvm.Value, error) { +func (cmd buildCmd) build(t lang.Term) (val, error) { return cmd.buildFn(t) } @@ -55,14 +55,17 @@ func buildCmds(mod *Module) []buildCmd { outTypeFn: func(t lang.Term) (llvm.Type, error) { return llvm.Int64Type(), nil }, - buildFn: func(t lang.Term) (llvm.Value, error) { + buildFn: func(t lang.Term) (val, error) { con := t.(lang.Const) coni, err := strconv.ParseInt(string(con), 10, 64) if err != nil { - return llvm.Value{}, err + return val{}, err } - // TODO why does this have to be cast? - return llvm.ConstInt(llvm.Int64Type(), uint64(coni), false), nil + return val{ + // TODO why does this have to be cast? + v: llvm.ConstInt(llvm.Int64Type(), uint64(coni), false), + typ: lang.AInt, + }, nil }, }, @@ -82,28 +85,36 @@ func buildCmds(mod *Module) []buildCmd { } return llvm.StructType(typs, false), nil }, - buildFn: func(t lang.Term) (llvm.Value, error) { + buildFn: func(t lang.Term) (val, error) { tup := t.(lang.Tuple) // if the tuple is empty then it is a void if len(tup) == 0 { - return llvm.Undef(llvm.VoidType()), nil + return val{ + v: llvm.Undef(llvm.VoidType()), + typ: lang.Tuple{lang.ATuple, lang.Tuple{}}, + }, nil } var err error - vals := make([]llvm.Value, len(tup)) + vals := make([]val, len(tup)) typs := make([]llvm.Type, len(tup)) + ttyps := make([]lang.Term, len(tup)) for i := range tup { if vals[i], err = mod.build(tup[i]); err != nil { - return llvm.Value{}, err + return val{}, err } - typs[i] = vals[i].Type() + typs[i] = vals[i].v.Type() + ttyps[i] = vals[i].typ } str := llvm.Undef(llvm.StructType(typs, false)) for i := range vals { - str = mod.b.CreateInsertValue(str, vals[i], i, "") + str = mod.b.CreateInsertValue(str, vals[i].v, i, "") } - return str, nil + return val{ + v: str, + typ: lang.Tuple{lang.ATuple, lang.Tuple(ttyps)}, + }, nil }, }, @@ -112,17 +123,20 @@ func buildCmds(mod *Module) []buildCmd { outTypeFn: func(t lang.Term) (llvm.Type, error) { return llvm.Int64Type(), nil }, - buildFn: func(t lang.Term) (llvm.Value, error) { + buildFn: func(t lang.Term) (val, error) { tup := t.(lang.Tuple) v1, err := mod.build(tup[0]) if err != nil { - return llvm.Value{}, err + return val{}, err } v2, err := mod.build(tup[1]) if err != nil { - return llvm.Value{}, err + return val{}, err } - return mod.b.CreateAdd(v1, v2, ""), nil + return val{ + v: mod.b.CreateAdd(v1.v, v2.v, ""), + typ: v1.typ, + }, nil }, }, } diff --git a/vm/vm.go b/vm/vm.go index 2e83a19..ffb49ea 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -10,9 +10,9 @@ import ( "llvm.org/llvm/bindings/go/llvm" ) -// Val holds onto a value which has been created within the VM -type Val struct { - v llvm.Value +type val struct { + typ lang.Term + v llvm.Value } // Module contains a compiled set of code which can be run, dumped in IR form, @@ -87,14 +87,15 @@ func (mod *Module) outType(t lang.Term) (llvm.Type, error) { return cmd.outType(t.(lang.Tuple)[1]) } -func (mod *Module) build(t lang.Term) (llvm.Value, error) { +func (mod *Module) build(t lang.Term) (val, error) { cmd, err := mod.matchingBuildCmd(t) if err != nil { - return llvm.Value{}, err + return val{}, err } return cmd.build(t.(lang.Tuple)[1]) } +// TODO make this return a val once we get function types func (mod *Module) buildFn(tt ...lang.Term) (llvm.Value, error) { if len(tt) == 0 { return llvm.Value{}, errors.New("function cannot be empty") @@ -120,13 +121,13 @@ func (mod *Module) buildFn(tt ...lang.Term) (llvm.Value, error) { block := llvm.AddBasicBlock(fn, "") mod.b.SetInsertPoint(block, block.FirstInstruction()) - var out llvm.Value + var out val for _, t := range tt { if out, err = mod.build(t); err != nil { return llvm.Value{}, err } } - mod.b.CreateRet(out) + mod.b.CreateRet(out.v) return fn, nil }