LCOV - code coverage report
Current view: top level - buildbot/coverage/build/src/theory/arith - bound_inference.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 89 128 69.5 %
Date: 2025-02-16 11:43:33 Functions: 8 11 72.7 %
Branches: 72 111 64.9 %

           Branch data     Line data    Source code
       1                 :            : /******************************************************************************
       2                 :            :  * Top contributors (to current version):
       3                 :            :  *   Gereon Kremer, Andrew Reynolds, Andres Noetzli
       4                 :            :  *
       5                 :            :  * This file is part of the cvc5 project.
       6                 :            :  *
       7                 :            :  * Copyright (c) 2009-2025 by the authors listed in the file AUTHORS
       8                 :            :  * in the top-level source directory and their institutional affiliations.
       9                 :            :  * All rights reserved.  See the file COPYING in the top-level source
      10                 :            :  * directory for licensing information.
      11                 :            :  * ****************************************************************************
      12                 :            :  *
      13                 :            :  * Extract bounds on variables from theory atoms.
      14                 :            :  */
      15                 :            : 
      16                 :            : #include "theory/arith/bound_inference.h"
      17                 :            : 
      18                 :            : #include "smt/env.h"
      19                 :            : #include "theory/arith/arith_utilities.h"
      20                 :            : #include "theory/arith/linear/normal_form.h"
      21                 :            : #include "theory/rewriter.h"
      22                 :            : 
      23                 :            : using namespace cvc5::internal::kind;
      24                 :            : 
      25                 :            : namespace cvc5::internal {
      26                 :            : namespace theory {
      27                 :            : namespace arith {
      28                 :            : 
      29                 :          0 : std::ostream& operator<<(std::ostream& os, const Bounds& b) {
      30         [ -  - ]:          0 :   return os << (b.lower_strict ? '(' : '[') << b.lower_value << " .. "
      31         [ -  - ]:          0 :             << b.upper_value << (b.upper_strict ? ')' : ']');
      32                 :            : }
      33                 :            : 
      34                 :      45231 : BoundInference::BoundInference(Env& env) : EnvObj(env) {}
      35                 :            : 
      36                 :         10 : void BoundInference::reset() { d_bounds.clear(); }
      37                 :            : 
      38                 :     428187 : Bounds& BoundInference::get_or_add(const Node& lhs)
      39                 :            : {
      40                 :     428187 :   auto it = d_bounds.find(lhs);
      41         [ +  + ]:     428187 :   if (it == d_bounds.end())
      42                 :            :   {
      43                 :     193522 :     it = d_bounds.emplace(lhs, Bounds()).first;
      44                 :            :   }
      45                 :     428187 :   return it->second;
      46                 :            : }
      47                 :         40 : Bounds BoundInference::get(const Node& lhs) const
      48                 :            : {
      49                 :         40 :   auto it = d_bounds.find(lhs);
      50         [ +  + ]:         40 :   if (it == d_bounds.end())
      51                 :            :   {
      52                 :         25 :     return Bounds();
      53                 :            :   }
      54                 :         15 :   return it->second;
      55                 :            : }
      56                 :            : 
      57                 :      12256 : const std::map<Node, Bounds>& BoundInference::get() const { return d_bounds; }
      58                 :     498984 : bool BoundInference::add(const Node& n, bool onlyVariables)
      59                 :            : {
      60                 :     997968 :   Node tmp = rewrite(n);
      61         [ -  + ]:     498984 :   if (tmp.getKind() == Kind::CONST_BOOLEAN)
      62                 :            :   {
      63                 :          0 :     return false;
      64                 :            :   }
      65                 :            :   // Parse the node as a comparison
      66                 :     997968 :   auto comp = linear::Comparison::parseNormalForm(tmp);
      67                 :     997968 :   auto dec = comp.decompose(true);
      68 [ +  + ][ +  + ]:     498984 :   if (onlyVariables && !std::get<0>(dec).isVariable())
                 [ +  + ]
      69                 :            :   {
      70                 :         34 :     return false;
      71                 :            :   }
      72                 :            : 
      73                 :     997900 :   Node lhs = std::get<0>(dec).getNode();
      74                 :     498950 :   Kind relation = std::get<1>(dec);
      75         [ +  + ]:     498950 :   if (relation == Kind::DISTINCT) return false;
      76                 :     389469 :   Node bound = std::get<2>(dec).getNode();
      77                 :            :   // has the form  lhs  ~relation~  bound
      78                 :            : 
      79         [ +  + ]:     389469 :   if (lhs.getType().isInteger())
      80                 :            :   {
      81                 :     238245 :     Rational br = bound.getConst<Rational>();
      82                 :     238245 :     auto* nm = nodeManager();
      83 [ -  + ][ -  + ]:     238245 :     switch (relation)
                    [ + ]
      84                 :            :     {
      85                 :          0 :       case Kind::LEQ: bound = nm->mkConstInt(br.floor()); break;
      86                 :      79001 :       case Kind::LT:
      87                 :      79001 :         bound = nm->mkConstInt((br - 1).ceiling());
      88                 :      79001 :         relation = Kind::LEQ;
      89                 :      79001 :         break;
      90                 :          0 :       case Kind::GT:
      91                 :          0 :         bound = nm->mkConstInt((br + 1).floor());
      92                 :          0 :         relation = Kind::GEQ;
      93                 :          0 :         break;
      94                 :     133326 :       case Kind::GEQ: bound = nm->mkConstInt(br.ceiling()); break;
      95                 :      25918 :       default:
      96                 :            :         // always ensure integer
      97                 :      25918 :         bound = nm->mkConstInt(br);
      98                 :      25918 :         break;
      99                 :            :     }
     100         [ +  - ]:     476490 :     Trace("bound-inf") << "Strengthened " << n << " to " << lhs << " "
     101                 :     238245 :                        << relation << " " << bound << std::endl;
     102                 :            :   }
     103                 :            : 
     104 [ +  + ][ +  + ]:     389469 :   switch (relation)
                 [ +  - ]
     105                 :            :   {
     106                 :     101877 :     case Kind::LEQ: update_upper_bound(n, lhs, bound, false); break;
     107                 :      42599 :     case Kind::LT: update_upper_bound(n, lhs, bound, true); break;
     108                 :      38718 :     case Kind::EQUAL:
     109                 :      38718 :       update_lower_bound(n, lhs, bound, false);
     110                 :      38718 :       update_upper_bound(n, lhs, bound, false);
     111                 :      38718 :       break;
     112                 :      30112 :     case Kind::GT: update_lower_bound(n, lhs, bound, true); break;
     113                 :     176163 :     case Kind::GEQ: update_lower_bound(n, lhs, bound, false); break;
     114                 :          0 :     default: Assert(false);
     115                 :            :   }
     116                 :     389469 :   return true;
     117                 :            : }
     118                 :            : 
     119                 :          0 : void BoundInference::replaceByOrigins(std::vector<Node>& nodes) const
     120                 :            : {
     121                 :          0 :   std::vector<Node> toAdd;
     122         [ -  - ]:          0 :   for (auto& n : nodes)
     123                 :            :   {
     124         [ -  - ]:          0 :     for (const auto& b : d_bounds)
     125                 :            :     {
     126                 :          0 :       if (n == b.second.lower_bound && n == b.second.upper_bound)
     127                 :            :       {
     128                 :          0 :         if (n != b.second.lower_origin && n != b.second.upper_origin)
     129                 :            :         {
     130         [ -  - ]:          0 :           Trace("bound-inf")
     131                 :          0 :               << "Replace " << n << " by origins " << b.second.lower_origin
     132                 :          0 :               << " and " << b.second.upper_origin << std::endl;
     133                 :          0 :           n = b.second.lower_origin;
     134                 :          0 :           toAdd.emplace_back(b.second.upper_origin);
     135                 :            :         }
     136                 :            :       }
     137         [ -  - ]:          0 :       else if (n == b.second.lower_bound)
     138                 :            :       {
     139         [ -  - ]:          0 :         if (n != b.second.lower_origin)
     140                 :            :         {
     141         [ -  - ]:          0 :           Trace("bound-inf") << "Replace " << n << " by origin "
     142                 :          0 :                              << b.second.lower_origin << std::endl;
     143                 :          0 :           n = b.second.lower_origin;
     144                 :            :         }
     145                 :            :       }
     146         [ -  - ]:          0 :       else if (n == b.second.upper_bound)
     147                 :            :       {
     148         [ -  - ]:          0 :         if (n != b.second.upper_origin)
     149                 :            :         {
     150         [ -  - ]:          0 :           Trace("bound-inf") << "Replace " << n << " by origin "
     151                 :          0 :                              << b.second.upper_origin << std::endl;
     152                 :          0 :           n = b.second.upper_origin;
     153                 :            :         }
     154                 :            :       }
     155                 :            :     }
     156                 :            :   }
     157                 :          0 :   nodes.insert(nodes.end(), toAdd.begin(), toAdd.end());
     158                 :          0 : }
     159                 :            : 
     160                 :     244993 : void BoundInference::update_lower_bound(const Node& origin,
     161                 :            :                                         const Node& lhs,
     162                 :            :                                         const Node& value,
     163                 :            :                                         bool strict)
     164                 :            : {
     165 [ -  + ][ -  + ]:     244993 :   Assert(value.isConst());
                 [ -  - ]
     166                 :            :   // lhs > or >= value because of origin
     167 [ +  - ][ -  - ]:     489986 :   Trace("bound-inf") << "\tNew bound " << lhs << (strict ? ">" : ">=") << value
     168                 :     244993 :                      << " due to " << origin << std::endl;
     169                 :     244993 :   Bounds& b = get_or_add(lhs);
     170                 :     244993 :   if (b.lower_value.isNull()
     171 [ +  + ][ +  + ]:     244993 :       || b.lower_value.getConst<Rational>() < value.getConst<Rational>())
                 [ +  + ]
     172                 :            :   {
     173                 :     155300 :     auto* nm = nodeManager();
     174                 :     155300 :     b.lower_value = value;
     175                 :     155300 :     b.lower_strict = strict;
     176                 :            : 
     177                 :     155300 :     b.lower_origin = origin;
     178                 :            : 
     179 [ +  + ][ +  + ]:     155300 :     if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value)
         [ +  + ][ +  + ]
     180                 :            :     {
     181                 :       2768 :       Node eq = mkEquality(lhs, value);
     182                 :       2768 :       b.lower_bound = b.upper_bound = rewrite(eq);
     183                 :            :     }
     184                 :            :     else
     185                 :            :     {
     186                 :            :       b.lower_bound =
     187         [ +  + ]:     152532 :           rewrite(nm->mkNode(strict ? Kind::GT : Kind::GEQ, lhs, value));
     188                 :            :     }
     189                 :            :   }
     190 [ +  + ][ +  + ]:      89693 :   else if (strict && b.lower_value == value)
                 [ +  + ]
     191                 :            :   {
     192                 :       4318 :     auto* nm = nodeManager();
     193                 :       4318 :     b.lower_strict = strict;
     194                 :       4318 :     b.lower_bound = rewrite(nm->mkNode(Kind::GT, lhs, value));
     195                 :       4318 :     b.lower_origin = origin;
     196                 :            :   }
     197                 :     244993 : }
     198                 :     183194 : void BoundInference::update_upper_bound(const Node& origin,
     199                 :            :                                         const Node& lhs,
     200                 :            :                                         const Node& value,
     201                 :            :                                         bool strict)
     202                 :            : {
     203                 :            :   // lhs < or <= value because of origin
     204 [ +  - ][ -  - ]:     366388 :   Trace("bound-inf") << "\tNew bound " << lhs << (strict ? "<" : "<=") << value
     205                 :     183194 :                      << " due to " << origin << std::endl;
     206                 :     183194 :   Bounds& b = get_or_add(lhs);
     207                 :     183194 :   if (b.upper_value.isNull()
     208 [ +  + ][ +  + ]:     183194 :       || b.upper_value.getConst<Rational>() > value.getConst<Rational>())
                 [ +  + ]
     209                 :            :   {
     210                 :     126022 :     auto* nm = nodeManager();
     211                 :     126022 :     b.upper_value = value;
     212                 :     126022 :     b.upper_strict = strict;
     213                 :     126022 :     b.upper_origin = origin;
     214 [ +  + ][ +  + ]:     126022 :     if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value)
         [ +  + ][ +  + ]
     215                 :            :     {
     216                 :      37231 :       Node eq = mkEquality(lhs, value);
     217                 :      37231 :       b.lower_bound = b.upper_bound = rewrite(eq);
     218                 :            :     }
     219                 :            :     else
     220                 :            :     {
     221                 :            :       b.upper_bound =
     222         [ +  + ]:      88791 :           rewrite(nm->mkNode(strict ? Kind::LT : Kind::LEQ, lhs, value));
     223                 :            :     }
     224                 :            :   }
     225 [ +  + ][ +  + ]:      57172 :   else if (strict && b.upper_value == value)
                 [ +  + ]
     226                 :            :   {
     227                 :       2907 :     auto* nm = nodeManager();
     228                 :       2907 :     b.upper_strict = strict;
     229                 :       2907 :     b.upper_bound = rewrite(nm->mkNode(Kind::LT, lhs, value));
     230                 :       2907 :     b.upper_origin = origin;
     231                 :            :   }
     232                 :     183194 : }
     233                 :            : 
     234                 :          0 : std::ostream& operator<<(std::ostream& os, const BoundInference& bi)
     235                 :            : {
     236                 :          0 :   os << "Bounds:" << std::endl;
     237         [ -  - ]:          0 :   for (const auto& b : bi.get())
     238                 :            :   {
     239                 :          0 :     os << "\t" << b.first << " -> " << b.second.lower_value << ".."
     240                 :          0 :        << b.second.upper_value << std::endl;
     241                 :            :   }
     242                 :          0 :   return os;
     243                 :            : }
     244                 :            : 
     245                 :            : }  // namespace arith
     246                 :            : }  // namespace theory
     247                 :            : }  // namespace cvc5::internal

Generated by: LCOV version 1.14