1-- SPDX-FileCopyrightText: 2025 Sören Tempel <soeren+git@soeren-tempel.net>2--3-- SPDX-License-Identifier: MIT AND GPL-3.0-only4{-# LANGUAGE PatternSynonyms #-}56module SimpleBV7 ( SExpr,8 SMT.Solver,9 SMT.defaultConfig,10 SMT.newLogger,11 SMT.newLoggerWithHandle,12 SMT.newSolver,13 SMT.newSolverWithConfig,14 SMT.solverLogger,15 SMT.smtSolverLogger,16 SMT.setLogic,17 SMT.push,18 SMT.pop,19 SMT.popMany,20 SMT.check,21 SMT.Result (..),22 SMT.Value (..),23 pattern W,24 pattern Byte,25 pattern Half,26 pattern Word,27 pattern Long,28 width,29 const,30 declareBV,31 assert,32 sexprToVal,33 getValue,34 getValues,35 toSMT,36 ite,37 and,38 or,39 not,40 eq,41 bvLit,42 bvAdd,43 bvAShr,44 bvLShr,45 bvAnd,46 bvMul,47 bvNeg,48 bvOr,49 bvSDiv,50 bvSLeq,51 bvSLt,52 bvSGeq,53 bvSGt,54 bvSRem,55 bvShl,56 bvSub,57 bvUDiv,58 bvULeq,59 bvUGeq,60 bvUGt,61 bvULt,62 bvURem,63 bvXOr,64 concat,65 extract,66 signExtend,67 zeroExtend,68 )69where7071import Control.DeepSeq (NFData, NFData1)72import Data.Bits (shiftL, shiftR, (.&.))73import GHC.Generics (Generic, Generic1)74import SimpleSMT qualified as SMT75import Prelude hiding (and, concat, const, not, or)7677data Expr a78 = Var String79 | Int Integer80 | And a a81 | Or a a82 | Neg a83 | Not a84 | Eq a a85 | BvAdd a a86 | BvAShr a a87 | BvLShr a a88 | BvAnd a a89 | BvMul a a90 | BvOr a a91 | BvSDiv a a92 | BvSLeq a a93 | BvSLt a a94 | BvSGeq a a95 | BvSGt a a96 | BvSRem a a97 | BvShl a a98 | BvSub a a99 | BvUDiv a a100 | BvULeq a a101 | BvUGeq a a102 | BvUGt a a103 | BvULt a a104 | BvURem a a105 | BvXOr a a106 | Concat a a107 | Ite a a a108 | Extract Int Int a109 | SignExtend Integer a110 | ZeroExtend Integer a111 deriving (Show, Eq, Generic, Generic1)112113instance (NFData a) => NFData (Expr a)114115instance NFData1 Expr116117data SExpr118 = SExpr119 { width :: Int,120 sexpr :: Expr SExpr121 }122 deriving (Show, Eq, Generic)123124instance NFData SExpr125126toSMT :: SExpr -> SMT.SExpr127toSMT expr =128 case sexpr expr of129 (Var name) -> SMT.const name130 (Int v) -> SMT.List [SMT.Atom "_", SMT.Atom ("bv" ++ show v), SMT.Atom $ show (width expr)]131 (Or lhs rhs) -> SMT.or (toSMT lhs) (toSMT rhs)132 (Ite cond lhs rhs) -> SMT.ite (toSMT cond) (toSMT lhs) (toSMT rhs)133 (And lhs rhs) -> SMT.and (toSMT lhs) (toSMT rhs)134 (Not v) -> SMT.not (toSMT v)135 (Neg v) -> SMT.bvNeg (toSMT v)136 (SignExtend n v) -> SMT.signExtend n (toSMT v)137 (ZeroExtend n v) -> SMT.zeroExtend n (toSMT v)138 (Eq lhs rhs) -> SMT.eq (toSMT lhs) (toSMT rhs)139 (Concat lhs rhs) -> SMT.concat (toSMT lhs) (toSMT rhs)140 (Extract o w e) -> SMT.extract (toSMT e) (fromIntegral $ o + w - 1) (fromIntegral o)141 (BvAnd lhs rhs) -> SMT.bvAnd (toSMT lhs) (toSMT rhs)142 (BvAShr lhs rhs) -> SMT.bvAShr (toSMT lhs) (toSMT rhs)143 (BvLShr lhs rhs) -> SMT.bvLShr (toSMT lhs) (toSMT rhs)144 (BvAdd lhs rhs) -> SMT.bvAdd (toSMT lhs) (toSMT rhs)145 (BvMul lhs rhs) -> SMT.bvMul (toSMT lhs) (toSMT rhs)146 (BvOr lhs rhs) -> SMT.bvOr (toSMT lhs) (toSMT rhs)147 (BvSDiv lhs rhs) -> SMT.bvSDiv (toSMT lhs) (toSMT rhs)148 (BvSLeq lhs rhs) -> SMT.bvSLeq (toSMT lhs) (toSMT rhs)149 (BvSLt lhs rhs) -> SMT.bvSLt (toSMT lhs) (toSMT rhs)150 (BvSGeq lhs rhs) -> SMT.fun "bvsge" [toSMT lhs, toSMT rhs]151 (BvSGt lhs rhs) -> SMT.fun "bvsgt" [toSMT lhs, toSMT rhs]152 (BvSRem lhs rhs) -> SMT.bvSRem (toSMT lhs) (toSMT rhs)153 (BvShl lhs rhs) -> SMT.bvShl (toSMT lhs) (toSMT rhs)154 (BvSub lhs rhs) -> SMT.bvSub (toSMT lhs) (toSMT rhs)155 (BvUDiv lhs rhs) -> SMT.bvUDiv (toSMT lhs) (toSMT rhs)156 (BvULeq lhs rhs) -> SMT.bvULeq (toSMT lhs) (toSMT rhs)157 (BvUGeq lhs rhs) -> SMT.fun "bvuge" [toSMT lhs, toSMT rhs]158 (BvUGt lhs rhs) -> SMT.fun "bvugt" [toSMT lhs, toSMT rhs]159 (BvULt lhs rhs) -> SMT.bvULt (toSMT lhs) (toSMT rhs)160 (BvURem lhs rhs) -> SMT.bvURem (toSMT lhs) (toSMT rhs)161 (BvXOr lhs rhs) -> SMT.bvXOr (toSMT lhs) (toSMT rhs)162163boolWidth :: Int164boolWidth = 1165166pattern E :: Expr SExpr -> SExpr167pattern E expr <- SExpr {sexpr = expr, width = _}168169pattern W :: Int -> SExpr170pattern W w <- SExpr {width = w}171172pattern Byte :: SExpr173pattern Byte <- SExpr {width = 8}174175pattern Half :: SExpr176pattern Half <- SExpr {width = 16}177178pattern Word :: SExpr179pattern Word <- SExpr {width = 32}180181pattern Long :: SExpr182pattern Long <- SExpr {width = 64}183184------------------------------------------------------------------------185186const :: String -> Int -> SExpr187const name width = SExpr width (Var name)188189declareBV :: SMT.Solver -> String -> Int -> IO SExpr190declareBV solver name width = do191 let bits = SMT.tBits $ fromIntegral width192 SMT.declare solver name bits >> pure (const name width)193194bvLit :: Int -> Integer -> SExpr195bvLit width value = SExpr width (Int value)196197sexprToVal :: SExpr -> SMT.Value198sexprToVal (E (Var n)) = SMT.Other $ SMT.Atom n199sexprToVal e@(E (Int i)) = SMT.Bits (width e) i200sexprToVal _ = SMT.Other $ SMT.Atom "_"201202assert :: SMT.Solver -> SExpr -> IO ()203assert solver = SMT.assert solver . toSMT204205getValue :: SMT.Solver -> SExpr -> IO SMT.Value206getValue solver = SMT.getExpr solver . toSMT207208getValues :: SMT.Solver -> [SExpr] -> IO [(String, SMT.Value)]209getValues solver exprs = do210 map go <$> SMT.getExprs solver (map toSMT exprs)211 where212 go :: (SMT.SExpr, SMT.Value) -> (String, SMT.Value)213 go (SMT.Atom name, value) = (name, value)214 go _ = error "non-atomic variable in inputVars"215216---------------------------------------------------------------------------217218ite :: SExpr -> SExpr -> SExpr -> SExpr219ite cond ifT ifF = SExpr (width ifT) (Ite cond ifT ifF)220221not :: SExpr -> SExpr222not (E (Not cond)) = cond223not expr = expr {sexpr = Not expr}224225and :: SExpr -> SExpr -> SExpr226and lhs rhs = lhs {sexpr = And lhs rhs}227228or :: SExpr -> SExpr -> SExpr229or lhs rhs = lhs {sexpr = Or lhs rhs}230231signExtend :: Integer -> SExpr -> SExpr232signExtend n expr = SExpr (width expr + fromIntegral n) $ SignExtend n expr233234zeroExtend :: Integer -> SExpr -> SExpr235zeroExtend n expr = SExpr (width expr + fromIntegral n) $ ZeroExtend n expr236237------------------------------------------------------------------------238239eq' :: SExpr -> SExpr -> SExpr240eq' lhs rhs = SExpr boolWidth $ Eq lhs rhs241242-- Eliminates ITE expressions when comparing with constants values, this is243-- useful in the QBE context to eliminate comparisons with truth values.244eq :: SExpr -> SExpr -> SExpr245eq lexpr@(E (Ite cond (E (Int ifT)) (E (Int ifF)))) rexpr@(E (Int other))246 | other == ifT = cond247 | other == ifF = not cond248 | otherwise = eq' lexpr rexpr249eq lhs rhs = eq' lhs rhs250251concat' :: SExpr -> SExpr -> SExpr252concat' lhs rhs =253 SExpr (width lhs + width rhs) $ Concat lhs rhs254255-- Replace 0 concats with zero extension: (concat (_ bv0 8) buf6)256concatZeros :: SExpr -> SExpr -> SExpr257concatZeros lhs@(E (Int 0)) rhs = zeroExtend (fromIntegral $ width lhs) rhs258concatZeros lhs rhs = concat' lhs rhs259260-- Replaces continuous concat expressions with a single extract expression.261concat :: SExpr -> SExpr -> SExpr262concat263 lhs@(E (Extract loff lwidth latom@(E exprLhs)))264 rhs@(E (Extract roff rwidth (E exprRhs)))265 | exprLhs == exprRhs && (roff + rwidth) == loff = extract latom roff (lwidth + rwidth)266 | otherwise = concatZeros lhs rhs267concat lhs rhs = concatZeros lhs rhs268269extract' :: SExpr -> Int -> Int -> SExpr270extract' expr off w = SExpr w $ Extract off w expr271272-- Eliminate extract expression where the value already has the desired bits.273extractSameWidth :: SExpr -> Int -> Int -> SExpr274extractSameWidth expr off w275 | off == 0 && width expr == w = expr276 | otherwise = extract' expr off w277278-- Eliminate nested extract expression of the same width.279extractNested :: SExpr -> Int -> Int -> SExpr280extractNested expr@(E (Extract ioff iwidth _)) off width =281 if ioff == off && iwidth == width282 then expr283 else extractSameWidth expr off width284extractNested expr off width = extractSameWidth expr off width285286-- Performs direct extractions of constant immediate values.287extractConst :: SExpr -> Int -> Int -> SExpr288extractConst (E (Int value)) off w =289 SExpr w . Int $ truncTo (value `shiftR` off) w290 where291 truncTo v bits = v .&. ((1 `shiftL` bits) - 1)292extractConst expr off width = extractNested expr off width293294-- This performs constant propagation for subtyping of condition values (i.e.295-- the conversion from long to word).296extractIte :: SExpr -> Int -> Int -> SExpr297extractIte (E (Ite cond ifT@(E (Int _)) ifF@(E (Int _)))) off w =298 let ex x = extractConst x off w299 in SExpr w $ Ite cond (ex ifT) (ex ifF)300extractIte expr off width = extractConst expr off width301302extractZeros ::303 SExpr ->304 Int ->305 Int ->306 SExpr307extractZeros expr@(E (ZeroExtend extBits inner)) exOff exWidth308 | exOff >= width inner && extBits > 0 = bvLit exWidth 0 -- only extracting zeros309 | otherwise = extractIte expr exOff exWidth310extractZeros outer exOff exWidth = extractIte outer exOff exWidth311312extractExt' ::313 (Integer -> SExpr -> Expr SExpr) ->314 SExpr ->315 Integer ->316 SExpr ->317 Int ->318 Int ->319 SExpr320extractExt' cons outer extBits inner exOff exWidth321 -- If we are only extracting the non-extended bytes...322 | width inner >= exOff + exWidth = extractZeros inner exOff exWidth323 -- Consider: ((_ extract 31 0) ((_ zero_extend 56) byte))324 | exWidth < fromIntegral extBits && exOff == 0 =325 SExpr exWidth $ cons (extBits - fromIntegral exWidth) inner326 -- No folding...327 | otherwise = extractZeros outer exOff exWidth328329-- Remove ZeroExtend and SignExtend expression where we don't use330-- the extended bits because we extract below the extended size.331extractExt :: SExpr -> Int -> Int -> SExpr332extractExt expr@(E (SignExtend extBits inner)) exOff exWidth =333 extractExt' SignExtend expr extBits inner exOff exWidth334extractExt expr@(E (ZeroExtend extBits inner)) exOff exWidth =335 extractExt' ZeroExtend expr extBits inner exOff exWidth336extractExt expr off w = extractIte expr off w337338extract :: SExpr -> Int -> Int -> SExpr339extract = extractExt340341------------------------------------------------------------------------342343binOp' :: (SExpr -> SExpr -> Expr SExpr) -> SExpr -> SExpr -> SExpr344binOp' op lhs rhs = lhs {sexpr = op lhs rhs}345346binOp :: (SExpr -> SExpr -> Expr SExpr) -> SExpr -> SExpr -> SExpr347-- Consider: (bvslt ((_ zero_extend 24) byte0) ((_ zero_extend 24) byte1))348-- TODO: The following only works if 'op' does not consider sign-bits. Otherwise,349-- there is no semantic expression equivalence after this folding operation.350-- binOp op lhs@(E (ZeroExtend _ lhsInner)) rhs@(E (ZeroExtend _ rhsInner)) =351-- if width lhsInner == width rhsInner352-- then binOp op lhsInner rhsInner353-- else binOp' op lhs rhs354binOp op lhs rhs = binOp' op lhs rhs355356-- TODO: Generate these using template-haskell.357358bvNeg :: SExpr -> SExpr359bvNeg x = x {sexpr = Neg x}360361bvAdd :: SExpr -> SExpr -> SExpr362bvAdd = binOp BvAdd363364bvAShr :: SExpr -> SExpr -> SExpr365bvAShr = binOp BvAShr366367bvLShr :: SExpr -> SExpr -> SExpr368bvLShr = binOp BvLShr369370bvAnd :: SExpr -> SExpr -> SExpr371bvAnd = binOp BvAnd372373bvMul :: SExpr -> SExpr -> SExpr374bvMul = binOp BvMul375376bvOr :: SExpr -> SExpr -> SExpr377bvOr = binOp BvOr378379bvSDiv :: SExpr -> SExpr -> SExpr380bvSDiv = binOp BvSDiv381382bvSLeq :: SExpr -> SExpr -> SExpr383bvSLeq = binOp BvSLeq384385bvSLt :: SExpr -> SExpr -> SExpr386bvSLt = binOp BvSLt387388bvSGeq :: SExpr -> SExpr -> SExpr389bvSGeq = binOp BvSGeq390391bvSGt :: SExpr -> SExpr -> SExpr392bvSGt = binOp BvSGt393394bvSRem :: SExpr -> SExpr -> SExpr395bvSRem = binOp BvSRem396397bvShl :: SExpr -> SExpr -> SExpr398bvShl = binOp BvShl399400bvSub :: SExpr -> SExpr -> SExpr401bvSub = binOp BvSub402403bvUDiv :: SExpr -> SExpr -> SExpr404bvUDiv = binOp BvUDiv405406bvULeq :: SExpr -> SExpr -> SExpr407bvULeq = binOp BvULeq408409bvUGeq :: SExpr -> SExpr -> SExpr410bvUGeq = binOp BvUGeq411412bvUGt :: SExpr -> SExpr -> SExpr413bvUGt = binOp BvUGt414415bvULt :: SExpr -> SExpr -> SExpr416bvULt = binOp BvULt417418bvURem :: SExpr -> SExpr -> SExpr419-- Fold constant bvURem operations which are emitted a lot in our generated420-- SMT-LIB because of QBE's "shift-value modulo bitsize"-semantics.421bvURem vlhs@(E (Int lhs)) (E (Int rhs)) =422 SExpr (width vlhs) $423 -- XXX: On urem-by-zero, SMT-LIB returns the lhs.424 if rhs == 0425 then Int lhs426 else Int $ lhs `rem` rhs427bvURem lhs rhs = binOp BvURem lhs rhs428429bvXOr :: SExpr -> SExpr -> SExpr430bvXOr = binOp BvXOr