/*
 *              __  ____________        ____         __    
 *             / / / /_  __/ __/ ____  / __/______ _/ /__ _
 *            / /_/ / / / _\ \  /___/ _\ \/ __/ _ `/ / _ `/
 *            \____/ /_/ /___/       /___/\__/\_,_/_/\_,_/ 
 * 
 * This file is part of an implementation of the Universe Type System for
 * Scala.
 * 
 * Copyright (C) 2007-2008  Swiss Federal Institute of Technology, Zurich
 * 
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License along
 * with this program; if not, write to the Free Software Foundation, Inc.,
 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 * 
 * 
 * $Id: RuntimeCheckTransformBase.scala 883 2008-02-01 18:59:56Z ms $
 */
package ch.ethz.inf.sct.uts.plugin.runtimecheck

import scala.tools.nsc.transform._
import ch.ethz.inf.sct.uts.plugin.common._
import scala.tools.nsc.util._
import scala.tools.nsc._
import ch.ethz.inf.sct.uts.annotation._
import scala.tools.nsc.symtab._

/**
 * Base trait with the transformation which adds runtime checks to the AST.
 *
 * @author  Manfred Stock
 * @version $Revision: 883 $
 */
trait RuntimeCheckTransformBase {
  val global: Global
  
  import global._
  import definitions._ // standard classes and methods
  import typer.{typed, atOwner}    // methods to type trees
  import posAssigner.atPos         // for filling in tree positions
  import UTSDefaults._

  object extendedType extends ExtendedType {
    val global : RuntimeCheckTransformBase.this.global.type = RuntimeCheckTransformBase.this.global
  }
  import extendedType._
  
  /**
   * Factory method to create new <code>Transformer</code>s.
   * @param unit The unit which will be processed by the <code>Transformer</code>.
   * @return the new <code>Transformer</code>.
   */
  def newTransformer(unit: CompilationUnit) : Transformer
  
  /**
   * Logger for output of the plugin.
   */
  val logger : UTSLogger

  /**
   * Base Transformer to add UTS runtime checks to the AST.
   * @param unit The unit where the checks should be added.
   */
  abstract class RTCTransformer(unit: CompilationUnit) extends Transformer {
    /**
     * Object which provides the methods to call at runtime.
     */
    val runtimeObject : String
   
    /**
     * Log a message at runtime.
     * @param message Message to print.
     * @return the tree which prints the given message at runtime.
     */
    def log(message: String) : Tree = {
      Apply(
          Select(
              Select(
                  Ident(nme.scala_) setSymbol ScalaPackage,
                  nme.Predef
              ),
              newTermName("println")
          ),
          List(
              Literal(Constant(message))
          )
      )
    }
    
    /**
     * Log a message at runtime.
     * @param pos Position in the source where this code will seem to come from.
     * @param message Message to print.
     * @return the typed tree which prints the given message at runtime.
     */
    def log(pos: Position, message: String) : Tree = {
      typed(atPos(pos)(
          log(message)
      ))
    }
    
    /**
     * Get the class object of a class from a compiler type.
     * @param tpe Type of the compiler.
     * @return the <code>Tree</code> which results in the class object 
     *         of the given type at compile time.
     */
    def getClassFromType(tpe: Type) : Tree = {
      /**
       * Get name of a symbol, insert '$' if it is an inner class. Workaround
       * for bug/problem with Scala compiler, it would be better if the runtime-check 
       * phase would run after flatten, where the 'real' name is known. Unfortunately, 
       * this currently crashes the compiler, even if the plugin does not do anything.
       * Running the plugin after flatten should be reconsidered as soon as Ticket #375
       * (https://lampsvn.epfl.ch/trac/scala/ticket/375) was fixed.
       * 
       * This workaround works for inner classes, but not if they are anonymous, as their 
       * name seems to be calculated differently.
       * @param sym The symbol.
       * @return the name of the symbol.
       */
      def getName(sym: Symbol) : String = {
        if (!sym.hasFlag(Flags.METHOD) &&
          sym.rawowner != NoSymbol && !sym.rawowner.isPackageClass) {
          newTermName(getName(sym.rawowner) + "$" + sym.rawname).toString
        }
        else {
          sym.fullNameString
        }
      }
      Apply(
          Select(
              Ident(newTermName("java.lang.Class")) setSymbol getModule("java.lang.Class"),
              newTermName("forName")
          ),
          List(
              Literal(Constant(getName(tpe.typeSymbol)))
          )
      )
    }
     
    /**
     * Get the tree which leads to a given method of the runtime object's 
     * <code>target</code> (e.g. handler, policy, etc.) reference.
     * @param target Field of the runtime object the method should be called on.
     * @param name   Name of the method.
     * @return the tree which leads to the given method.
     */
    protected def rtMethod(target: String, name: String) : Tree = {
      Select(
          Select(
              Ident(newTermName(runtimeObject)) setSymbol getModule(runtimeObject),
              newTermName(target)
          ),
          newTermName(name)
      )
    }
        
    /**
     * Get the tree which leads to a given method of the runtime object's
     * <code>target</code> (e.g. handler, policy, etc.) reference.
     * @param target Field of the runtime object the method should be called on.
     * @param name Name of the method.
     * @param args Arguments to the method call.
     * @return the tree which leads to the given method.
     */
    protected def rtMethod(target: String, name: String, args: List[Tree]) : Tree = {
      Apply(
          rtMethod(target,name),
          args
      )
    }
    
    /**
     * Get the tree which leads to a given method of the runtime object's
     * <code>handler</code> reference.
     * @param name Name of the method.
     * @param args Arguments to the method call.
     * @return the tree which leads to the given method.
     */
    protected def rthMethod(name: String, args: List[Tree]) : Tree = {
    	rtMethod("handler",name,args)     
    }
    
    /**
     * Get the tree which leads to a given method of the runtime object's
     * <code>policy</code> reference.
     * @param name Name of the method.
     * @param args Arguments to the method call.
     * @return the tree which leads to the given method.
     */
    protected def rtpMethod(name: String, args: List[Tree]) : Tree = {
      rtMethod("policy",name,args)     
    }
     
    /**
     * Call a given method on the runtime object's <code>target</code> 
     * (e.g. handler, policy, etc.) reference which does not take arguments.
     * @param pos Position where the generated code should appear.
     * @param target Field of the runtime object the method should be called on.
     * @param name Name of the method to call.
     * @return the typed tree which calls the specified method. 
     */
    def rtCall(pos: Position, target: String, name: String) : Tree = {
      typed(atPos(pos)(
          rtMethod(target,name)
      ))
    }
       
    /**
     * Call a given method on the runtime object's <code>target</code> 
     * (e.g. handler, policy, etc.) reference.
     * @param pos Position where the generated code should appear.
     * @param target Field of the runtime object the method should be called on.
     * @param name Name of the method to call.
     * @param args List of arguments to be passed to the method.
     * @return the typed tree which calls the specified method. 
     */
    def rtCall(pos: Position, target: String, name: String, args: List[Tree]) : Tree = {
      typed(atPos(pos)(
          rtMethod(target,name,args)
      ))
    }
     
    /**
     * Call a given method on the runtime object's <code>handler</code> reference.
     * @param pos Position where the generated code should appear.
     * @param name Name of the method to call.
     * @param args List of arguments to be passed to the method.
     * @return the typed tree which calls the specified method. 
     */
    def rthCall(pos: Position, name: String, args: List[Tree]) : Tree = {
      rtCall(pos,"handler",name,args)
    }

    /**
     * Call a given method on the runtime object's <code>policy</code> reference.
     * @param pos Position where the generated code should appear.
     * @param name Name of the method to call.
     * @param args List of arguments to be passed to the method.
     * @return the typed tree which calls the specified method. 
     */
    def rtpCall(pos: Position, name: String, args: List[Tree]) : Tree = {
      rtCall(pos,"policy",name,args)
    }
     
    /**
     * Get the main modifier of a type, warn if more than one given or 
     * return some default if none was found.
     * @param tpe  The type whose main modifiers is of interest.
     * @return the main modifier or some default.
     */
    def getMainModifier(tpe: Type) : OwnershipModifier = {
      val modifiers = tpe.ownershipModifiers 
      if (modifiers.length == 0) {
        defaultOwnershipModifier
      }
      else {
        if (modifiers.length > 1) {
          logger.notice("More than one ownership modifier found - using "+modifiers.head)
        }
        modifiers.head
      }
    }
    
    /**
     * Get the main modifier of a type, warn if more than one given. If the type was 
     * not annotated, return the default retrieved from the type of the target.
     * @param ttpe Type of the target.
     * @param tpe  The type whose main modifiers is of interest.
     * @return the main modifier or some default.
     */
    def getMainModifier(ttpe: Type, tpe: Type) : OwnershipModifier = {
      if (tpe.isOwnershipAnnotated) {
        getMainModifier(tpe)
      }
      else {
        getMainModifier(ttpe)
      }
    }
    
    /**
     * Create some temporary variable.
     * @param pos Position where the variable shall be created.
     * @param rhs Tree which should be assigned to the newly created value.
     * @param tpe Type of the new variable.
     * @return the tree with the value definition and a function to create the 
     *         tree for access to the value.
     */
    def createTempVar(pos: Position, rhs: Tree, tpe: Type) : (Tree, () => Ident) = {
      val tmp = currentOwner.newValue(pos, newTermName(unit.fresh.newName("tmp$"))).setFlag(Flags.SYNTHETIC).setInfo(tpe)
      val definition = ValDef(tmp, rhs)
      val access = () => Ident(tmp) setType tpe
      (definition, access)
    }

    /**
     * Process the definition of a constructor.
     * @param tree The constructor's definition.
     * @return the modified tree.
     */
    def processConstructor(tree: Tree) : Tree = {
      tree match {
        case DefDef(mods, name, tparams, vparamss, tpt, rhs @ Block(body,result)) => 
          copy.DefDef(tree, mods, name, tparams, vparamss, tpt, 
            Block(
                body.head ::
                rthCall(
                    tree.pos,
                    "setOwner",
                    List(
                        This(currentOwner.enclClass)
                    )
                ) :: body.tail,
                result
            )
          )
      }
    }
    
    /**
     * Process the call of a constructor.
     * @param tree      Tree with the call. 
     * @param om        Main modifier of the newly created instance's type.
     * @param tmp       Temporary variable which contains the new instance.
     * @param tmpaccess Function which returns a tree for access to the temporary variable.
     * @return the body of the block which is to be added in the place of the constructor call.
     */
    def processConstructorCall(tree: Tree, om: OwnershipModifier, tmp: Tree, tmpaccess: () => Ident) : List[Tree]
    
    /**
     * Process a call of the <code>isInstanceOf</code>/<code>asInstanceOf</code> method.
     * @param tree    The tree containing the call.
     * @param mainmod Main modifier of the type the object is checked for/cast to.
     * @return the tree which executes the additional check.
     */
    def processIsAsInstanceOf(tree: Tree, mainmod: OwnershipModifier) : Tree
    
    /**
     * Do the actual transformation.
     * @param tree The subtree to transform.
     * @return the transformed subtree.
     */
    override def transform(tree: Tree): Tree = {
      val tree1 = super.transform(tree)      // transformers always maintain `currentOwner'.
      
      tree1 match {
        case DefDef(mods, name, tparams, vparamss, tpt, rhs @ Block(body,result)) =>
          // Process constructors
          if (tree1.symbol.isConstructor) {
            logger.debug("Found "+tree1.symbol+" owned by "+tree1.symbol.owner)
    				processConstructor(tree1)
          }
          else {
            tree1
          }
        case Apply(fun@Select(qualifier@New(tpt),selector),args) =>
          // Add code around a call to 'new'
          logger.debug("Call to constructor of "+tpt+" with arguments "+args)
        
          // Create tree for temporary reference to the new instance
          val (tmp,tmpaccess) = createTempVar(tree1.pos,tree1,tpt.tpe)
        
          // Extract ownership modifier
          val owner : OwnershipModifier = getMainModifier(tpt.tpe)
        	
          // Construct new block body for the ownership hierarchy maintenance
          val stats = processConstructorCall(tree1,owner,tmp,tmpaccess)
          
          // Create new block which contains the extended new statement
          val res = typed(atPos(tree1.pos)(Block(
              stats,
              tmpaccess()
          )))
          res
        case ta @ TypeApply(sel @ Select(qualifier,selector),tpelist) =>
          // Add checks to asInstanceOf[T] and isInstanceOf[T]
          logger.debug("Found TypeApply on "+qualifier+": "+qualifier.tpe.extractType+" to "+selector+".")
          // Don't do tests on immutable type or casts to/checks for immutable types (ie. Int, String, etc.) 
          if (tpelist.length == 1 && ! tpelist.head.tpe.isImmutable && ! qualifier.tpe.isImmutable 
              // Only handle calls to isInstanceOf/asInstanceOf
              && sel.symbol.isMethod && (sel.symbol == Any_isInstanceOf || sel.symbol == Any_asInstanceOf)) {
            val mainmod = getMainModifier(qualifier.tpe.extractType,tpelist.head.tpe)
            if (mainmod == peer() || mainmod == rep()) {
  				 		processIsAsInstanceOf(tree1,mainmod)
            }
            else {
              tree1
            }
          }
          else {
            tree1
          }
        case _ =>
          tree1
      }
    }
  }
}
