1use z3::{
2 ast::{Array, BV},
3 Context, Sort,
4};
5
6pub struct Memory<'ctx> {
7 ctx: &'ctx Context,
8 pub data: Array<'ctx>,
9}
10
11impl<'ctx> Memory<'ctx> {
12 pub fn new(ctx: &'ctx Context) -> Memory<'ctx> {
13 let ary = Array::new_const(
14 ctx,
15 "memory",
16 &Sort::bitvector(&ctx, 64), // index type
17 &Sort::bitvector(&ctx, 8), // value type
18 );
19
20 Memory {
21 ctx: ctx,
22 data: ary,
23 }
24 }
25
26 pub fn store_byte(&mut self, addr: BV<'ctx>, value: BV<'ctx>) {
27 assert!(addr.get_size() == 64);
28 assert!(value.get_size() == 8);
29 self.data = self.data.store(&addr, &value);
30 }
31
32 pub fn load_byte(&self, addr: BV<'ctx>) -> BV<'ctx> {
33 assert!(addr.get_size() == 64);
34 self.data.select(&addr).as_bv().unwrap()
35 }
36
37 pub fn store_bitvector(&mut self, addr: BV<'ctx>, value: BV<'ctx>) {
38 assert!(value.get_size() % 8 == 0);
39 let amount = value.get_size() / 8;
40
41 // Extract nth bytes from the bitvector
42 let bytes = (1..=amount)
43 .into_iter()
44 .rev()
45 .map(|n| value.extract((n * 8) - 1, (n - 1) * 8));
46
47 // Store each byte in memory
48 bytes.enumerate().for_each(|(n, b)| {
49 assert!(b.get_size() == 8);
50 self.store_byte(addr.bvadd(&BV::from_u64(self.ctx, n as u64, 64)), b)
51 });
52 }
53
54 pub fn load_bitvector(&self, addr: BV<'ctx>, amount: u64) -> BV<'ctx> {
55 // Load amount bytes from memory
56 let bytes = (0..amount)
57 .into_iter()
58 .map(|n| self.load_byte(addr.bvadd(&BV::from_u64(self.ctx, n, 64))));
59
60 // Concat the bytes into a single bitvector
61 bytes.reduce(|acc, e| acc.concat(&e)).unwrap()
62 }
63
64 pub fn store_string(&mut self, addr: BV<'ctx>, str: &str) -> BV<'ctx> {
65 let mut cur_addr = addr;
66 for c in str.chars() {
67 let code: u8 = c.try_into().unwrap();
68 self.store_byte(cur_addr.clone(), BV::from_u64(self.ctx, code.into(), 8));
69 cur_addr = cur_addr.bvadd(&BV::from_u64(self.ctx, 1, 64));
70 }
71
72 cur_addr
73 }
74
75 pub fn store_word(&mut self, addr: BV<'ctx>, value: BV<'ctx>) {
76 assert!(value.get_size() == 32);
77 self.store_bitvector(addr, value)
78 }
79
80 pub fn load_word(&self, addr: BV<'ctx>) -> BV<'ctx> {
81 assert!(addr.get_size() == 64);
82 self.load_bitvector(addr, 4)
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use z3::ast::Ast;
90 use z3::Config;
91 use z3::SatResult;
92 use z3::Solver;
93
94 #[test]
95 fn test_byte() {
96 let cfg = Config::new();
97 let ctx = Context::new(&cfg);
98 let mut mem = Memory::new(&ctx);
99
100 let addr = BV::from_u64(&ctx, 0x800000, 64);
101 let value = BV::from_u64(&ctx, 0x23, 8);
102
103 mem.store_byte(addr.clone(), value.clone());
104 let loaded = mem.load_byte(addr);
105
106 let solver = Solver::new(&ctx);
107 solver.assert(&loaded._eq(&value));
108 assert_eq!(SatResult::Sat, solver.check());
109 }
110
111 #[test]
112 fn test_string() {
113 let cfg = Config::new();
114 let ctx = Context::new(&cfg);
115 let mut mem = Memory::new(&ctx);
116
117 let addr = BV::from_u64(&ctx, 0x0, 64);
118 mem.store_string(addr, "hello");
119 let loaded = mem.load_byte(BV::from_u64(&ctx, 0x0, 64));
120
121 let solver = Solver::new(&ctx);
122 solver.assert(&loaded._eq(&BV::from_u64(&ctx, 0x68, 8)));
123 assert_eq!(SatResult::Sat, solver.check());
124 }
125
126 #[test]
127 fn test_word() {
128 let cfg = Config::new();
129 let ctx = Context::new(&cfg);
130 let mut mem = Memory::new(&ctx);
131
132 let addr = BV::from_u64(&ctx, 0x1000, 64);
133 let word = BV::from_u64(&ctx, 0xdeadbeef, 32);
134
135 mem.store_word(addr.clone(), word.clone());
136 let bytes = vec![
137 mem.load_byte(BV::from_u64(&ctx, 0x1000, 64)),
138 mem.load_byte(BV::from_u64(&ctx, 0x1001, 64)),
139 mem.load_byte(BV::from_u64(&ctx, 0x1002, 64)),
140 mem.load_byte(BV::from_u64(&ctx, 0x1003, 64)),
141 ];
142
143 let solver = Solver::new(&ctx);
144 solver.assert(&bytes[0]._eq(&BV::from_u64(&ctx, 0xde, 8)));
145 solver.assert(&bytes[1]._eq(&BV::from_u64(&ctx, 0xad, 8)));
146 solver.assert(&bytes[2]._eq(&BV::from_u64(&ctx, 0xbe, 8)));
147 solver.assert(&bytes[3]._eq(&BV::from_u64(&ctx, 0xef, 8)));
148 assert_eq!(SatResult::Sat, solver.check());
149
150 solver.reset();
151
152 let loaded_word = mem.load_word(addr);
153 solver.assert(&loaded_word._eq(&word));
154 assert_eq!(SatResult::Sat, solver.check());
155 }
156}