00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 #include "comma/ast/AstResource.h"
00010 #include "comma/ast/AstRewriter.h"
00011 #include "comma/ast/Decl.h"
00012 #include "comma/ast/Type.h"
00013 
00014 #include "llvm/Support/Casting.h"
00015 #include "llvm/ADT/SmallVector.h"
00016 
00017 using namespace comma;
00018 using llvm::dyn_cast;
00019 using llvm::dyn_cast_or_null;
00020 using llvm::cast;
00021 using llvm::isa;
00022 
00023 Type *AstRewriter::findRewrite(Type *source) const
00024 {
00025     RewriteMap::const_iterator iter = rewrites.find(source);
00026     if (iter == rewrites.end())
00027         return 0;
00028     return iter->second;
00029 }
00030 
00031 Type *AstRewriter::getRewrite(Type *source) const
00032 {
00033     if (Type *res = findRewrite(source))
00034         return res;
00035     return source;
00036 }
00037 
00038 void AstRewriter::installRewrites(DomainType *context)
00039 {
00040     
00041     
00042     if (DomainInstanceDecl *instance = context->getInstanceDecl()) {
00043         if (FunctorDecl *functor = instance->getDefiningFunctor()) {
00044             unsigned arity = instance->getArity();
00045             for (unsigned i = 0; i < arity; ++i) {
00046                 DomainType *formal = functor->getFormalType(i);
00047                 Type *actual = instance->getActualParamType(i);
00048                 rewrites[formal] = actual;
00049             }
00050         }
00051     }
00052 }
00053 
00054 void AstRewriter::installRewrites(SigInstanceDecl *context)
00055 {
00056     VarietyDecl *variety = context->getVariety();
00057 
00058     if (variety) {
00059         unsigned arity = variety->getArity();
00060         for (unsigned i = 0; i < arity; ++i) {
00061             DomainType *formal = variety->getFormalType(i);
00062             Type *actual = context->getActualParamType(i);
00063             addTypeRewrite(formal, actual);
00064         }
00065     }
00066 }
00067 
00068 Type *AstRewriter::rewriteType(Type *type) const
00069 {
00070     if (Type *result = findRewrite(type))
00071         return result;
00072 
00073     switch(type->getKind()) {
00074 
00075     default: return type;
00076 
00077     case Ast::AST_DomainType:
00078         return rewriteType(cast<DomainType>(type));
00079     case Ast::AST_FunctionType:
00080         return rewriteType(cast<FunctionType>(type));
00081     case Ast::AST_ProcedureType:
00082         return rewriteType(cast<ProcedureType>(type));
00083     }
00084 }
00085 
00086 SigInstanceDecl *AstRewriter::rewriteSigInstance(SigInstanceDecl *sig) const
00087 {
00088     if (sig->isParameterized()) {
00089         llvm::SmallVector<DomainTypeDecl*, 4> args;
00090         SigInstanceDecl::arg_iterator iter;
00091         SigInstanceDecl::arg_iterator endIter = sig->endArguments();
00092         for (iter = sig->beginArguments(); iter != endIter; ++iter) {
00093             
00094             
00095             DomainType *argTy = rewriteType((*iter)->getType());
00096             args.push_back(argTy->getDomainTypeDecl());
00097         }
00098         
00099         VarietyDecl *decl = sig->getVariety();
00100         return decl->getInstance(&args[0], args.size());
00101     }
00102     return sig;
00103 }
00104 
00105 DomainType *AstRewriter::rewriteType(DomainType *dom) const
00106 {
00107     if (DomainType *result = dyn_cast_or_null<DomainType>(findRewrite(dom)))
00108         return result;
00109 
00110     if (DomainInstanceDecl *instance = dom->getInstanceDecl()) {
00111         if (FunctorDecl *functor = instance->getDefiningFunctor()) {
00112             typedef DomainInstanceDecl::arg_iterator iterator;
00113             llvm::SmallVector<DomainTypeDecl*, 4> args;
00114             iterator iter;
00115             iterator endIter = instance->endArguments();
00116             for (iter = instance->beginArguments(); iter != endIter; ++iter) {
00117                 
00118                 
00119                 DomainType *argTy = rewriteType((*iter)->getType());
00120                 args.push_back(argTy->getDomainTypeDecl());
00121             }
00122             
00123             instance = functor->getInstance(&args[0], args.size());
00124             return instance->getType();
00125         }
00126     }
00127     return dom;
00128 }
00129 
00130 SubroutineType *AstRewriter::rewriteType(SubroutineType *srType) const
00131 {
00132     if (ProcedureType *ptype = dyn_cast<ProcedureType>(srType))
00133         return rewriteType(ptype);
00134 
00135     return rewriteType(cast<FunctionType>(srType));
00136 }
00137 
00138 void AstRewriter::rewriteParameters(SubroutineType *srType,
00139                                     unsigned count, Type **params) const
00140 {
00141     for (unsigned i = 0; i < count; ++i)
00142         params[i] = getRewrite(srType->getArgType(i));
00143 }
00144 
00145 FunctionType *AstRewriter::rewriteType(FunctionType *ftype) const
00146 {
00147     unsigned arity = ftype->getArity();
00148     Type *returnType = getRewrite(ftype->getReturnType());
00149     Type *paramTypes[arity];
00150 
00151     rewriteParameters(ftype, arity, paramTypes);
00152     return resource.getFunctionType(paramTypes, arity, returnType);
00153 }
00154 
00155 ProcedureType *AstRewriter::rewriteType(ProcedureType *ptype) const
00156 {
00157     unsigned arity = ptype->getArity();
00158     Type *paramTypes[arity];
00159 
00160     rewriteParameters(ptype, arity, paramTypes);
00161     return resource.getProcedureType(paramTypes, arity);
00162 }