1-- SPDX-FileCopyrightText: 2025 Sören Tempel <soeren+git@soeren-tempel.net>
2--
3-- SPDX-License-Identifier: GPL-3.0-only
4
5module Language.QBE.Simulator.Symbolic.Expression
6 ( BitVector,
7 fromByte,
8 fromReg,
9 fromSExpr,
10 toSExpr,
11 symbolic,
12 bitSize,
13 toCond,
14 )
15where
16
17import Control.Exception (assert)
18import Data.Bits (shiftL, (.&.))
19import Data.Word (Word64, Word8)
20import Language.QBE.Simulator.Default.Expression qualified as D
21import Language.QBE.Simulator.Expression qualified as E
22import Language.QBE.Simulator.Memory qualified as MEM
23import Language.QBE.Types qualified as QBE
24import SimpleBV qualified as SMT
25
26-- TODO: Floating point support.
27newtype BitVector = BitVector SMT.SExpr
28 deriving (Show, Eq)
29
30fromByte :: Word8 -> BitVector
31fromByte byte = BitVector (SMT.bvLit 8 $ fromIntegral byte)
32
33fromReg :: D.RegVal -> BitVector
34fromReg (D.VByte v) = BitVector (SMT.bvLit 8 $ fromIntegral v)
35fromReg (D.VHalf v) = BitVector (SMT.bvLit 16 $ fromIntegral v)
36fromReg (D.VWord v) = BitVector (SMT.bvLit 32 $ fromIntegral v)
37fromReg (D.VLong v) = BitVector (SMT.bvLit 64 $ fromIntegral v)
38fromReg (D.VSingle _) = error "symbolic floats not supported"
39fromReg (D.VDouble _) = error "symbolic doubles not supported"
40
41-- TODO: remove
42fromSExpr :: SMT.SExpr -> BitVector
43fromSExpr = BitVector
44
45toSExpr :: BitVector -> SMT.SExpr
46toSExpr (BitVector s) = s
47
48symbolic :: String -> QBE.ExtType -> BitVector
49symbolic name ty = BitVector (SMT.const name $ QBE.extTypeBitSize ty)
50
51bitSize :: BitVector -> Int
52bitSize = SMT.width . toSExpr
53
54-- In the QBE a condition (see `jnz`) is true if the Word value is not zero.
55toCond :: Bool -> BitVector -> SMT.SExpr
56toCond isTrue bv =
57 -- Equality is only defined for Words.
58 assert (bitSize bv == QBE.baseTypeBitSize QBE.Word) $
59 let zeroSExpr = toSExpr (fromReg $ E.fromLit (QBE.Base QBE.Word) 0)
60 in toCond' (toSExpr bv) zeroSExpr
61 where
62 toCond' lhs rhs
63 | isTrue = SMT.not (SMT.eq lhs rhs) -- /= 0
64 | otherwise = SMT.eq lhs rhs -- == 0
65
66------------------------------------------------------------------------
67
68instance MEM.Storable BitVector BitVector where
69 toBytes (BitVector s) =
70 assert (size `mod` 8 == 0) $
71 map (BitVector . nthByte s) [1 .. fromIntegral size `div` 8]
72 where
73 size :: Integer
74 size = fromIntegral $ SMT.width s
75
76 nthByte :: SMT.SExpr -> Int -> SMT.SExpr
77 nthByte expr n = SMT.extract expr ((n - 1) * 8) 8
78
79 fromBytes _ [] = Nothing
80 fromBytes ty bytes@(BitVector s : xs) =
81 if length bytes /= fromIntegral (QBE.loadByteSize ty)
82 then Nothing
83 else case (ty, bytes) of
84 (QBE.LSubWord QBE.UnsignedByte, [_]) ->
85 Just (BitVector (SMT.zeroExtend 24 concated))
86 (QBE.LSubWord QBE.SignedByte, [_]) ->
87 Just (BitVector (SMT.signExtend 24 concated))
88 (QBE.LSubWord QBE.SignedHalf, [_, _]) ->
89 Just (BitVector (SMT.signExtend 16 concated))
90 (QBE.LSubWord QBE.UnsignedHalf, [_, _]) ->
91 Just (BitVector (SMT.zeroExtend 16 concated))
92 (QBE.LBase QBE.Word, [_, _, _, _]) ->
93 Just (BitVector concated)
94 (QBE.LBase QBE.Long, [_, _, _, _, _, _, _, _]) ->
95 Just (BitVector concated)
96 (QBE.LBase QBE.Single, [_, _, _, _]) ->
97 error "float loading not implemented"
98 (QBE.LBase QBE.Double, [_, _, _, _, _, _, _, _]) ->
99 error "double loading not implemented"
100 _ -> Nothing
101 where
102 concated :: SMT.SExpr
103 concated = foldl concatBV s xs
104
105 concatBV :: SMT.SExpr -> BitVector -> SMT.SExpr
106 concatBV acc (BitVector byte) =
107 assert (SMT.width byte == 8) $
108 SMT.concat byte acc
109
110------------------------------------------------------------------------
111
112binaryOp :: (SMT.SExpr -> SMT.SExpr -> SMT.SExpr) -> BitVector -> BitVector -> Maybe BitVector
113binaryOp op (BitVector lhs) (BitVector rhs)
114 | SMT.width lhs == SMT.width rhs = Just $ BitVector (lhs `op` rhs)
115 | otherwise = Nothing
116
117-- TODO: Move this into the expression abstraction.
118toShiftAmount :: Word64 -> BitVector -> Maybe BitVector
119toShiftAmount size amount = amount `E.urem` E.fromLit (QBE.Base QBE.Word) size
120
121shiftOp :: (SMT.SExpr -> SMT.SExpr -> SMT.SExpr) -> BitVector -> BitVector -> Maybe BitVector
122shiftOp op value amount@(BitVector SMT.Word) =
123 case bitSize value of
124 32 -> toShiftAmount 32 amount >>= binaryOp op value
125 64 -> do
126 shiftAmount <- toShiftAmount 64 amount
127 E.wordToLong QBE.SLUnsignedWord shiftAmount >>= binaryOp op value
128 _ -> Nothing
129shiftOp _ _ _ = Nothing -- Shift amount must always be a Word.
130
131binaryBoolOp :: (SMT.SExpr -> SMT.SExpr -> SMT.SExpr) -> BitVector -> BitVector -> Maybe BitVector
132binaryBoolOp op lhs rhs = do
133 bv <- binaryOp op lhs rhs
134 return $ fromSExpr (SMT.ite (toSExpr bv) trueValue falseValue)
135 where
136 -- TODO: Declare these as constants.
137 trueValue :: SMT.SExpr
138 trueValue = toSExpr $ E.fromLit (QBE.Base QBE.Long) 1
139
140 falseValue :: SMT.SExpr
141 falseValue = toSExpr $ E.fromLit (QBE.Base QBE.Long) 0
142
143instance E.ValueRepr BitVector where
144 fromLit ty n =
145 let size = QBE.extTypeBitSize ty
146 mask = (1 `shiftL` size) - 1
147 in BitVector $ SMT.bvLit (fromIntegral size) $ fromIntegral (n .&. mask)
148
149 fromFloat = error "symbolic floats currently unsupported"
150 fromDouble = error "symbolic doubles currently unsupported"
151
152 -- XXX: This only works for constants values, but this is fine since we implement
153 -- concolic execution and can obtain the address from the concrete value part.
154 toWord64 (BitVector value) =
155 case SMT.sexprToVal value of
156 SMT.Bits _ n -> fromIntegral n
157 _ -> error "unrechable"
158
159 wordToLong (QBE.SLSubWord QBE.SignedByte) (BitVector s@SMT.Word) =
160 Just $ BitVector (SMT.signExtend 56 (SMT.extract s 0 8))
161 wordToLong (QBE.SLSubWord QBE.UnsignedByte) (BitVector s@SMT.Word) =
162 Just $ BitVector (SMT.zeroExtend 56 (SMT.extract s 0 8))
163 wordToLong (QBE.SLSubWord QBE.SignedHalf) (BitVector s@SMT.Word) =
164 Just $ BitVector (SMT.signExtend 48 (SMT.extract s 0 16))
165 wordToLong (QBE.SLSubWord QBE.UnsignedHalf) (BitVector s@SMT.Word) =
166 Just $ BitVector (SMT.zeroExtend 48 (SMT.extract s 0 16))
167 wordToLong QBE.SLSignedWord (BitVector s@SMT.Word) =
168 Just $ BitVector (SMT.signExtend 32 s)
169 wordToLong QBE.SLUnsignedWord (BitVector s@SMT.Word) =
170 Just $ BitVector (SMT.zeroExtend 32 s)
171 wordToLong _ _ = Nothing
172
173 subType QBE.Word v@(BitVector SMT.Word) = Just v
174 subType QBE.Word (BitVector s@SMT.Long) =
175 Just $ BitVector (SMT.extract s 0 32)
176 subType QBE.Long v@(BitVector SMT.Long) = Just v
177 subType _ _ = Nothing
178
179 add = binaryOp SMT.bvAdd
180 sub = binaryOp SMT.bvSub
181 mul = binaryOp SMT.bvMul
182 div = binaryOp SMT.bvSDiv
183 or = binaryOp SMT.bvOr
184 xor = binaryOp SMT.bvXOr
185 and = binaryOp SMT.bvAnd
186 urem = binaryOp SMT.bvURem
187 srem = binaryOp SMT.bvSRem
188 udiv = binaryOp SMT.bvUDiv
189
190 neg (BitVector v) = Just $ BitVector (SMT.bvNeg v)
191
192 sar = shiftOp SMT.bvAShr
193 shr = shiftOp SMT.bvLShr
194 shl = shiftOp SMT.bvShl
195
196 eq = binaryBoolOp SMT.eq
197 ne = binaryBoolOp (\lhs rhs -> SMT.not $ SMT.eq lhs rhs)
198 sle = binaryBoolOp SMT.bvSLeq
199 slt = binaryBoolOp SMT.bvSLt
200 sge = binaryBoolOp (flip SMT.bvSLeq)
201 sgt = binaryBoolOp (flip SMT.bvSLt)
202 ule = binaryBoolOp SMT.bvULeq
203 ult = binaryBoolOp SMT.bvULt
204 uge = binaryBoolOp (\lhs rhs -> SMT.or (SMT.bvULt rhs lhs) (SMT.eq lhs rhs))
205 ugt = binaryBoolOp (flip SMT.bvULt)