1-- SPDX-FileCopyrightText: 2025 Sören Tempel <soeren+git@soeren-tempel.net>
2--
3-- SPDX-License-Identifier: MIT AND GPL-3.0-only
4{-# LANGUAGE PatternSynonyms #-}
5
6module SimpleBV
7 ( SExpr,
8 SMT.Solver,
9 SMT.defaultConfig,
10 SMT.newLogger,
11 SMT.newLoggerWithHandle,
12 SMT.newSolver,
13 SMT.newSolverWithConfig,
14 SMT.solverLogger,
15 SMT.smtSolverLogger,
16 SMT.setLogic,
17 SMT.push,
18 SMT.pop,
19 SMT.popMany,
20 SMT.check,
21 SMT.Result (..),
22 SMT.Value (..),
23 pattern W,
24 pattern Byte,
25 pattern Half,
26 pattern Word,
27 pattern Long,
28 width,
29 const,
30 declareBV,
31 assert,
32 sexprToVal,
33 getValue,
34 getValues,
35 toSMT,
36 ite,
37 and,
38 or,
39 not,
40 eq,
41 bvLit,
42 bvAdd,
43 bvAShr,
44 bvLShr,
45 bvAnd,
46 bvMul,
47 bvNeg,
48 bvOr,
49 bvSDiv,
50 bvSLeq,
51 bvSLt,
52 bvSRem,
53 bvShl,
54 bvSub,
55 bvUDiv,
56 bvULeq,
57 bvULt,
58 bvURem,
59 bvXOr,
60 concat,
61 extract,
62 signExtend,
63 zeroExtend,
64 )
65where
66
67import Data.Bits (shiftL, shiftR, (.&.))
68import SimpleSMT qualified as SMT
69import Prelude hiding (and, concat, const, not, or)
70
71data Expr a
72 = Var String
73 | Int Integer
74 | And a a
75 | Or a a
76 | Neg a
77 | Not a
78 | Eq a a
79 | BvAdd a a
80 | BvAShr a a
81 | BvLShr a a
82 | BvAnd a a
83 | BvMul a a
84 | BvOr a a
85 | BvSDiv a a
86 | BvSLeq a a
87 | BvSLt a a
88 | BvSRem a a
89 | BvShl a a
90 | BvSub a a
91 | BvUDiv a a
92 | BvULeq a a
93 | BvULt a a
94 | BvURem a a
95 | BvXOr a a
96 | Concat a a
97 | Ite a a a
98 | Extract Int Int a
99 | SignExtend Integer a
100 | ZeroExtend Integer a
101 deriving (Show, Eq)
102
103data SExpr
104 = SExpr
105 { width :: Int,
106 sexpr :: Expr SExpr
107 }
108 deriving (Show, Eq)
109
110toSMT :: SExpr -> SMT.SExpr
111toSMT expr =
112 case sexpr expr of
113 (Var name) -> SMT.const name
114 (Int v) -> SMT.bvHex (width expr) v
115 (Or lhs rhs) -> SMT.or (toSMT lhs) (toSMT rhs)
116 (Ite cond lhs rhs) -> SMT.ite (toSMT cond) (toSMT lhs) (toSMT rhs)
117 (And lhs rhs) -> SMT.and (toSMT lhs) (toSMT rhs)
118 (Not v) -> SMT.not (toSMT v)
119 (Neg v) -> SMT.bvNeg (toSMT v)
120 (SignExtend n v) -> SMT.signExtend n (toSMT v)
121 (ZeroExtend n v) -> SMT.zeroExtend n (toSMT v)
122 (Eq lhs rhs) -> SMT.eq (toSMT lhs) (toSMT rhs)
123 (Concat lhs rhs) -> SMT.concat (toSMT lhs) (toSMT rhs)
124 (Extract o w e) -> SMT.extract (toSMT e) (fromIntegral $ o + w - 1) (fromIntegral o)
125 (BvAnd lhs rhs) -> SMT.bvAnd (toSMT lhs) (toSMT rhs)
126 (BvAShr lhs rhs) -> SMT.bvAShr (toSMT lhs) (toSMT rhs)
127 (BvLShr lhs rhs) -> SMT.bvLShr (toSMT lhs) (toSMT rhs)
128 (BvAdd lhs rhs) -> SMT.bvAdd (toSMT lhs) (toSMT rhs)
129 (BvMul lhs rhs) -> SMT.bvMul (toSMT lhs) (toSMT rhs)
130 (BvOr lhs rhs) -> SMT.bvOr (toSMT lhs) (toSMT rhs)
131 (BvSDiv lhs rhs) -> SMT.bvSDiv (toSMT lhs) (toSMT rhs)
132 (BvSLeq lhs rhs) -> SMT.bvSLeq (toSMT lhs) (toSMT rhs)
133 (BvSLt lhs rhs) -> SMT.bvSLt (toSMT lhs) (toSMT rhs)
134 (BvSRem lhs rhs) -> SMT.bvSRem (toSMT lhs) (toSMT rhs)
135 (BvShl lhs rhs) -> SMT.bvShl (toSMT lhs) (toSMT rhs)
136 (BvSub lhs rhs) -> SMT.bvSub (toSMT lhs) (toSMT rhs)
137 (BvUDiv lhs rhs) -> SMT.bvUDiv (toSMT lhs) (toSMT rhs)
138 (BvULeq lhs rhs) -> SMT.bvULeq (toSMT lhs) (toSMT rhs)
139 (BvULt lhs rhs) -> SMT.bvULt (toSMT lhs) (toSMT rhs)
140 (BvURem lhs rhs) -> SMT.bvURem (toSMT lhs) (toSMT rhs)
141 (BvXOr lhs rhs) -> SMT.bvXOr (toSMT lhs) (toSMT rhs)
142
143boolWidth :: Int
144boolWidth = 1
145
146pattern E :: Expr SExpr -> SExpr
147pattern E expr <- SExpr {sexpr = expr, width = _}
148
149pattern W :: Int -> SExpr
150pattern W w <- SExpr {width = w}
151
152pattern Byte :: SExpr
153pattern Byte <- SExpr {width = 8}
154
155pattern Half :: SExpr
156pattern Half <- SExpr {width = 16}
157
158pattern Word :: SExpr
159pattern Word <- SExpr {width = 32}
160
161pattern Long :: SExpr
162pattern Long <- SExpr {width = 64}
163
164------------------------------------------------------------------------
165
166const :: String -> Int -> SExpr
167const name width = SExpr width (Var name)
168
169declareBV :: SMT.Solver -> String -> Int -> IO SExpr
170declareBV solver name width = do
171 let bits = SMT.tBits $ fromIntegral width
172 SMT.declare solver name bits >> pure (const name width)
173
174bvLit :: Int -> Integer -> SExpr
175bvLit width value = SExpr width (Int value)
176
177sexprToVal :: SExpr -> SMT.Value
178sexprToVal (E (Var n)) = SMT.Other $ SMT.Atom n
179sexprToVal e@(E (Int i)) = SMT.Bits (width e) i
180sexprToVal _ = SMT.Other $ SMT.Atom "_"
181
182assert :: SMT.Solver -> SExpr -> IO ()
183assert solver = SMT.assert solver . toSMT
184
185getValue :: SMT.Solver -> SExpr -> IO SMT.Value
186getValue solver = SMT.getExpr solver . toSMT
187
188getValues :: SMT.Solver -> [SExpr] -> IO [(String, SMT.Value)]
189getValues solver exprs = do
190 map go <$> SMT.getExprs solver (map toSMT exprs)
191 where
192 go :: (SMT.SExpr, SMT.Value) -> (String, SMT.Value)
193 go (SMT.Atom name, value) = (name, value)
194 go _ = error "non-atomic variable in inputVars"
195
196---------------------------------------------------------------------------
197
198ite :: SExpr -> SExpr -> SExpr -> SExpr
199ite cond ifT ifF = SExpr (width ifT) (Ite cond ifT ifF)
200
201not :: SExpr -> SExpr
202not (E (Not cond)) = cond
203not expr = expr {sexpr = Not expr}
204
205and :: SExpr -> SExpr -> SExpr
206and lhs rhs = lhs {sexpr = And lhs rhs}
207
208or :: SExpr -> SExpr -> SExpr
209or lhs rhs = lhs {sexpr = Or lhs rhs}
210
211signExtend :: Integer -> SExpr -> SExpr
212signExtend n expr = SExpr (width expr + fromIntegral n) $ SignExtend n expr
213
214zeroExtend :: Integer -> SExpr -> SExpr
215zeroExtend n expr = SExpr (width expr + fromIntegral n) $ ZeroExtend n expr
216
217------------------------------------------------------------------------
218
219eq' :: SExpr -> SExpr -> SExpr
220eq' lhs rhs = SExpr boolWidth $ Eq lhs rhs
221
222-- Eliminates ITE expressions when comparing with constants values, this is
223-- useful in the QBE context to eliminate comparisons with truth values.
224eq :: SExpr -> SExpr -> SExpr
225eq lexpr@(E (Ite cond (E (Int ifT)) (E (Int ifF)))) rexpr@(E (Int other))
226 | other == ifT = cond
227 | other == ifF = not cond
228 | otherwise = eq' lexpr rexpr
229eq lhs rhs = eq' lhs rhs
230
231concat' :: SExpr -> SExpr -> SExpr
232concat' lhs rhs =
233 SExpr (width lhs + width rhs) $ Concat lhs rhs
234
235-- Replaces continuous concat expressions with a single extract expression.
236concat :: SExpr -> SExpr -> SExpr
237concat
238 lhs@(E (Extract loff lwidth latom@(E (Var varLhs))))
239 rhs@(E (Extract roff rwidth (E (Var varRhs))))
240 | varLhs == varRhs && (roff + rwidth) == loff = extract latom roff (lwidth + rwidth)
241 | otherwise = concat' lhs rhs
242concat lhs rhs = concat' lhs rhs
243
244extract' :: SExpr -> Int -> Int -> SExpr
245extract' expr off w = SExpr w $ Extract off w expr
246
247-- Eliminate extract expression where the value already has the desired bits.
248extractSameWidth :: SExpr -> Int -> Int -> SExpr
249extractSameWidth expr off w
250 | off == 0 && width expr == w = expr
251 | otherwise = extract' expr off w
252
253-- Eliminate nested extract expression of the same width.
254extractNested :: SExpr -> Int -> Int -> SExpr
255extractNested expr@(E (Extract ioff iwidth _)) off width =
256 if ioff == off && iwidth == width
257 then expr
258 else extractSameWidth expr off width
259extractNested expr off width = extractSameWidth expr off width
260
261-- Performs direct extractions of constant immediate values.
262extractConst :: SExpr -> Int -> Int -> SExpr
263extractConst (E (Int value)) off w =
264 SExpr w . Int $ truncTo (value `shiftR` off) w
265 where
266 truncTo v bits = v .&. ((1 `shiftL` bits) - 1)
267extractConst expr off width = extractNested expr off width
268
269-- This performs constant propagation for subtyping of condition values (i.e.
270-- the conversion from long to word).
271extractIte :: SExpr -> Int -> Int -> SExpr
272extractIte (E (Ite cond ifT@(E (Int _)) ifF@(E (Int _)))) off w =
273 let ex x = extractConst x off w
274 in SExpr w $ Ite cond (ex ifT) (ex ifF)
275extractIte expr off width = extractConst expr off width
276
277extractExt' ::
278 (Integer -> SExpr -> Expr SExpr) ->
279 SExpr ->
280 Integer ->
281 SExpr ->
282 Int ->
283 Int ->
284 SExpr
285extractExt' cons outer extBits inner exOff exWidth
286 -- If we are only extracting the non-extended bytes...
287 | width inner >= exOff + exWidth = extractIte inner exOff exWidth
288 -- Consider: ((_ extract 31 0) ((_ zero_extend 56) byte))
289 | exWidth < fromIntegral extBits && exOff == 0 =
290 SExpr exWidth $ cons (extBits - fromIntegral exWidth) inner
291 -- No folding...
292 | otherwise = extractIte outer exOff exWidth
293
294-- Remove ZeroExtend and SignExtend expression where we don't use
295-- the extended bits because we extract below the extended size.
296extractExt :: SExpr -> Int -> Int -> SExpr
297extractExt expr@(E (SignExtend extBits inner)) exOff exWidth =
298 extractExt' SignExtend expr extBits inner exOff exWidth
299extractExt expr@(E (ZeroExtend extBits inner)) exOff exWidth =
300 extractExt' ZeroExtend expr extBits inner exOff exWidth
301extractExt expr off w = extractIte expr off w
302
303extract :: SExpr -> Int -> Int -> SExpr
304extract = extractExt
305
306------------------------------------------------------------------------
307
308binOp' :: (SExpr -> SExpr -> Expr SExpr) -> SExpr -> SExpr -> SExpr
309binOp' op lhs rhs = lhs {sexpr = op lhs rhs}
310
311binOp :: (SExpr -> SExpr -> Expr SExpr) -> SExpr -> SExpr -> SExpr
312-- Consider: (bvslt ((_ zero_extend 24) byte0) ((_ zero_extend 24) byte1))
313-- TODO: The following only works if 'op' does not consider sign-bits. Otherwise,
314-- there is no semantic expression equivalence after this folding operation.
315-- binOp op lhs@(E (ZeroExtend _ lhsInner)) rhs@(E (ZeroExtend _ rhsInner)) =
316-- if width lhsInner == width rhsInner
317-- then binOp op lhsInner rhsInner
318-- else binOp' op lhs rhs
319binOp op lhs rhs = binOp' op lhs rhs
320
321-- TODO: Generate these using template-haskell.
322
323bvNeg :: SExpr -> SExpr
324bvNeg x = x {sexpr = Neg x}
325
326bvAdd :: SExpr -> SExpr -> SExpr
327bvAdd = binOp BvAdd
328
329bvAShr :: SExpr -> SExpr -> SExpr
330bvAShr = binOp BvAShr
331
332bvLShr :: SExpr -> SExpr -> SExpr
333bvLShr = binOp BvLShr
334
335bvAnd :: SExpr -> SExpr -> SExpr
336bvAnd = binOp BvAnd
337
338bvMul :: SExpr -> SExpr -> SExpr
339bvMul = binOp BvMul
340
341bvOr :: SExpr -> SExpr -> SExpr
342bvOr = binOp BvOr
343
344bvSDiv :: SExpr -> SExpr -> SExpr
345bvSDiv = binOp BvSDiv
346
347bvSLeq :: SExpr -> SExpr -> SExpr
348bvSLeq = binOp BvSLeq
349
350bvSLt :: SExpr -> SExpr -> SExpr
351bvSLt = binOp BvSLt
352
353bvSRem :: SExpr -> SExpr -> SExpr
354bvSRem = binOp BvSRem
355
356bvShl :: SExpr -> SExpr -> SExpr
357bvShl = binOp BvShl
358
359bvSub :: SExpr -> SExpr -> SExpr
360bvSub = binOp BvSub
361
362bvUDiv :: SExpr -> SExpr -> SExpr
363bvUDiv = binOp BvUDiv
364
365bvULeq :: SExpr -> SExpr -> SExpr
366bvULeq = binOp BvULeq
367
368bvULt :: SExpr -> SExpr -> SExpr
369bvULt = binOp BvULt
370
371bvURem :: SExpr -> SExpr -> SExpr
372bvURem = binOp BvURem
373
374bvXOr :: SExpr -> SExpr -> SExpr
375bvXOr = binOp BvXOr