os/mm/
memory_set.rs

1use super::{FrameTracker, frame_alloc};
2use super::{PTEFlags, PageTable, PageTableEntry};
3use super::{PhysAddr, PhysPageNum, VirtAddr, VirtPageNum};
4use super::{StepByOne, VPNRange};
5use crate::config::{MEMORY_END, MMIO, PAGE_SIZE, TRAMPOLINE};
6use crate::sync::UPSafeCell;
7use alloc::collections::BTreeMap;
8use alloc::sync::Arc;
9use alloc::vec::Vec;
10use core::arch::asm;
11use lazy_static::*;
12use riscv::register::satp;
13
14unsafe extern "C" {
15    safe fn stext();
16    safe fn etext();
17    safe fn srodata();
18    safe fn erodata();
19    safe fn sdata();
20    safe fn edata();
21    safe fn sbss_with_stack();
22    safe fn ebss();
23    safe fn ekernel();
24    safe fn strampoline();
25}
26
27lazy_static! {
28    pub static ref KERNEL_SPACE: Arc<UPSafeCell<MemorySet>> =
29        Arc::new(unsafe { UPSafeCell::new(MemorySet::new_kernel()) });
30}
31
32pub fn kernel_token() -> usize {
33    KERNEL_SPACE.exclusive_access().token()
34}
35
36pub struct MemorySet {
37    page_table: PageTable,
38    areas: Vec<MapArea>,
39}
40
41impl MemorySet {
42    pub fn new_bare() -> Self {
43        Self {
44            page_table: PageTable::new(),
45            areas: Vec::new(),
46        }
47    }
48    pub fn token(&self) -> usize {
49        self.page_table.token()
50    }
51    /// Assume that no conflicts.
52    pub fn insert_framed_area(
53        &mut self,
54        start_va: VirtAddr,
55        end_va: VirtAddr,
56        permission: MapPermission,
57    ) {
58        self.push(
59            MapArea::new(start_va, end_va, MapType::Framed, permission),
60            None,
61        );
62    }
63    pub fn remove_area_with_start_vpn(&mut self, start_vpn: VirtPageNum) {
64        if let Some((idx, area)) = self
65            .areas
66            .iter_mut()
67            .enumerate()
68            .find(|(_, area)| area.vpn_range.get_start() == start_vpn)
69        {
70            area.unmap(&mut self.page_table);
71            self.areas.remove(idx);
72        }
73    }
74    fn push(&mut self, mut map_area: MapArea, data: Option<&[u8]>) {
75        map_area.map(&mut self.page_table);
76        if let Some(data) = data {
77            map_area.copy_data(&self.page_table, data);
78        }
79        self.areas.push(map_area);
80    }
81    /// Mention that trampoline is not collected by areas.
82    fn map_trampoline(&mut self) {
83        self.page_table.map(
84            VirtAddr::from(TRAMPOLINE).into(),
85            PhysAddr::from(strampoline as usize).into(),
86            PTEFlags::R | PTEFlags::X,
87        );
88    }
89    /// Without kernel stacks.
90    pub fn new_kernel() -> Self {
91        let mut memory_set = Self::new_bare();
92        // map trampoline
93        memory_set.map_trampoline();
94        // map kernel sections
95        println!(".text [{:#x}, {:#x})", stext as usize, etext as usize);
96        println!(".rodata [{:#x}, {:#x})", srodata as usize, erodata as usize);
97        println!(".data [{:#x}, {:#x})", sdata as usize, edata as usize);
98        println!(
99            ".bss [{:#x}, {:#x})",
100            sbss_with_stack as usize, ebss as usize
101        );
102        println!("mapping .text section");
103        memory_set.push(
104            MapArea::new(
105                (stext as usize).into(),
106                (etext as usize).into(),
107                MapType::Identical,
108                MapPermission::R | MapPermission::X,
109            ),
110            None,
111        );
112        println!("mapping .rodata section");
113        memory_set.push(
114            MapArea::new(
115                (srodata as usize).into(),
116                (erodata as usize).into(),
117                MapType::Identical,
118                MapPermission::R,
119            ),
120            None,
121        );
122        println!("mapping .data section");
123        memory_set.push(
124            MapArea::new(
125                (sdata as usize).into(),
126                (edata as usize).into(),
127                MapType::Identical,
128                MapPermission::R | MapPermission::W,
129            ),
130            None,
131        );
132        println!("mapping .bss section");
133        memory_set.push(
134            MapArea::new(
135                (sbss_with_stack as usize).into(),
136                (ebss as usize).into(),
137                MapType::Identical,
138                MapPermission::R | MapPermission::W,
139            ),
140            None,
141        );
142        println!("mapping physical memory");
143        memory_set.push(
144            MapArea::new(
145                (ekernel as usize).into(),
146                MEMORY_END.into(),
147                MapType::Identical,
148                MapPermission::R | MapPermission::W,
149            ),
150            None,
151        );
152        println!("mapping memory-mapped registers");
153        for pair in MMIO {
154            memory_set.push(
155                MapArea::new(
156                    (*pair).0.into(),
157                    ((*pair).0 + (*pair).1).into(),
158                    MapType::Identical,
159                    MapPermission::R | MapPermission::W,
160                ),
161                None,
162            );
163        }
164        memory_set
165    }
166    /// Include sections in elf and trampoline,
167    /// also returns user_sp_base and entry point.
168    pub fn from_elf(elf_data: &[u8]) -> (Self, usize, usize) {
169        let mut memory_set = Self::new_bare();
170        // map trampoline
171        memory_set.map_trampoline();
172        // map program headers of elf, with U flag
173        let elf = xmas_elf::ElfFile::new(elf_data).unwrap();
174        let elf_header = elf.header;
175        let magic = elf_header.pt1.magic;
176        assert_eq!(magic, [0x7f, 0x45, 0x4c, 0x46], "invalid elf!");
177        let ph_count = elf_header.pt2.ph_count();
178        let mut max_end_vpn = VirtPageNum(0);
179        for i in 0..ph_count {
180            let ph = elf.program_header(i).unwrap();
181            if ph.get_type().unwrap() == xmas_elf::program::Type::Load {
182                let start_va: VirtAddr = (ph.virtual_addr() as usize).into();
183                let end_va: VirtAddr = ((ph.virtual_addr() + ph.mem_size()) as usize).into();
184                let mut map_perm = MapPermission::U;
185                let ph_flags = ph.flags();
186                if ph_flags.is_read() {
187                    map_perm |= MapPermission::R;
188                }
189                if ph_flags.is_write() {
190                    map_perm |= MapPermission::W;
191                }
192                if ph_flags.is_execute() {
193                    map_perm |= MapPermission::X;
194                }
195                let map_area = MapArea::new(start_va, end_va, MapType::Framed, map_perm);
196                max_end_vpn = map_area.vpn_range.get_end();
197                memory_set.push(
198                    map_area,
199                    Some(&elf.input[ph.offset() as usize..(ph.offset() + ph.file_size()) as usize]),
200                );
201            }
202        }
203        let max_end_va: VirtAddr = max_end_vpn.into();
204        let mut user_stack_base: usize = max_end_va.into();
205        user_stack_base += PAGE_SIZE;
206        (
207            memory_set,
208            user_stack_base,
209            elf.header.pt2.entry_point() as usize,
210        )
211    }
212    pub fn from_existed_user(user_space: &MemorySet) -> MemorySet {
213        let mut memory_set = Self::new_bare();
214        // map trampoline
215        memory_set.map_trampoline();
216        // copy data sections/trap_context/user_stack
217        for area in user_space.areas.iter() {
218            let new_area = MapArea::from_another(area);
219            memory_set.push(new_area, None);
220            // copy data from another space
221            for vpn in area.vpn_range {
222                let src_ppn = user_space.translate(vpn).unwrap().ppn();
223                let dst_ppn = memory_set.translate(vpn).unwrap().ppn();
224                dst_ppn
225                    .get_bytes_array()
226                    .copy_from_slice(src_ppn.get_bytes_array());
227            }
228        }
229        memory_set
230    }
231    pub fn activate(&self) {
232        let satp = self.page_table.token();
233        unsafe {
234            satp::write(satp);
235            asm!("sfence.vma");
236        }
237    }
238    pub fn translate(&self, vpn: VirtPageNum) -> Option<PageTableEntry> {
239        self.page_table.translate(vpn)
240    }
241    pub fn recycle_data_pages(&mut self) {
242        //*self = Self::new_bare();
243        self.areas.clear();
244    }
245}
246
247pub struct MapArea {
248    vpn_range: VPNRange,
249    data_frames: BTreeMap<VirtPageNum, FrameTracker>,
250    map_type: MapType,
251    map_perm: MapPermission,
252}
253
254impl MapArea {
255    pub fn new(
256        start_va: VirtAddr,
257        end_va: VirtAddr,
258        map_type: MapType,
259        map_perm: MapPermission,
260    ) -> Self {
261        let start_vpn: VirtPageNum = start_va.floor();
262        let end_vpn: VirtPageNum = end_va.ceil();
263        Self {
264            vpn_range: VPNRange::new(start_vpn, end_vpn),
265            data_frames: BTreeMap::new(),
266            map_type,
267            map_perm,
268        }
269    }
270    pub fn from_another(another: &MapArea) -> Self {
271        Self {
272            vpn_range: VPNRange::new(another.vpn_range.get_start(), another.vpn_range.get_end()),
273            data_frames: BTreeMap::new(),
274            map_type: another.map_type,
275            map_perm: another.map_perm,
276        }
277    }
278    pub fn map_one(&mut self, page_table: &mut PageTable, vpn: VirtPageNum) {
279        let ppn: PhysPageNum;
280        match self.map_type {
281            MapType::Identical => {
282                ppn = PhysPageNum(vpn.0);
283            }
284            MapType::Framed => {
285                let frame = frame_alloc().unwrap();
286                ppn = frame.ppn;
287                self.data_frames.insert(vpn, frame);
288            }
289        }
290        let pte_flags = PTEFlags::from_bits(self.map_perm.bits).unwrap();
291        page_table.map(vpn, ppn, pte_flags);
292    }
293    pub fn unmap_one(&mut self, page_table: &mut PageTable, vpn: VirtPageNum) {
294        if self.map_type == MapType::Framed {
295            self.data_frames.remove(&vpn);
296        }
297        page_table.unmap(vpn);
298    }
299    pub fn map(&mut self, page_table: &mut PageTable) {
300        for vpn in self.vpn_range {
301            self.map_one(page_table, vpn);
302        }
303    }
304    pub fn unmap(&mut self, page_table: &mut PageTable) {
305        for vpn in self.vpn_range {
306            self.unmap_one(page_table, vpn);
307        }
308    }
309    /// data: start-aligned but maybe with shorter length
310    /// assume that all frames were cleared before
311    pub fn copy_data(&mut self, page_table: &PageTable, data: &[u8]) {
312        assert_eq!(self.map_type, MapType::Framed);
313        let mut start: usize = 0;
314        let mut current_vpn = self.vpn_range.get_start();
315        let len = data.len();
316        loop {
317            let src = &data[start..len.min(start + PAGE_SIZE)];
318            let dst = &mut page_table
319                .translate(current_vpn)
320                .unwrap()
321                .ppn()
322                .get_bytes_array()[..src.len()];
323            dst.copy_from_slice(src);
324            start += PAGE_SIZE;
325            if start >= len {
326                break;
327            }
328            current_vpn.step();
329        }
330    }
331}
332
333#[derive(Copy, Clone, PartialEq, Debug)]
334pub enum MapType {
335    Identical,
336    Framed,
337}
338
339bitflags! {
340    pub struct MapPermission: u8 {
341        const R = 1 << 1;
342        const W = 1 << 2;
343        const X = 1 << 3;
344        const U = 1 << 4;
345    }
346}
347
348#[allow(unused)]
349pub fn remap_test() {
350    let mut kernel_space = KERNEL_SPACE.exclusive_access();
351    let mid_text: VirtAddr = ((stext as usize + etext as usize) / 2).into();
352    let mid_rodata: VirtAddr = ((srodata as usize + erodata as usize) / 2).into();
353    let mid_data: VirtAddr = ((sdata as usize + edata as usize) / 2).into();
354    assert!(
355        !kernel_space
356            .page_table
357            .translate(mid_text.floor())
358            .unwrap()
359            .writable(),
360    );
361    assert!(
362        !kernel_space
363            .page_table
364            .translate(mid_rodata.floor())
365            .unwrap()
366            .writable(),
367    );
368    assert!(
369        !kernel_space
370            .page_table
371            .translate(mid_data.floor())
372            .unwrap()
373            .executable(),
374    );
375    println!("remap_test passed!");
376}