00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00013 
00014 
00015 #include "comma/ast/AttribDecl.h"
00016 #include "comma/ast/AttribExpr.h"
00017 #include "comma/ast/Expr.h"
00018 
00019 using namespace comma;
00020 using llvm::dyn_cast;
00021 using llvm::cast;
00022 using llvm::isa;
00023 
00024 namespace {
00025 
00034 bool staticDiscreteFunctionValue(const FunctionCallExpr *expr,
00035                                  llvm::APInt &result);
00036 
00041 bool staticDiscreteFunctionAttribValue(const FunctionCallExpr *expr,
00042                                        llvm::APInt &result);
00043 
00055 bool staticDiscretePosAttribValue(const DiscreteType *prefix, const Expr *arg,
00056                                   llvm::APInt &result);
00057 
00069 bool staticDiscreteValAttribValue(const DiscreteType *prefix, const Expr *arg,
00070                                   llvm::APInt &result);
00071 
00082 bool staticDiscreteUnaryValue(PO::PrimitiveID ID,
00083                              const Expr *expr, llvm::APInt &result);
00084 
00097 bool staticDiscreteBinaryValue(PO::PrimitiveID ID,
00098                                const Expr *x, const Expr *y,
00099                                llvm::APInt &result);
00100 
00107 bool staticDiscreteAttribExpr(const AttribExpr *expr, llvm::APInt &result);
00108 
00109 PO::PrimitiveID getCallPrimitive(const FunctionCallExpr *call)
00110 {
00111     if (call->isAmbiguous())
00112         return PO::NotPrimitive;
00113     else {
00114         const FunctionDecl *decl = cast<FunctionDecl>(call->getConnective());
00115         return decl->getPrimitiveID();
00116     }
00117 }
00118 
00119 void signExtend(llvm::APInt &x, llvm::APInt &y);
00120 
00122 inline llvm::APInt &zeroExtend(llvm::APInt &x) {
00123     x.zext(x.getBitWidth() + 1);
00124     return x;
00125 }
00126 
00128 inline llvm::APInt &negate(llvm::APInt &x) {
00129     if (x.isMinSignedValue())
00130         zeroExtend(x);
00131     else {
00132         x.flip();
00133         ++x;
00134     }
00135     return x;
00136 }
00137 
00139 inline llvm::APInt &minimizeWidth(llvm::APInt &x)
00140 {
00141     return x.trunc(x.getMinSignedBits());
00142 }
00143 
00146 llvm::APInt add(llvm::APInt x, llvm::APInt y);
00147 llvm::APInt subtract(llvm::APInt x, llvm::APInt y);
00148 llvm::APInt multiply(llvm::APInt x, llvm::APInt y);
00149 llvm::APInt exponentiate(llvm::APInt x, llvm::APInt y);
00150 
00151 
00152 
00153 
00154 bool staticDiscreteFunctionValue(const FunctionCallExpr *expr,
00155                                  llvm::APInt &result)
00156 {
00157     PO::PrimitiveID ID = getCallPrimitive(expr);
00158 
00159     if (ID == PO::NotPrimitive)
00160         return staticDiscreteFunctionAttribValue(expr, result);
00161 
00162     typedef FunctionCallExpr::const_arg_iterator iterator;
00163     iterator I = expr->begin_arguments();
00164     if (PO::denotesUnaryOp(ID)) {
00165         assert(expr->getNumArgs() == 1);
00166         const Expr *arg = *I;
00167         return staticDiscreteUnaryValue(ID, arg, result);
00168     }
00169     else if (PO::denotesBinaryOp(ID)) {
00170         assert(expr->getNumArgs() == 2);
00171         const Expr *lhs = *I;
00172         const Expr *rhs = *(++I);
00173         return staticDiscreteBinaryValue(ID, lhs, rhs, result);
00174     }
00175     else if (ID == PO::ENUM_op) {
00176         const EnumLiteral *lit = cast<EnumLiteral>(expr->getConnective());
00177         const EnumerationType *enumTy = lit->getReturnType();
00178         unsigned idx = lit->getIndex();
00179         unsigned size = enumTy->getSize();
00180         result = llvm::APInt(size, idx);
00181         return true;
00182     }
00183     else
00184         
00185         return false;
00186 }
00187 
00188 bool staticDiscreteFunctionAttribValue(const FunctionCallExpr *expr,
00189                                        llvm::APInt &result)
00190 {
00191     bool success = false;
00192     const FunctionAttribDecl *decl;
00193 
00194     if (expr->isAmbiguous() ||
00195         !(decl = dyn_cast<FunctionAttribDecl>(expr->getConnective())))
00196         return false;
00197 
00198     switch (decl->getKind()) {
00199 
00200     default:
00201         
00202         success = false;
00203         break;
00204 
00205     case Ast::AST_PosAD: {
00206         const PosAD *attrib = cast<PosAD>(decl);
00207         success = staticDiscretePosAttribValue
00208             (attrib->getPrefix(), *expr->begin_arguments(), result);
00209         break;
00210     }
00211 
00212     case Ast::AST_ValAD: {
00213         const ValAD *attrib = cast<ValAD>(decl);
00214         success = staticDiscreteValAttribValue
00215             (attrib->getPrefix(), *expr->begin_arguments(), result);
00216     }
00217     };
00218 
00219     return success;
00220 }
00221 
00222 bool staticDiscretePosAttribValue(const DiscreteType *prefix, const Expr *arg,
00223                                   llvm::APInt &result)
00224 {
00225     llvm::APInt lower;
00226     llvm::APInt pos;
00227 
00228     
00229     if (const Range *constraint = prefix->getConstraint()) {
00230         if (constraint->hasStaticLowerBound())
00231             lower = constraint->getStaticLowerBound();
00232         else
00233             return false;
00234     }
00235     else
00236         prefix->getLowerLimit(lower);
00237 
00238     
00239     if (!arg->staticDiscreteValue(pos))
00240         return false;
00241 
00242     
00243     result = subtract(pos, lower);
00244     return true;
00245 }
00246 
00247 bool staticDiscreteValAttribValue(const DiscreteType *prefix, const Expr *arg,
00248                                   llvm::APInt &result)
00249 {
00250     llvm::APInt lower;
00251     llvm::APInt val;
00252 
00253     
00254     if (const Range *constraint = prefix->getConstraint()) {
00255         if (constraint->hasStaticLowerBound())
00256             lower = constraint->getStaticLowerBound();
00257         else
00258             return false;
00259     }
00260     else
00261         prefix->getLowerLimit(lower);
00262 
00263     
00264     if (!arg->staticDiscreteValue(val))
00265         return false;
00266 
00267     
00268     result = add(val, lower);
00269     return true;
00270 }
00271 
00272 bool staticDiscreteBinaryValue(PO::PrimitiveID ID,
00273                                const Expr *x, const Expr *y,
00274                                llvm::APInt &result)
00275 {
00276     llvm::APInt LHS, RHS;
00277     if (!x->staticDiscreteValue(LHS) || !y->staticDiscreteValue(RHS))
00278         return false;
00279 
00280     switch (ID) {
00281 
00282     default:
00283         return false;
00284 
00285     case PO::ADD_op:
00286         result = add(LHS, RHS);
00287         break;
00288 
00289     case PO::SUB_op:
00290         result = subtract(LHS, RHS);
00291         break;
00292 
00293     case PO::MUL_op:
00294         result = multiply(LHS, RHS);
00295         break;
00296 
00297     case PO::POW_op:
00298         result = exponentiate(LHS, RHS);
00299         break;
00300     }
00301     return true;
00302 }
00303 
00304 bool staticDiscreteUnaryValue(PO::PrimitiveID ID, const Expr *arg,
00305                               llvm::APInt &result)
00306 {
00307     if (!arg->staticDiscreteValue(result))
00308         return false;
00309 
00310     
00311     
00312     switch (ID) {
00313     default:
00314         assert(false && "Bad primitive ID for a unary operator!");
00315         return false;
00316     case PO::NEG_op:
00317         negate(result);
00318         break;
00319     case PO::POS_op:
00320         break;
00321     }
00322     return true;
00323 }
00324 
00325 bool staticDiscreteAttribExpr(const AttribExpr *expr, llvm::APInt &result)
00326 {
00327     bool status = false;
00328 
00329     
00330     
00331     switch (expr->getKind()) {
00332 
00333     default:
00334         
00335         break;
00336 
00337     case Ast::AST_FirstAE: {
00338         const DiscreteType *intTy = cast<FirstAE>(expr)->getType();
00339         if (const Range *constraint = intTy->getConstraint()) {
00340             if (constraint->hasStaticLowerBound()) {
00341                 result = constraint->getStaticLowerBound();
00342                 status = true;
00343             }
00344         }
00345         else {
00346             intTy->getLowerLimit(result);
00347             status = true;
00348         }
00349         break;
00350     }
00351 
00352     case Ast::AST_LastAE: {
00353         const DiscreteType *intTy = cast<LastAE>(expr)->getType();
00354         if (const Range *constraint = intTy->getConstraint()) {
00355             if (constraint->hasStaticUpperBound()) {
00356                 result = constraint->getStaticUpperBound();
00357                 status = true;
00358             }
00359         }
00360         else {
00361             intTy->getUpperLimit(result);
00362             status = true;
00363         }
00364         break;
00365     }
00366     };
00367     return status;
00368 }
00369 
00370 void signExtend(llvm::APInt &x, llvm::APInt &y)
00371 {
00372     unsigned xWidth = x.getBitWidth();
00373     unsigned yWidth = y.getBitWidth();
00374     unsigned target = std::max(xWidth, yWidth);
00375 
00376     if (xWidth < yWidth)
00377         x.sext(target);
00378     else if (yWidth < xWidth)
00379         y.sext(target);
00380 }
00381 
00382 llvm::APInt add(llvm::APInt x, llvm::APInt y)
00383 {
00384     if (y.isNonNegative()) {
00385         signExtend(x, y);
00386         llvm::APInt result(x + y);
00387 
00388         
00389         if (result.slt(x))
00390             zeroExtend(result);
00391         return result;
00392     }
00393     else {
00394         signExtend(x, negate(y));
00395         llvm::APInt result(x - y);
00396 
00397         
00398         if (result.sgt(x))
00399             zeroExtend(result);
00400         return result;
00401     }
00402 }
00403 
00404 llvm::APInt subtract(llvm::APInt x, llvm::APInt y)
00405 {
00406     return add(x, negate(y));
00407 }
00408 
00409 llvm::APInt multiply(llvm::APInt x, llvm::APInt y)
00410 {
00411     unsigned xWidth = x.getBitWidth();
00412     unsigned yWidth = y.getBitWidth();
00413     unsigned target = 2 * std::max(xWidth, yWidth);
00414     x.sext(target);
00415     y.sext(target);
00416     llvm::APInt result(x * y);
00417     return minimizeWidth(result);
00418 }
00419 
00420 llvm::APInt exponentiate(llvm::APInt x, llvm::APInt y)
00421 {
00422     assert(y.isNonNegative() && "Negative power in exponentiation!");
00423 
00424     if (y == 0) {
00425         x = 1;
00426         return minimizeWidth(x);
00427     }
00428 
00429     llvm::APInt result(x);
00430     while (--y != 0)
00431         result = multiply(result, x);
00432     return result;
00433 }
00434 
00435 } 
00436 
00437 bool Expr::staticDiscreteValue(llvm::APInt &result) const
00438 {
00439     if (const IntegerLiteral *ILit = dyn_cast<IntegerLiteral>(this)) {
00440         result = ILit->getValue();
00441         return true;
00442     }
00443 
00444     if (const FunctionCallExpr *FCall = dyn_cast<FunctionCallExpr>(this))
00445         return staticDiscreteFunctionValue(FCall, result);
00446 
00447     if (const ConversionExpr *CExpr = dyn_cast<ConversionExpr>(this))
00448         return CExpr->getOperand()->staticDiscreteValue(result);
00449 
00450     if (const AttribExpr *AExpr = dyn_cast<AttribExpr>(this))
00451         return staticDiscreteAttribExpr(AExpr, result);
00452 
00453     return false;
00454 }
00455 
00456 bool Expr::isStaticDiscreteExpr() const
00457 {
00458     llvm::APInt tmp;
00459     return staticDiscreteValue(tmp);
00460 }
00461 
00462 bool Expr::staticStringValue(std::string &result) const
00463 {
00464     
00465     if (const StringLiteral *lit = dyn_cast<StringLiteral>(this)) {
00466         result = lit->getString().str();
00467         return true;
00468     }
00469     return false;
00470 }
00471 
00472 bool Expr::isStaticStringExpr() const
00473 {
00474     return isa<StringLiteral>(this);
00475 }
00476