00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 #include "CGContext.h"
00010 #include "CodeGen.h"
00011 #include "CodeGenTypes.h"
00012 #include "CommaRT.h"
00013 #include "comma/ast/Decl.h"
00014 
00015 #include "llvm/DerivedTypes.h"
00016 #include "llvm/Target/TargetData.h"
00017 
00018 using namespace comma;
00019 
00020 using llvm::dyn_cast;
00021 using llvm::cast;
00022 using llvm::isa;
00023 
00024 namespace {
00025 
00029 uint64_t getArrayWidth(const llvm::APInt &low, const llvm::APInt &high,
00030                        bool isSigned)
00031 {
00032     llvm::APInt lower(low);
00033     llvm::APInt upper(high);
00034 
00035     
00036     if (isSigned) {
00037         if (upper.slt(lower))
00038             return 0;
00039     }
00040     else {
00041         if (upper.ult(lower))
00042             return 0;
00043     }
00044 
00045     
00046     
00047     unsigned width = std::max(lower.getBitWidth(), upper.getBitWidth()) + 1;
00048 
00049     if (isSigned) {
00050         lower.sext(width);
00051         upper.sext(width);
00052     }
00053     else {
00054         lower.zext(width);
00055         upper.zext(width);
00056     }
00057 
00058     llvm::APInt range(upper);
00059     range -= lower;
00060     range++;
00061 
00062     
00063     assert(range.getActiveBits() <= 64 && "Index too wide for array type!");
00064     return range.getZExtValue();
00065 }
00066 
00067 } 
00068 
00069 unsigned CodeGenTypes::getTypeAlignment(const llvm::Type *type) const
00070 {
00071     return CG.getTargetData().getABITypeAlignment(type);
00072 }
00073 
00074 uint64_t CodeGenTypes::getTypeSize(const llvm::Type *type) const
00075 {
00076     return CG.getTargetData().getTypeStoreSize(type);
00077 }
00078 
00079 const llvm::Type *CodeGenTypes::lowerType(const Type *type)
00080 {
00081     switch (type->getKind()) {
00082 
00083     default:
00084         assert(false && "Cannot lower the given Type!");
00085         return 0;
00086 
00087     case Ast::AST_DomainType:
00088         return lowerDomainType(cast<DomainType>(type));
00089 
00090     case Ast::AST_EnumerationType:
00091     case Ast::AST_IntegerType:
00092         return lowerDiscreteType(cast<DiscreteType>(type));
00093 
00094     case Ast::AST_ArrayType:
00095         return lowerArrayType(cast<ArrayType>(type));
00096 
00097     case Ast::AST_RecordType:
00098         return lowerRecordType(cast<RecordType>(type));
00099 
00100     case Ast::AST_AccessType:
00101         return lowerAccessType(cast<AccessType>(type));
00102 
00103     case Ast::AST_IncompleteType:
00104         return lowerIncompleteType(cast<IncompleteType>(type));
00105 
00106     case Ast::AST_UniversalType:
00107         return lowerUniversalType(cast<UniversalType>(type));
00108     }
00109 }
00110 
00111 void CodeGenTypes::addInstanceRewrites(const DomainInstanceDecl *instance)
00112 {
00113     const FunctorDecl *functor = instance->getDefiningFunctor();
00114     if (!functor)
00115         return;
00116 
00117     unsigned arity = functor->getArity();
00118     for (unsigned i = 0; i < arity; ++i) {
00119         const Type *key = functor->getFormalType(i);
00120         const Type *value = instance->getActualParamType(i);
00121         rewrites.insert(key, value);
00122     }
00123 }
00124 
00125 const DomainType *
00126 CodeGenTypes::rewriteAbstractDecl(const AbstractDomainDecl *abstract)
00127 {
00128     typedef RewriteMap::iterator iterator;
00129     iterator I = rewrites.begin(abstract->getType());
00130     assert(I != rewrites.end() && "Could not resolve abstract type!");
00131     return cast<DomainType>(*I);
00132 }
00133 
00134 const Type *CodeGenTypes::resolveType(const Type *type)
00135 {
00136     if (const DomainType *domTy = dyn_cast<DomainType>(type)) {
00137 
00138         if (const AbstractDomainDecl *decl = domTy->getAbstractDecl())
00139             return resolveType(rewriteAbstractDecl(decl));
00140 
00141         const DomainInstanceDecl *instance;
00142         if (const PercentDecl *percent = domTy->getPercentDecl()) {
00143             
00144             
00145             
00146             assert(percent->getDefinition() == context->getDefinition());
00147             ((void*)percent);
00148             instance = context;
00149         }
00150         else
00151             instance = domTy->getInstanceDecl();
00152 
00153         if (instance->isParameterized() && instance->isDependent()) {
00154             RewriteScope scope(rewrites);
00155             addInstanceRewrites(instance);
00156             return resolveType(instance->getRepresentationType());
00157         }
00158         else
00159             return resolveType(instance->getRepresentationType());
00160     }
00161     else if (const IncompleteType *IT = dyn_cast<IncompleteType>(type))
00162         return resolveType(IT->getCompleteType());
00163 
00164     return type;
00165 }
00166 
00167 const llvm::Type *CodeGenTypes::lowerDomainType(const DomainType *type)
00168 {
00169     const llvm::Type *entry = 0;
00170 
00171     if (type->isAbstract())
00172         type = rewriteAbstractDecl(type->getAbstractDecl());
00173 
00174     if (const PercentDecl *percent = type->getPercentDecl()) {
00175         assert(percent->getDefinition() == context->getDefinition() &&
00176                "Inconsistent context for PercentDecl!");
00177         ((void*)percent);
00178         entry = lowerType(context->getRepresentationType());
00179     }
00180     else {
00181         const DomainInstanceDecl *instance = type->getInstanceDecl();
00182         if (instance->isParameterized()) {
00183             RewriteScope scope(rewrites);
00184             addInstanceRewrites(instance);
00185             entry = lowerType(instance->getRepresentationType());
00186         }
00187         else
00188             entry = lowerType(instance->getRepresentationType());
00189     }
00190     return entry;
00191 }
00192 
00193 const llvm::FunctionType *
00194 CodeGenTypes::lowerSubroutine(const SubroutineDecl *decl)
00195 {
00196     std::vector<const llvm::Type*> args;
00197     const llvm::Type *retTy = 0;
00198 
00200     if (const FunctionDecl *fdecl = dyn_cast<FunctionDecl>(decl)) {
00201 
00202         const Type *targetTy = fdecl->getReturnType();
00203 
00204         switch (getConvention(decl)) {
00205 
00206         case CC_Simple:
00207             
00208             retTy = lowerType(targetTy);
00209             break;
00210 
00211         case CC_Sret: {
00212             
00213             
00214             const llvm::Type *sretTy = lowerType(fdecl->getReturnType());
00215             args.push_back(sretTy->getPointerTo());
00216             retTy = CG.getVoidTy();
00217             break;
00218         }
00219 
00220         case CC_Vstack:
00221             
00222             
00223             retTy = CG.getVoidTy();
00224             break;
00225         }
00226     }
00227     else
00228         retTy = CG.getVoidTy();
00229 
00230     
00231     
00232     
00233     
00234     
00235     if (!decl->findPragma(pragma::Import))
00236         args.push_back(CG.getRuntime().getType<CommaRT::CRT_DomainInstance>());
00237 
00238     SubroutineDecl::const_param_iterator I = decl->begin_params();
00239     SubroutineDecl::const_param_iterator E = decl->end_params();
00240     for ( ; I != E; ++I) {
00241         const ParamValueDecl *param = *I;
00242         const Type *paramTy = resolveType(param->getType());
00243         const llvm::Type *loweredTy = lowerType(paramTy);
00244 
00245         if (const CompositeType *compTy = dyn_cast<CompositeType>(paramTy)) {
00246             
00247             
00248             
00249             args.push_back(loweredTy->getPointerTo());
00250 
00251             
00252             
00253             if (const ArrayType *arrTy = dyn_cast<ArrayType>(compTy)) {
00254                 if (!compTy->isConstrained())
00255                     args.push_back(lowerArrayBounds(arrTy)->getPointerTo());
00256             }
00257         }
00258         else if (paramTy->isFatAccessType()) {
00259             
00260             
00261             args.push_back(loweredTy->getPointerTo());
00262         }
00263         else {
00264             
00265             
00266             PM::ParameterMode mode = param->getParameterMode();
00267             if (mode == PM::MODE_OUT or mode == PM::MODE_IN_OUT)
00268                 loweredTy = loweredTy->getPointerTo();
00269             args.push_back(loweredTy);
00270         }
00271     }
00272 
00273     return llvm::FunctionType::get(retTy, args, false);
00274 }
00275 
00276 const llvm::IntegerType *CodeGenTypes::lowerDiscreteType(const DiscreteType *type)
00277 {
00278     
00279     
00280     
00281     return getTypeForWidth(type->getSize());
00282 }
00283 
00284 const llvm::ArrayType *CodeGenTypes::lowerArrayType(const ArrayType *type)
00285 {
00286     assert(type->getRank() == 1 &&
00287            "Cannot codegen multidimensional arrays yet!");
00288 
00289     const llvm::Type *elementTy = lowerType(type->getComponentType());
00290 
00291     
00292     
00293     if (!type->isConstrained())
00294         return llvm::ArrayType::get(elementTy, 0);
00295 
00296     const DiscreteType *idxTy = type->getIndexType(0);
00297 
00298     
00299     
00300     
00301     llvm::APInt lowerBound(idxTy->getSize(), 0);
00302     llvm::APInt upperBound(idxTy->getSize(), 0);
00303     if (const IntegerType *subTy = dyn_cast<IntegerType>(idxTy)) {
00304         if (subTy->isConstrained()) {
00305             if (subTy->isStaticallyConstrained()) {
00306                 const Range *range = subTy->getConstraint();
00307                 lowerBound = range->getStaticLowerBound();
00308                 upperBound = range->getStaticUpperBound();
00309             }
00310             else
00311                 return llvm::ArrayType::get(elementTy, 0);
00312         }
00313         else {
00314             const IntegerType *rootTy = subTy->getRootType();
00315             rootTy->getLowerLimit(lowerBound);
00316             rootTy->getUpperLimit(upperBound);
00317         }
00318     }
00319     else {
00320         
00321         const EnumerationType *enumTy = cast<EnumerationType>(idxTy);
00322         assert(enumTy && "Unexpected array index type!");
00323         lowerBound = 0;
00324         upperBound = enumTy->getNumLiterals() - 1;
00325     }
00326 
00327     uint64_t numElems;
00328     numElems = getArrayWidth(lowerBound, upperBound, idxTy->isSigned());
00329     const llvm::ArrayType *result = llvm::ArrayType::get(elementTy, numElems);
00330     return result;
00331 }
00332 
00333 const llvm::StructType *CodeGenTypes::lowerRecordType(const RecordType *recTy)
00334 {
00335     unsigned maxAlignment = 0;
00336     uint64_t currentOffset = 0;
00337     uint64_t requiredOffset = 0;
00338     unsigned currentIndex = 0;
00339     std::vector<const llvm::Type*> fields;
00340 
00341     const RecordDecl *recDecl = recTy->getDefiningDecl();
00342     for (unsigned i = 0; i < recDecl->numComponents(); ++i) {
00343         const ComponentDecl *componentDecl = recDecl->getComponent(i);
00344         const llvm::Type *componentTy = lowerType(componentDecl->getType());
00345         unsigned alignment = getTypeAlignment(componentTy);
00346         requiredOffset = llvm::TargetData::RoundUpAlignment(currentOffset,
00347                                                             alignment);
00348         maxAlignment = std::max(maxAlignment, alignment);
00349 
00350         
00351         
00352         while (currentOffset < requiredOffset) {
00353             fields.push_back(CG.getInt8Ty());
00354             currentOffset++;
00355             currentIndex++;
00356         }
00357         fields.push_back(componentTy);
00358         ComponentIndices[componentDecl] = currentIndex;
00359         currentOffset = requiredOffset + getTypeSize(componentTy);
00360         currentIndex++;
00361     }
00362 
00363     
00364     requiredOffset = llvm::TargetData::RoundUpAlignment(currentOffset,
00365                                                         maxAlignment);
00366     while (currentOffset < requiredOffset) {
00367         fields.push_back(CG.getInt8Ty());
00368         currentOffset++;
00369     }
00370 
00371     const llvm::StructType *result;
00372     result = llvm::StructType::get(CG.getLLVMContext(), fields);
00373     return result;
00374 }
00375 
00376 const llvm::Type *CodeGenTypes::lowerIncompleteType(const IncompleteType *type)
00377 {
00378     return lowerType(type->getCompleteType());
00379 }
00380 
00381 const llvm::PointerType *
00382 CodeGenTypes::lowerThinAccessType(const AccessType *access)
00383 {
00384     assert(access->isThinAccessType() && "Use lowerFatAccessType instead!");
00385 
00386     
00387     
00388     
00389     
00390     
00391     
00392     
00393     
00394     
00395     
00396     const llvm::Type *&entry = getLoweredType(access);
00397     if (entry)
00398         return llvm::cast<llvm::PointerType>(entry);
00399 
00400     llvm::PATypeHolder holder = llvm::OpaqueType::get(CG.getLLVMContext());
00401     const llvm::PointerType *barrier = llvm::PointerType::getUnqual(holder);
00402 
00403     entry = barrier;
00404     const llvm::Type *targetType = lowerType(access->getTargetType());
00405 
00406     
00407     
00408     cast<llvm::OpaqueType>(holder.get())->refineAbstractTypeTo(targetType);
00409     const llvm::PointerType *result = holder.get()->getPointerTo();
00410     entry = result;
00411 
00412     return result;
00413 }
00414 
00415 const llvm::StructType *
00416 CodeGenTypes::lowerFatAccessType(const AccessType *access)
00417 {
00418     assert(access->isFatAccessType() && "Use lowerThinAccessType instead!");
00419 
00420     const llvm::Type *&entry = getLoweredType(access);
00421     if (entry)
00422         return llvm::cast<llvm::StructType>(entry);
00423 
00424     
00425     
00426     llvm::PATypeHolder holder = llvm::OpaqueType::get(CG.getLLVMContext());
00427     const llvm::PointerType *barrier = llvm::PointerType::getUnqual(holder);
00428 
00429     entry = barrier;
00430     const llvm::Type *targetType = lowerType(access->getTargetType());
00431 
00432     
00433     
00434     cast<llvm::OpaqueType>(holder.get())->refineAbstractTypeTo(targetType);
00435     const llvm::PointerType *pointerTy = holder.get()->getPointerTo();
00436 
00437     
00438     
00439     const ArrayType *indefiniteTy = cast<ArrayType>(access->getTargetType());
00440     const llvm::StructType *boundsTy = lowerArrayBounds(indefiniteTy);
00441 
00442     
00443     std::vector<const llvm::Type*> elements;
00444     elements.push_back(pointerTy);
00445     elements.push_back(boundsTy);
00446 
00447     const llvm::StructType *result;
00448     result = llvm::StructType::get(CG.getLLVMContext(), elements);
00449     entry = result;
00450     return result;
00451 }
00452 
00453 const llvm::Type *CodeGenTypes::lowerAccessType(const AccessType *type)
00454 {
00455     if (type->isThinAccessType())
00456         return lowerThinAccessType(type);
00457     else
00458         return lowerFatAccessType(type);
00459 }
00460 
00461 const llvm::Type *CodeGenTypes::lowerUniversalType(const UniversalType *type)
00462 {
00463     if (type->isUniversalIntegerType()) {
00464         
00465         
00466         
00467         
00468         
00469         
00470         
00471         
00472         
00473         
00474         
00475         
00476         
00477         return CG.getIntPtrTy();
00478     }
00479     else {
00480         assert(false && "Cannot lower the given universal type.");
00481     }
00482     return 0;
00483 }
00484 
00485 unsigned CodeGenTypes::getComponentIndex(const ComponentDecl *component)
00486 {
00487     ComponentIndexMap::iterator I = ComponentIndices.find(component);
00488 
00489     
00490     
00491     if (I == ComponentIndices.end()) {
00492         lowerRecordType(component->getDeclRegion()->getType());
00493         I = ComponentIndices.find(component);
00494     }
00495 
00496     assert (I != ComponentIndices.end()  && "Component index does not exist!");
00497     return I->second;
00498 }
00499 
00500 const llvm::StructType *CodeGenTypes::lowerArrayBounds(const ArrayType *arrTy)
00501 {
00502     std::vector<const llvm::Type*> elts;
00503     const ArrayType *baseTy = arrTy->getRootType();
00504 
00505     for (unsigned i = 0; i < baseTy->getRank(); ++i) {
00506         const llvm::Type *boundTy = lowerType(baseTy->getIndexType(i));
00507         elts.push_back(boundTy);
00508         elts.push_back(boundTy);
00509     }
00510 
00511     return CG.getStructTy(elts);
00512 }
00513 
00514 const llvm::StructType *
00515 CodeGenTypes::lowerScalarBounds(const DiscreteType *type)
00516 {
00517     std::vector<const llvm::Type*> elts;
00518     const llvm::Type *elemTy = lowerType(type);
00519     elts.push_back(elemTy);
00520     elts.push_back(elemTy);
00521     return CG.getStructTy(elts);
00522 }
00523 
00524 const llvm::StructType *CodeGenTypes::lowerRange(const Range *range)
00525 {
00526     std::vector<const llvm::Type*> elts;
00527     const llvm::Type *elemTy = lowerType(range->getType());
00528     elts.push_back(elemTy);
00529     elts.push_back(elemTy);
00530     return CG.getStructTy(elts);
00531 }
00532 
00533 const llvm::IntegerType *CodeGenTypes::getTypeForWidth(unsigned numBits)
00534 {
00535     
00536     if (numBits <= 1)
00537         return CG.getInt1Ty();
00538     if (numBits <= 8)
00539         return CG.getInt8Ty();
00540     else if (numBits <= 16)
00541         return CG.getInt16Ty();
00542     else if (numBits <= 32)
00543         return CG.getInt32Ty();
00544     else if (numBits <= 64)
00545         return CG.getInt64Ty();
00546     else {
00547         
00548         assert(false && "Bit size too large to codegen!");
00549         return 0;
00550     }
00551 }
00552 
00553 CodeGenTypes::CallConvention
00554 CodeGenTypes::getConvention(const SubroutineDecl *decl)
00555 {
00556     const FunctionDecl *fdecl = dyn_cast<FunctionDecl>(decl);
00557 
00558     
00559     
00560     if (!fdecl) return CC_Simple;
00561 
00562     const PrimaryType *targetTy;
00563     targetTy = dyn_cast<PrimaryType>(resolveType(fdecl->getReturnType()));
00564 
00565     
00566     if (!targetTy)
00567         return CC_Simple;
00568 
00569     
00570     
00571     if (targetTy->isFatAccessType())
00572         return CC_Sret;
00573 
00574     
00575     if (targetTy->isCompositeType() && targetTy->isUnconstrained())
00576         return CC_Vstack;
00577 
00578     
00579     if (targetTy->isCompositeType())
00580         return CC_Sret;
00581 
00582     
00583     return CC_Simple;
00584 }
00585