LARA

import SimpleCFG._, SimpleAST._
final class ASTtoCFG(cfg : Cfg) {
 
  private type Vertex = cfg.Vertex
 
  private val TrueConst = 1
  private val FalseConst = 0
 
  object FreshName { 
    var counter : Int = 0
    def get : String = {
      counter = counter + 1
      "x" + counter
    }
  }
 
  object Emit {
    private var pc : Vertex = cfg.entry
    def getPC : Vertex = { pc }
    def setPC(v : Vertex) = { pc = v }
    def statementBetween(v1 : Vertex, s : CfgStmt, v2 : Vertex) = {
      cfg += (v1, s, v2)
    }
    def statement(s : CfgStmt) = {
      val v = cfg.newVertex
      cfg += (pc, s, v)
      setPC(v)
    }
    def statementCont(s : CfgStmt, cont : Vertex) = {
      cfg += (pc, s, cont)
    }
    def goto(cont : Vertex) = {
      cfg += (pc, Skip, cont)
    }
  }
 
  def condExpr(e : Expression, falseCont : Vertex, trueCont : Vertex) 
  : Unit = 
    e match {
      case Leq(e1,e2) => {
	val x1 = expr(e1)
	val x2 = expr(e2)
	Emit.statementCont(Assume(x1, LeqOp, x2), trueCont)
	Emit.statementCont(Assume(x2, LTOp, x1), falseCont)
      }
      case And(e1,e2) => { // must do short-circuit
	val soFarTrueV = cfg.newVertex
	condExpr(e1, falseCont, soFarTrueV)
        Emit.setPC(soFarTrueV)
	condExpr(e2, falseCont, trueCont)
      }
      case Var(n) => { // false iff zero
	val current = Emit.getPC
	Emit.statementBetween(current, Assume(VarValue(n),EqOp,Const(0)),
			      falseCont)
	Emit.statementBetween(current, Assume(VarValue(n),NeqOp,Const(0)),
			      trueCont)
      }
      case _ => error("Not a boolean")
    }
 
  def exprStoreBool(lhs : Variable, e : Expression) = {
    val trueV = cfg.newVertex
    val falseV = cfg.newVertex
    condExpr(e, falseV, trueV)
    val afterV = cfg.newVertex    
    Emit.statementBetween(falseV, Copy(lhs,Const(FalseConst)), afterV)
    Emit.statementBetween(trueV, Copy(lhs,Const(TrueConst)), afterV)
    Emit.setPC(afterV)
  }
 
  def exprStore(lhs : Variable, e : Expression) =  e match {
    case Var(id) => Emit.statement(Copy(lhs, VarValue(id)))
    case IntConst(c) => Emit.statement(Copy(lhs, Const(c)))
    case Plus(e1,e2) => {
      val x1 = expr(e1)
      val x2 = expr(e2)
      Emit.statement(AssignBin(lhs, x1, PlusOp, x2))
    }
    case And(e1,e2) => exprStoreBool(lhs,e)
    case Leq(e1,e2) => exprStoreBool(lhs,e)
  }
 
  def alreadySimple(e : Expression) : Option[SimpleValue] = e match {
    case Var(id) => Some(VarValue(id))
    case IntConst(c) => Some(Const(c))
    case _ => None
  }
 
  def expr(e : Expression) : SimpleValue = alreadySimple(e) match {
    case Some(v) => v
    case None => {
      val x = FreshName.get
      exprStore(x, e)
      VarValue(x)
    }
  }
 
  def stmts(sts : List[Statement], cont : Vertex) : Unit =
    sts match {
      case Nil => Emit.goto(cont)
      case s::Nil => stmt(s, cont)
      case s::sts1 => {
	val v = cfg.newVertex
	stmt(s, v)
	Emit.setPC(v)
	stmts(sts1, cont)
      }
    }
  // cont = where to continue after the statement
  def stmt(s : Statement, cont : Vertex) : Unit = {
    s match {
      case AssignStat(lhs,rhs) => {
	exprStore(lhs, rhs)
	Emit.goto(cont)
      }
      case PrintStat(e) => {
	val v = expr(e)
	Emit.statementCont(Print(v), cont)
      }
      case IfStat(cond,trueS,falseS) => {
	val falseV = cfg.newVertex
	val trueV = cfg.newVertex
	condExpr(cond, falseV, trueV)
 
	Emit.setPC(falseV)
	stmt(falseS, cont)
 
	Emit.setPC(trueV)
	stmt(trueS, cont)
      }
      case WhileStat(cond,body) => {
        val beginning = Emit.getPC
	val trueV = cfg.newVertex
	condExpr(cond, cont, trueV)
	Emit.setPC(trueV)
	stmt(body, beginning)
      }
      case BlockStat(sts) => stmts(sts, cont)
    }
  }
  def fewerSkips = {
    for (v <- cfg.V) {
      if ((v != cfg.entry) &&
	  (v != cfg.exit) &&
	  (v.out.size==1))
      {
	for (eOut <- v.out) {
	  if (eOut.lab==Skip) {
            for (eIn <- v.in) {
	      // remove old edge
	      cfg -= (eIn.v1, eIn.lab, eIn.v2)
	      cfg -= (eOut.v1, eOut.lab, eOut.v2)
	      // insert new edge with label of incoming one
	      cfg += (eIn.v1, eIn.lab, eOut.v2)
	    }
	  }
	}
      }
    }
  }
/* Similar to: 
    * any other code generation
    * translation of a regular expression to a finite-state machine
  Place of current code in graph is given by two nodes:
    * pc - where the code comes from
    * continuation - where the code should go afterwards
      * sometimes not known in procedure
*/
  def translate(s : Statement) = {
    stmt(s, cfg.exit)
    fewerSkips
  }
}
 
object TestTranslations {
  def test(s : Statement) : Unit = {
    val cfg = new Cfg
    new ASTtoCFG(cfg).translate(s)
    println(cfg); cfg.dottyView
  }
}
 
object TestTrans1 {
  def main(args : Array[String]) = {
    val stat1 = BlockStat(List(
		     AssignStat("x", IntConst(3)),
	             AssignStat("y", Plus(Var("x"),Var("x"))),
                     IfStat(Leq(Var("y"),Plus(Var("x"),IntConst(1))),
			    AssignStat("z", Plus(Plus(Var("x"),Var("y")),
					         IntConst(42))),
			    AssignStat("z", Var("y")))
		   ))
    TestTranslations.test(stat1)
  }
}
object TestTrans2 {
  def main(args : Array[String]) = {
    val body = BlockStat(List(AssignStat("j", Plus(Var("i"), IntConst(42))),
			      AssignStat("i", Plus(Var("j"), IntConst(-40)))))
    val stat2 =
      BlockStat(List(
	AssignStat("i", IntConst(0)),
	WhileStat(Leq(Var("i"), IntConst(10)),
		  body),
	AssignStat("res", Var("j"))))
    TestTranslations.test(stat2)
  }
}
 
object TestTrans3 {
  def main(args : Array[String]) = {
    val body = BlockStat(List(AssignStat("j", Plus(Var("i"), IntConst(42))),
			      AssignStat("i", Plus(Var("j"), IntConst(-40)))))
    val stat2 =
      BlockStat(List(
	AssignStat("i", IntConst(0)),
	WhileStat(Leq(Var("i"), IntConst(10)),
		  body),
	AssignStat("res", Var("j"))))
    val stat3 = 
      BlockStat(List(
	AssignStat("k", IntConst(0)),
	WhileStat(Leq(Var("k"), IntConst(5)),
                  BlockStat(List(stat2,
		  AssignStat("k", Plus(Var("k"), IntConst(1))))))))
    TestTranslations.test(stat3)
  }
}