1-- SPDX-FileCopyrightText: 2025 Sören Tempel <soeren+git@soeren-tempel.net>2--3-- SPDX-License-Identifier: MIT AND GPL-3.0-only4{-# LANGUAGE PatternSynonyms #-}56module SimpleBV7 ( SExpr,8 SMT.Solver,9 SMT.defaultConfig,10 SMT.newLogger,11 SMT.newLoggerWithHandle,12 SMT.newSolver,13 SMT.newSolverWithConfig,14 SMT.solverLogger,15 SMT.smtSolverLogger,16 SMT.setLogic,17 SMT.push,18 SMT.pop,19 SMT.popMany,20 SMT.check,21 SMT.Result (..),22 SMT.Value (..),23 pattern W,24 pattern Byte,25 pattern Half,26 pattern Word,27 pattern Long,28 width,29 const,30 declareBV,31 assert,32 sexprToVal,33 getValue,34 getValues,35 toSMT,36 ite,37 and,38 or,39 not,40 eq,41 bvLit,42 bvAdd,43 bvAShr,44 bvLShr,45 bvAnd,46 bvMul,47 bvNeg,48 bvOr,49 bvSDiv,50 bvSLeq,51 bvSLt,52 bvSGeq,53 bvSGt,54 bvSRem,55 bvShl,56 bvSub,57 bvUDiv,58 bvULeq,59 bvUGeq,60 bvUGt,61 bvULt,62 bvURem,63 bvXOr,64 concat,65 extract,66 signExtend,67 zeroExtend,68 )69where7071import Data.Bits (shiftL, shiftR, (.&.))72import SimpleSMT qualified as SMT73import Prelude hiding (and, concat, const, not, or)7475data Expr a76 = Var String77 | Int Integer78 | And a a79 | Or a a80 | Neg a81 | Not a82 | Eq a a83 | BvAdd a a84 | BvAShr a a85 | BvLShr a a86 | BvAnd a a87 | BvMul a a88 | BvOr a a89 | BvSDiv a a90 | BvSLeq a a91 | BvSLt a a92 | BvSGeq a a93 | BvSGt a a94 | BvSRem a a95 | BvShl a a96 | BvSub a a97 | BvUDiv a a98 | BvULeq a a99 | BvUGeq a a100 | BvUGt a a101 | BvULt a a102 | BvURem a a103 | BvXOr a a104 | Concat a a105 | Ite a a a106 | Extract Int Int a107 | SignExtend Integer a108 | ZeroExtend Integer a109 deriving (Show, Eq)110111data SExpr112 = SExpr113 { width :: Int,114 sexpr :: Expr SExpr115 }116 deriving (Show, Eq)117118toSMT :: SExpr -> SMT.SExpr119toSMT expr =120 case sexpr expr of121 (Var name) -> SMT.const name122 (Int v) -> SMT.List [SMT.Atom "_", SMT.Atom ("bv" ++ show v), SMT.Atom $ show (width expr)]123 (Or lhs rhs) -> SMT.or (toSMT lhs) (toSMT rhs)124 (Ite cond lhs rhs) -> SMT.ite (toSMT cond) (toSMT lhs) (toSMT rhs)125 (And lhs rhs) -> SMT.and (toSMT lhs) (toSMT rhs)126 (Not v) -> SMT.not (toSMT v)127 (Neg v) -> SMT.bvNeg (toSMT v)128 (SignExtend n v) -> SMT.signExtend n (toSMT v)129 (ZeroExtend n v) -> SMT.zeroExtend n (toSMT v)130 (Eq lhs rhs) -> SMT.eq (toSMT lhs) (toSMT rhs)131 (Concat lhs rhs) -> SMT.concat (toSMT lhs) (toSMT rhs)132 (Extract o w e) -> SMT.extract (toSMT e) (fromIntegral $ o + w - 1) (fromIntegral o)133 (BvAnd lhs rhs) -> SMT.bvAnd (toSMT lhs) (toSMT rhs)134 (BvAShr lhs rhs) -> SMT.bvAShr (toSMT lhs) (toSMT rhs)135 (BvLShr lhs rhs) -> SMT.bvLShr (toSMT lhs) (toSMT rhs)136 (BvAdd lhs rhs) -> SMT.bvAdd (toSMT lhs) (toSMT rhs)137 (BvMul lhs rhs) -> SMT.bvMul (toSMT lhs) (toSMT rhs)138 (BvOr lhs rhs) -> SMT.bvOr (toSMT lhs) (toSMT rhs)139 (BvSDiv lhs rhs) -> SMT.bvSDiv (toSMT lhs) (toSMT rhs)140 (BvSLeq lhs rhs) -> SMT.bvSLeq (toSMT lhs) (toSMT rhs)141 (BvSLt lhs rhs) -> SMT.bvSLt (toSMT lhs) (toSMT rhs)142 (BvSGeq lhs rhs) -> SMT.fun "bvsge" [toSMT lhs, toSMT rhs]143 (BvSGt lhs rhs) -> SMT.fun "bvsgt" [toSMT lhs, toSMT rhs]144 (BvSRem lhs rhs) -> SMT.bvSRem (toSMT lhs) (toSMT rhs)145 (BvShl lhs rhs) -> SMT.bvShl (toSMT lhs) (toSMT rhs)146 (BvSub lhs rhs) -> SMT.bvSub (toSMT lhs) (toSMT rhs)147 (BvUDiv lhs rhs) -> SMT.bvUDiv (toSMT lhs) (toSMT rhs)148 (BvULeq lhs rhs) -> SMT.bvULeq (toSMT lhs) (toSMT rhs)149 (BvUGeq lhs rhs) -> SMT.fun "bvuge" [toSMT lhs, toSMT rhs]150 (BvUGt lhs rhs) -> SMT.fun "bvugt" [toSMT lhs, toSMT rhs]151 (BvULt lhs rhs) -> SMT.bvULt (toSMT lhs) (toSMT rhs)152 (BvURem lhs rhs) -> SMT.bvURem (toSMT lhs) (toSMT rhs)153 (BvXOr lhs rhs) -> SMT.bvXOr (toSMT lhs) (toSMT rhs)154155boolWidth :: Int156boolWidth = 1157158pattern E :: Expr SExpr -> SExpr159pattern E expr <- SExpr {sexpr = expr, width = _}160161pattern W :: Int -> SExpr162pattern W w <- SExpr {width = w}163164pattern Byte :: SExpr165pattern Byte <- SExpr {width = 8}166167pattern Half :: SExpr168pattern Half <- SExpr {width = 16}169170pattern Word :: SExpr171pattern Word <- SExpr {width = 32}172173pattern Long :: SExpr174pattern Long <- SExpr {width = 64}175176------------------------------------------------------------------------177178const :: String -> Int -> SExpr179const name width = SExpr width (Var name)180181declareBV :: SMT.Solver -> String -> Int -> IO SExpr182declareBV solver name width = do183 let bits = SMT.tBits $ fromIntegral width184 SMT.declare solver name bits >> pure (const name width)185186bvLit :: Int -> Integer -> SExpr187bvLit width value = SExpr width (Int value)188189sexprToVal :: SExpr -> SMT.Value190sexprToVal (E (Var n)) = SMT.Other $ SMT.Atom n191sexprToVal e@(E (Int i)) = SMT.Bits (width e) i192sexprToVal _ = SMT.Other $ SMT.Atom "_"193194assert :: SMT.Solver -> SExpr -> IO ()195assert solver = SMT.assert solver . toSMT196197getValue :: SMT.Solver -> SExpr -> IO SMT.Value198getValue solver = SMT.getExpr solver . toSMT199200getValues :: SMT.Solver -> [SExpr] -> IO [(String, SMT.Value)]201getValues solver exprs = do202 map go <$> SMT.getExprs solver (map toSMT exprs)203 where204 go :: (SMT.SExpr, SMT.Value) -> (String, SMT.Value)205 go (SMT.Atom name, value) = (name, value)206 go _ = error "non-atomic variable in inputVars"207208---------------------------------------------------------------------------209210ite :: SExpr -> SExpr -> SExpr -> SExpr211ite cond ifT ifF = SExpr (width ifT) (Ite cond ifT ifF)212213not :: SExpr -> SExpr214not (E (Not cond)) = cond215not expr = expr {sexpr = Not expr}216217and :: SExpr -> SExpr -> SExpr218and lhs rhs = lhs {sexpr = And lhs rhs}219220or :: SExpr -> SExpr -> SExpr221or lhs rhs = lhs {sexpr = Or lhs rhs}222223signExtend :: Integer -> SExpr -> SExpr224signExtend n expr = SExpr (width expr + fromIntegral n) $ SignExtend n expr225226zeroExtend :: Integer -> SExpr -> SExpr227zeroExtend n expr = SExpr (width expr + fromIntegral n) $ ZeroExtend n expr228229------------------------------------------------------------------------230231eq' :: SExpr -> SExpr -> SExpr232eq' lhs rhs = SExpr boolWidth $ Eq lhs rhs233234-- Eliminates ITE expressions when comparing with constants values, this is235-- useful in the QBE context to eliminate comparisons with truth values.236eq :: SExpr -> SExpr -> SExpr237eq lexpr@(E (Ite cond (E (Int ifT)) (E (Int ifF)))) rexpr@(E (Int other))238 | other == ifT = cond239 | other == ifF = not cond240 | otherwise = eq' lexpr rexpr241eq lhs rhs = eq' lhs rhs242243concat' :: SExpr -> SExpr -> SExpr244concat' lhs rhs =245 SExpr (width lhs + width rhs) $ Concat lhs rhs246247-- Replace 0 concats with zero extension: (concat (_ bv0 8) buf6)248concatZeros :: SExpr -> SExpr -> SExpr249concatZeros lhs@(E (Int 0)) rhs = zeroExtend (fromIntegral $ width lhs) rhs250concatZeros lhs rhs = concat' lhs rhs251252-- Replaces continuous concat expressions with a single extract expression.253concat :: SExpr -> SExpr -> SExpr254concat255 lhs@(E (Extract loff lwidth latom@(E exprLhs)))256 rhs@(E (Extract roff rwidth (E exprRhs)))257 | exprLhs == exprRhs && (roff + rwidth) == loff = extract latom roff (lwidth + rwidth)258 | otherwise = concatZeros lhs rhs259concat lhs rhs = concatZeros lhs rhs260261extract' :: SExpr -> Int -> Int -> SExpr262extract' expr off w = SExpr w $ Extract off w expr263264-- Eliminate extract expression where the value already has the desired bits.265extractSameWidth :: SExpr -> Int -> Int -> SExpr266extractSameWidth expr off w267 | off == 0 && width expr == w = expr268 | otherwise = extract' expr off w269270-- Eliminate nested extract expression of the same width.271extractNested :: SExpr -> Int -> Int -> SExpr272extractNested expr@(E (Extract ioff iwidth _)) off width =273 if ioff == off && iwidth == width274 then expr275 else extractSameWidth expr off width276extractNested expr off width = extractSameWidth expr off width277278-- Performs direct extractions of constant immediate values.279extractConst :: SExpr -> Int -> Int -> SExpr280extractConst (E (Int value)) off w =281 SExpr w . Int $ truncTo (value `shiftR` off) w282 where283 truncTo v bits = v .&. ((1 `shiftL` bits) - 1)284extractConst expr off width = extractNested expr off width285286-- This performs constant propagation for subtyping of condition values (i.e.287-- the conversion from long to word).288extractIte :: SExpr -> Int -> Int -> SExpr289extractIte (E (Ite cond ifT@(E (Int _)) ifF@(E (Int _)))) off w =290 let ex x = extractConst x off w291 in SExpr w $ Ite cond (ex ifT) (ex ifF)292extractIte expr off width = extractConst expr off width293294extractZeros ::295 SExpr ->296 Int ->297 Int ->298 SExpr299extractZeros expr@(E (ZeroExtend extBits inner)) exOff exWidth300 | exOff >= width inner && extBits > 0 = bvLit exWidth 0 -- only extracting zeros301 | otherwise = extractIte expr exOff exWidth302extractZeros outer exOff exWidth = extractIte outer exOff exWidth303304extractExt' ::305 (Integer -> SExpr -> Expr SExpr) ->306 SExpr ->307 Integer ->308 SExpr ->309 Int ->310 Int ->311 SExpr312extractExt' cons outer extBits inner exOff exWidth313 -- If we are only extracting the non-extended bytes...314 | width inner >= exOff + exWidth = extractZeros inner exOff exWidth315 -- Consider: ((_ extract 31 0) ((_ zero_extend 56) byte))316 | exWidth < fromIntegral extBits && exOff == 0 =317 SExpr exWidth $ cons (extBits - fromIntegral exWidth) inner318 -- No folding...319 | otherwise = extractZeros outer exOff exWidth320321-- Remove ZeroExtend and SignExtend expression where we don't use322-- the extended bits because we extract below the extended size.323extractExt :: SExpr -> Int -> Int -> SExpr324extractExt expr@(E (SignExtend extBits inner)) exOff exWidth =325 extractExt' SignExtend expr extBits inner exOff exWidth326extractExt expr@(E (ZeroExtend extBits inner)) exOff exWidth =327 extractExt' ZeroExtend expr extBits inner exOff exWidth328extractExt expr off w = extractIte expr off w329330extract :: SExpr -> Int -> Int -> SExpr331extract = extractExt332333------------------------------------------------------------------------334335binOp' :: (SExpr -> SExpr -> Expr SExpr) -> SExpr -> SExpr -> SExpr336binOp' op lhs rhs = lhs {sexpr = op lhs rhs}337338binOp :: (SExpr -> SExpr -> Expr SExpr) -> SExpr -> SExpr -> SExpr339-- Consider: (bvslt ((_ zero_extend 24) byte0) ((_ zero_extend 24) byte1))340-- TODO: The following only works if 'op' does not consider sign-bits. Otherwise,341-- there is no semantic expression equivalence after this folding operation.342-- binOp op lhs@(E (ZeroExtend _ lhsInner)) rhs@(E (ZeroExtend _ rhsInner)) =343-- if width lhsInner == width rhsInner344-- then binOp op lhsInner rhsInner345-- else binOp' op lhs rhs346binOp op lhs rhs = binOp' op lhs rhs347348-- TODO: Generate these using template-haskell.349350bvNeg :: SExpr -> SExpr351bvNeg x = x {sexpr = Neg x}352353bvAdd :: SExpr -> SExpr -> SExpr354bvAdd = binOp BvAdd355356bvAShr :: SExpr -> SExpr -> SExpr357bvAShr = binOp BvAShr358359bvLShr :: SExpr -> SExpr -> SExpr360bvLShr = binOp BvLShr361362bvAnd :: SExpr -> SExpr -> SExpr363bvAnd = binOp BvAnd364365bvMul :: SExpr -> SExpr -> SExpr366bvMul = binOp BvMul367368bvOr :: SExpr -> SExpr -> SExpr369bvOr = binOp BvOr370371bvSDiv :: SExpr -> SExpr -> SExpr372bvSDiv = binOp BvSDiv373374bvSLeq :: SExpr -> SExpr -> SExpr375bvSLeq = binOp BvSLeq376377bvSLt :: SExpr -> SExpr -> SExpr378bvSLt = binOp BvSLt379380bvSGeq :: SExpr -> SExpr -> SExpr381bvSGeq = binOp BvSGeq382383bvSGt :: SExpr -> SExpr -> SExpr384bvSGt = binOp BvSGt385386bvSRem :: SExpr -> SExpr -> SExpr387bvSRem = binOp BvSRem388389bvShl :: SExpr -> SExpr -> SExpr390bvShl = binOp BvShl391392bvSub :: SExpr -> SExpr -> SExpr393bvSub = binOp BvSub394395bvUDiv :: SExpr -> SExpr -> SExpr396bvUDiv = binOp BvUDiv397398bvULeq :: SExpr -> SExpr -> SExpr399bvULeq = binOp BvULeq400401bvUGeq :: SExpr -> SExpr -> SExpr402bvUGeq = binOp BvUGeq403404bvUGt :: SExpr -> SExpr -> SExpr405bvUGt = binOp BvUGt406407bvULt :: SExpr -> SExpr -> SExpr408bvULt = binOp BvULt409410bvURem :: SExpr -> SExpr -> SExpr411-- Fold constant bvURem operations which are emitted a lot in our generated412-- SMT-LIB because of QBE's "shift-value modulo bitsize"-semantics.413bvURem vlhs@(E (Int lhs)) (E (Int rhs)) =414 SExpr (width vlhs) $415 -- XXX: On urem-by-zero, SMT-LIB returns the lhs.416 if rhs == 0417 then Int lhs418 else Int $ lhs `rem` rhs419bvURem lhs rhs = binOp BvURem lhs rhs420421bvXOr :: SExpr -> SExpr -> SExpr422bvXOr = binOp BvXOr