Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ Run
--host 127.0.0.1 \
--port 8080 \
--worker 4 \
--timeout-ms 2000
--timeout-ms 2000 \
--max-header-size 8192
```

This command starts a server listening on `127.0.0.1:8080` with 4 preforked worker processes and a 2‑second accept timeout.
Expand All @@ -83,4 +84,3 @@ Tests cover:
- URL parsing (`http/http.rs`)
- Echo server logic (`process/echo.rs`)
- Worker manager integration (`worker/manager.rs`)

2 changes: 2 additions & 0 deletions src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ pub struct Args {
pub worker: u32,
#[arg(short, long, default_value_t = 500)]
pub timeout_ms: u64,
#[arg(long, default_value_t = 8196)]
pub max_header_size: usize,
}
48 changes: 34 additions & 14 deletions src/http/http.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{
collections::HashMap,
io::{BufRead, BufReader, Write},
net::TcpStream,
net::{SocketAddr, TcpStream},
time::{Duration, SystemTime},
};

Expand All @@ -17,6 +17,7 @@ use crate::{
};

pub struct Http1<T: Handler> {
max_header_length: usize,
handler: T,
}

Expand All @@ -36,25 +37,22 @@ where
log::trace!("Write timeout: {:?}", stream.write_timeout());

let mut reader = BufReader::new(&stream);
let header_res: Result<(usize, Vec<String>), Error> = self.read_header(&mut reader);
let header_res: Result<(usize, Vec<String>), Error> =
self.read_header(client_addr, &mut reader);
if let Err(err) = header_res {
return Err(process::Error::IoFail(format!("Read header failed: ({})", err)));
self.error_response_for_invalid_request(&stream);
return Err(process::Error::IoFail(format!(
"Read header failed: ({})",
err
)));
}

let (header_readed, headers) = header_res.unwrap();

let res_request: Result<HttpRequest<'_>, Error> =
self.init_request(client_addr, &headers, reader);
if let Err(err) = res_request {
let mut response = HttpResponse::new(HttpVersion::default(), &stream);

response.set_response_code(HttpResponseCode::BadRequest);
response.set_header(&server(HttpHeaderValue::Str("server_rs")));
response.set_header(&content_type(HttpHeaderValue::Str("text/plain")));
response.set_header(&date(SystemTime::now()));
let _ = response.write("Invalid request".as_bytes());
let _ = response.flush();

self.error_response_for_invalid_request(&stream);
return Err(process::Error::ParseFail(err.to_string()));
}

Expand All @@ -81,12 +79,16 @@ impl<T> Http1<T>
where
T: Handler,
{
pub fn new(handler: T) -> Self {
return Http1 { handler };
pub fn new(max_header_length: usize, handler: T) -> Self {
return Http1 {
max_header_length,
handler,
};
}

fn read_header<'a>(
&self,
client_addr: &SocketAddr,
reader: &mut BufReader<&'a TcpStream>,
) -> Result<(usize, Vec<String>), Error> {
let mut res = vec![];
Expand All @@ -99,6 +101,13 @@ where
}
readed += result.unwrap();

if readed > self.max_header_length {
return Err(Error::BadRequest(
client_addr.clone(),
"header size limit exceed",
));
}

while buf
.chars()
.nth(0)
Expand Down Expand Up @@ -185,6 +194,17 @@ where

return header_map;
}

fn error_response_for_invalid_request(&self, stream: &TcpStream) {
let mut response = HttpResponse::new(HttpVersion::default(), stream);

response.set_response_code(HttpResponseCode::BadRequest);
response.set_header(&server(HttpHeaderValue::Str("server_rs")));
response.set_header(&content_type(HttpHeaderValue::Str("text/plain")));
response.set_header(&date(SystemTime::now()));
let _ = response.write("Invalid request".as_bytes());
let _ = response.flush();
}
}

fn parse_url(query: &str) -> (String, HashMap<&str, Vec<&str>>) {
Expand Down
5 changes: 3 additions & 2 deletions src/http/value.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{fmt::Display, hash::Hash};
use std::{fmt::Display, hash::Hash, net::SocketAddr};

pub enum HttpVersion {
Http10,
Expand Down Expand Up @@ -139,10 +139,10 @@ impl HttpResponseCode {
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub enum Error {
#[allow(dead_code)]
ParseFail(String),
ReadFail(String),
WriteFail(String),
BadRequest(SocketAddr, &'static str),
}

impl std::fmt::Display for Error {
Expand All @@ -151,6 +151,7 @@ impl std::fmt::Display for Error {
Error::ParseFail(m) => ("parse fail", m),
Error::ReadFail(m) => ("read fail", m),
Error::WriteFail(m) => ("write fail", m),
Error::BadRequest(remote, msg) => ("bad request", &format!("{} {}", remote, msg)),
};

return f.write_fmt(format_args!("HttpError: [{}] {}", name.0, name.1));
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ fn main() {
host: arg.host.clone(),
port: arg.port,
worker: arg.worker,
process: Rc::new(Http1::new(SimpleHandler)),
process: Rc::new(Http1::new(arg.max_header_size, SimpleHandler)),
}];

let mut server = Server::new(ServerArgs {
Expand Down