Skip to content

Commit cb40957

Browse files
committed
增加多线程,在同一块内存上读写,用barrier进行同步,没有使用mpsc。
目前行为不对,待改正
1 parent 9ae93ad commit cb40957

10 files changed

+189
-124
lines changed

Cargo.lock

+10
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[package]
22
name = "rust-ppm"
33
version = "0.1.0"
4+
authors = ["ChrisZhang <[email protected]>"]
45
edition = "2018"
56

67
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -13,4 +14,5 @@ rand = "0.8.4"
1314
lazy_static = "1.4.0"
1415
image = "0.23.14"
1516
num-complex = "0.4.0"
16-
adqselect = "0.1.3"
17+
adqselect = "0.1.3"
18+
threadpool = "1.8.1"

src/camera.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use core::f64;
2-
use std::rc::Rc;
2+
use std::sync::Arc;
33
use json::JsonValue;
44
use vecmat::matrix::Matrix3x3;
55
use vecmat::{Matrix, Vector, traits::Dot, vector::Vector3};
@@ -107,7 +107,7 @@ impl Camera for DoFCamera {
107107
]).transpose();
108108
let dir = rot.dot(dir).normalize();
109109
let temp_ray = Ray::new(self.center, dir, None);
110-
let temp_material = Rc::new(DiffuseMaterial::new(Vector3::<f64>::from([1., 1., 1.])));
110+
let temp_material = Arc::new(DiffuseMaterial::new(Vector3::<f64>::from([1., 1., 1.])));
111111
let focus_plane = Plane::new(temp_material, self.direction, self.focus_dist);
112112
let hit = focus_plane.intersect(&temp_ray, 0.015).unwrap(); //这里保证有交
113113
let delta: Vector3<f64> = self.aperture * (normal_x * self.horizental + normal_y * self.up);
@@ -123,7 +123,7 @@ impl Camera for DoFCamera {
123123
}
124124
}
125125

126-
pub fn build_camera(camera_attr: &JsonValue) -> Box<dyn Camera> {
126+
pub fn build_camera(camera_attr: &JsonValue) -> Arc<dyn Camera + Send + Sync> {
127127
let cam_type = camera_attr["Type"].as_str().unwrap();
128128
let center = parse_vector(&camera_attr["Center"]);
129129
let direction = parse_vector(&camera_attr["Direction"]);
@@ -132,11 +132,11 @@ pub fn build_camera(camera_attr: &JsonValue) -> Box<dyn Camera> {
132132
let width = camera_attr["Width"].as_u32().unwrap();
133133
let height = camera_attr["Height"].as_u32().unwrap();
134134
match cam_type {
135-
"Perspective" => Box::new(PerspectiveCamera::new(center, direction, up, angle, width, height)),
135+
"Perspective" => Arc::new(PerspectiveCamera::new(center, direction, up, angle, width, height)),
136136
"DoF" => {
137137
let focus = parse_vector(&camera_attr["Focus"]);
138138
let aperture = camera_attr["Aperture"].as_f64().unwrap();
139-
Box::new(DoFCamera::new(center, direction, up, angle, width, height, focus, aperture))
139+
Arc::new(DoFCamera::new(center, direction, up, angle, width, height, focus, aperture))
140140
},
141141
_ => panic!("Invalid Camera Type!")
142142
}

src/hit.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
use crate::materials::Material;
22
use vecmat::vector::Vector3;
3-
use std::rc::Rc;
3+
use std::sync::Arc;
44

55
pub struct Hit {
66
t: f64,
7-
material: Rc<dyn Material>,
7+
material: Arc<dyn Material>,
88
normal: Vector3::<f64>
99
}
1010

1111
impl Hit {
12-
pub fn new(t: f64, material: Rc<dyn Material>, normal: Vector3::<f64>) -> Self {
12+
pub fn new(t: f64, material: Arc<dyn Material>, normal: Vector3::<f64>) -> Self {
1313
Self { t, material, normal }
1414
}
1515

1616
pub fn clone_obj(&self) -> Self {
17-
Self { t: self.t, material: Rc::clone(&self.material), normal: self.normal }
17+
Self { t: self.t, material: Arc::clone(&self.material), normal: self.normal }
1818
}
1919

2020
pub fn get_t(&self) -> f64 {
2121
self.t
2222
}
2323

24-
pub fn get_material(&self) -> &Rc<dyn Material> {
24+
pub fn get_material(&self) -> &Arc<dyn Material> {
2525
&self.material
2626
}
2727

src/lights.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use core::f64;
2-
2+
use std::sync::Arc;
33
use vecmat::vector::Vector3;
44
use json::JsonValue;
55
use crate::ray::Ray;
@@ -108,26 +108,26 @@ impl Light for DirectionCircleLight {
108108
}
109109
}
110110

111-
pub fn build_light(light_attr: &JsonValue) -> Box<dyn Light> {
111+
pub fn build_light(light_attr: &JsonValue) -> Arc<dyn Light + Send + Sync> {
112112
let light_type = light_attr["Type"].as_str().unwrap();
113113
let scale = light_attr["Scale"].as_f64().unwrap();
114114
let pos = parse_vector(&light_attr["Position"]);
115115
let flux = parse_vector(&light_attr["Flux"]);
116116
match light_type {
117-
"SphereLiht" => Box::new(SphereLight::new(Some(scale), pos, flux)),
117+
"SphereLiht" => Arc::new(SphereLight::new(Some(scale), pos, flux)),
118118
"ConeLight" => {
119119
let normal = parse_vector(&light_attr["Normal"]);
120120
let angle = light_attr["Angle"].as_f64().unwrap();
121-
Box::new(ConeLight::new(Some(scale), pos, normal, flux, angle))
121+
Arc::new(ConeLight::new(Some(scale), pos, normal, flux, angle))
122122
},
123123
"HalfSphereLight" => {
124124
let normal = parse_vector(&light_attr["Normal"]);
125-
Box::new(ConeLight::new(Some(scale), pos, normal, flux, 90.))
125+
Arc::new(ConeLight::new(Some(scale), pos, normal, flux, 90.))
126126
},
127127
"DirectionCircleLight" => {
128128
let normal = parse_vector(&light_attr["Normal"]);
129129
let radius = light_attr["Radius"].as_f64().unwrap();
130-
Box::new(DirectionCircleLight::new(Some(scale), pos, normal, flux, radius))
130+
Arc::new(DirectionCircleLight::new(Some(scale), pos, normal, flux, radius))
131131
},
132132
_ => {
133133
panic!("Wrong light type!");

src/main.rs

+107-54
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
use core::f64;
2+
use std::sync::{Arc, RwLock, Barrier};
3+
use std::{thread, u32};
24
use std::{env, usize};
35
mod sceneparser;
46
mod camera;
@@ -14,44 +16,76 @@ use image::ImageError;
1416
use materials::MaterialType;
1517
use object3d::Group;
1618
use object3d::Object3d;
19+
// use threadpool::ThreadPool;
1720
use vecmat::vector::Vector2;
18-
use vecmat::vector::Vector3;
19-
use image::{Rgb, ImageResult, ImageBuffer};
20-
21+
use image::{Rgb, ImageResult};
22+
use image::RgbImage;
23+
// use image::ImageBuffer;
2124
use crate::{photon::{HitPoint, KDTree, Photon}, sceneparser::build_sceneparser};
2225
use crate::ray::Ray;
2326
use crate::matrix::trunc;
24-
use rand::{thread_rng, Rng};
2527

2628
static PHOTON_NUMBER: u32 = 100000;
2729
static ROUND_NUMBER: u32 = 3;
2830
static SAMPLE_NUMBER: u32 = 3;
29-
static _PARALLEL_NUMBER: u32 = 8;
30-
static _PHOTONS_PER_ROUND: u32 = PHOTON_NUMBER / _PARALLEL_NUMBER;
31+
static PARALLEL_NUMBER: usize = 8;
32+
static _PHOTONS_PER_ROUND: u32 = PHOTON_NUMBER / PARALLEL_NUMBER as u32;
3133
static TMIN: f64 = 0.015;
3234

33-
fn render(pic: &Vec<Vec<HitPoint>>, output_file: &str) -> ImageResult<()> {
34-
let width = pic.len() as u32;
35-
let height = pic[0].len() as u32;
36-
let img = ImageBuffer::from_fn(
37-
width,
38-
height,
39-
|x, y| {
40-
let point = &pic[x as usize][(height - 1 - y) as usize];
41-
let area = f64::consts::PI * point.radius * point.radius;
42-
let number = (PHOTON_NUMBER * ROUND_NUMBER) as f64;
43-
Rgb([
44-
trunc(point.tau.x() / (area * number)),
45-
trunc(point.tau.y() / (area * number)),
46-
trunc(point.tau.z() / (area * number)),
47-
])
35+
fn render(pic: &Vec<Arc<RwLock<Vec<Vec<HitPoint>>>>>, output_file: &str, width: u32, height: u32) -> ImageResult<()> {
36+
let number = (PHOTON_NUMBER * ROUND_NUMBER) as f64;
37+
let mut img = RgbImage::new(width, height);
38+
// for x in 0 .. width {
39+
// let interval = width as usize / PARALLEL_NUMBER;
40+
// let dim_1 = x as usize / interval;
41+
// let dim_2 = x as usize % interval;
42+
// // let point_vector = &pic[dim_1].read().unwrap();
43+
// for y in 0 .. height {
44+
// let point = &pic[dim_1].read().unwrap()[dim_2][(height - 1 - y) as usize];
45+
// // let point = &point_vector[dim_2][(height - 1 - y) as usize];
46+
// let area = f64::consts::PI * point.radius * point.radius;
47+
// *img.get_pixel_mut(x, y) = Rgb([
48+
// trunc(point.tau.x() / (area * number)),
49+
// trunc(point.tau.y() / (area * number)),
50+
// trunc(point.tau.z() / (area * number)),
51+
// ]);
52+
// }
53+
// }
54+
for dim_1 in 0 .. PARALLEL_NUMBER {
55+
for dim_2 in 0 .. width as usize / PARALLEL_NUMBER {
56+
for y in 0 .. height {
57+
let x = dim_1 * width as usize / PARALLEL_NUMBER + dim_2;
58+
let point = &pic[dim_1].read().unwrap()[dim_2][(height - 1 - y) as usize];
59+
let area = f64::consts::PI * point.radius * point.radius;
60+
*img.get_pixel_mut(x as u32, y) = Rgb([
61+
trunc(point.tau.x() / (area * number)),
62+
trunc(point.tau.y() / (area * number)),
63+
trunc(point.tau.z() / (area * number)),
64+
]);
65+
}
4866
}
49-
);
67+
}
68+
// let img = ImageBuffer::from_fn(
69+
// width,
70+
// height,
71+
// |x, y| {
72+
// let interval = width as usize / PARALLEL_NUMBER;
73+
// let dim_1 = x as usize / interval;
74+
// let dim_2 = x as usize % interval;
75+
// let point = &pic[dim_1].read().unwrap()[dim_2][(height - 1 - y) as usize];
76+
// let area = f64::consts::PI * point.radius * point.radius;
77+
// Rgb([
78+
// trunc(point.tau.x() / (area * number)),
79+
// trunc(point.tau.y() / (area * number)),
80+
// trunc(point.tau.z() / (area * number)),
81+
// ])
82+
// }
83+
// );
5084
img.save(output_file)?;
5185
Ok(())
5286
}
5387

54-
fn photon_trace(group: &Box<Group>, mut ray: Ray, photon_map: &mut Vec<Photon>) {
88+
fn photon_trace(group: &Arc<Group>, mut ray: Ray, photon_map: &mut Vec<Photon>) {
5589
let mut depth = 0;
5690
loop {
5791
if depth > 100 {
@@ -84,8 +118,7 @@ fn photon_trace(group: &Box<Group>, mut ray: Ray, photon_map: &mut Vec<Photon>)
84118
}
85119

86120
fn ray_trace(
87-
x: usize, y: usize, group: &Box<Group>, mut ray: Ray, kd_tree: &KDTree,
88-
picture: &Vec<Vec<HitPoint>>, buffer: &mut Vec<Vec<HitPoint>>
121+
group: &Arc<Group>, mut ray: Ray, kd_tree: &Arc<KDTree>, radius: f64, buffer_pixel: &mut HitPoint
89122
) {
90123
let mut depth = 0;
91124
loop {
@@ -100,9 +133,9 @@ fn ray_trace(
100133
depth += 1;
101134
match material.get_type() {
102135
&MaterialType::DIFFUSE => {
103-
buffer[x][y].radius = picture[x][y].radius;
104-
buffer[x][y].pos = Some(position);
105-
kd_tree.search(&mut buffer[x][y], color, hit.get_normal(), ray.get_flux());
136+
buffer_pixel.radius = radius;
137+
buffer_pixel.pos = Some(position);
138+
kd_tree.search(buffer_pixel, color, hit.get_normal(), ray.get_flux());
106139
break;
107140
},
108141
&MaterialType::SPECULAR | &MaterialType::REFRACTION => {
@@ -127,10 +160,10 @@ fn main() -> Result<(), ImageError> {
127160
let group = parser.group;
128161
let width = camera.get_width() as usize;
129162
let height = camera.get_height() as usize;
130-
let mut picture = vec![vec![HitPoint::new(); height]; width];
131-
let mut buffer = vec![vec![HitPoint::new(); height]; width];
163+
let pictures: Vec<Arc<RwLock<Vec<Vec<HitPoint>>>>> =
164+
vec![Arc::new(RwLock::new(vec![vec![HitPoint::new(); height]; width / PARALLEL_NUMBER])); PARALLEL_NUMBER];
165+
let barrier = Arc::new(Barrier::new(PARALLEL_NUMBER + 1));
132166

133-
let mut rng = thread_rng();
134167
for round in 0 .. ROUND_NUMBER {
135168
let mut photon_map: Vec<Photon> = Vec::new();
136169
for light in &lights {
@@ -141,34 +174,54 @@ fn main() -> Result<(), ImageError> {
141174
}
142175
println!("Round {} photon pass complete", &round);
143176
let kd_tree = KDTree::new(photon_map);
177+
let arc_kd_tree = Arc::new(kd_tree);
144178
println!("Round {} kdtree build complete", &round);
145-
for x in 0 .. width {
146-
for y in 0 .. height {
147-
buffer[x][y].tau = Vector3::<f64>::from([0., 0., 0.]);
148-
buffer[x][y].n = 0.;
149-
for _ in 0 .. SAMPLE_NUMBER {
150-
let mut ray = camera.generate_ray(&Vector2::<f64>::from([
151-
x as f64 + rng.gen_range(0. .. 1.),
152-
y as f64 + rng.gen_range(0. .. 1.)
153-
]));
154-
ray.set_color(*ray.get_flux() / (SAMPLE_NUMBER as f64));
155-
ray_trace(x, y, &group, ray, &kd_tree, &picture, &mut buffer);
156-
}
157-
if round == 0 {
158-
picture[x][y].n = buffer[x][y].n;
159-
picture[x][y].tau = buffer[x][y].tau;
160-
} else {
161-
if picture[x][y].n + buffer[x][y].n > 0. {
162-
let ratio = (picture[x][y].n + photon::ALPHA * buffer[x][y].n) / (picture[x][y].n + buffer[x][y].n);
163-
picture[x][y].radius *= f64::sqrt(ratio);
164-
picture[x][y].tau = (picture[x][y].tau + buffer[x][y].tau) * ratio;
165-
picture[x][y].n += buffer[x][y].n * ratio;
179+
for i in 0 .. PARALLEL_NUMBER as usize {
180+
let group = group.clone();
181+
let camera = camera.clone();
182+
let arc_kd_tree = arc_kd_tree.clone();
183+
let picture = pictures[i].clone();
184+
let barrier = barrier.clone();
185+
thread::spawn( move || {
186+
let column_begin = width * i / PARALLEL_NUMBER;
187+
let column_end = width * (i + 1) / PARALLEL_NUMBER;
188+
println!("thread {} spawns with column range [{}, {})", &i, &column_begin, &column_end);
189+
let mut buffer = vec![vec![HitPoint::new(); height]; column_end - column_begin];
190+
let mut picture = picture.write().unwrap();
191+
for (x, global_x) in (column_begin .. column_end).enumerate() {
192+
for y in 0 .. height {
193+
let buffer_pixel = &mut buffer[x][y];
194+
let picture_pixel = &mut picture[x][y];
195+
for _ in 0 .. SAMPLE_NUMBER {
196+
let dest_x = global_x as f64 + rand::random::<f64>();
197+
let dest_y = y as f64 + rand::random::<f64>();
198+
let mut ray = camera.generate_ray(&Vector2::<f64>::from([
199+
dest_x,
200+
dest_y
201+
]));
202+
ray.set_color(*ray.get_flux() / (SAMPLE_NUMBER as f64));
203+
ray_trace(&group, ray, &arc_kd_tree, picture_pixel.radius, buffer_pixel);
204+
}
205+
if round == 0 {
206+
picture_pixel.n = buffer_pixel.n;
207+
picture_pixel.tau = buffer_pixel.tau;
208+
} else {
209+
if picture_pixel.n + buffer_pixel.n > 0. {
210+
let ratio = (picture_pixel.n + photon::ALPHA * buffer_pixel.n) / (picture_pixel.n + buffer_pixel.n);
211+
picture_pixel.radius = picture_pixel.radius * f64::sqrt(ratio);
212+
picture_pixel.tau = (picture_pixel.tau + buffer_pixel.tau) * ratio;
213+
picture_pixel.n = picture_pixel.n + buffer_pixel.n * ratio;
214+
}
215+
}
166216
}
167217
}
168-
}
218+
drop(picture);
219+
barrier.wait();
220+
});
169221
}
222+
barrier.wait();
170223
println!("Round {} complete", &round);
171224
}
172-
render(&picture, &output_file)?;
225+
render(&pictures, &output_file, width as u32, height as u32)?;
173226
Ok(())
174227
}

0 commit comments

Comments
 (0)