diff --git a/README.md b/README.md index e28da9b..00c2935 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,8 @@ Run --host 127.0.0.1 \ --port 8080 \ --worker 4 \ - --timeout-ms 2000 + --timeout-ms 2000 \ + --max-header-size 8192 ``` This command starts a server listening on `127.0.0.1:8080` with 4 preforked worker processes and a 2‑second accept timeout. @@ -83,4 +84,3 @@ Tests cover: - URL parsing (`http/http.rs`) - Echo server logic (`process/echo.rs`) - Worker manager integration (`worker/manager.rs`) - diff --git a/src/args.rs b/src/args.rs index 57bc4ff..319ad24 100644 --- a/src/args.rs +++ b/src/args.rs @@ -13,4 +13,6 @@ pub struct Args { pub worker: u32, #[arg(short, long, default_value_t = 500)] pub timeout_ms: u64, + #[arg(long, default_value_t = 8196)] + pub max_header_size: usize, } diff --git a/src/http/http.rs b/src/http/http.rs index 2c78d22..dbdcbf8 100644 --- a/src/http/http.rs +++ b/src/http/http.rs @@ -1,7 +1,7 @@ use std::{ collections::HashMap, io::{BufRead, BufReader, Write}, - net::TcpStream, + net::{SocketAddr, TcpStream}, time::{Duration, SystemTime}, }; @@ -17,6 +17,7 @@ use crate::{ }; pub struct Http1 { + max_header_length: usize, handler: T, } @@ -36,9 +37,14 @@ where log::trace!("Write timeout: {:?}", stream.write_timeout()); let mut reader = BufReader::new(&stream); - let header_res: Result<(usize, Vec), Error> = self.read_header(&mut reader); + let header_res: Result<(usize, Vec), Error> = + self.read_header(client_addr, &mut reader); if let Err(err) = header_res { - return Err(process::Error::IoFail(format!("Read header failed: ({})", err))); + self.error_response_for_invalid_request(&stream); + return Err(process::Error::IoFail(format!( + "Read header failed: ({})", + err + ))); } let (header_readed, headers) = header_res.unwrap(); @@ -46,15 +52,7 @@ where let res_request: Result, Error> = self.init_request(client_addr, &headers, reader); if let Err(err) = res_request { - let mut response = HttpResponse::new(HttpVersion::default(), &stream); - - response.set_response_code(HttpResponseCode::BadRequest); - response.set_header(&server(HttpHeaderValue::Str("server_rs"))); - response.set_header(&content_type(HttpHeaderValue::Str("text/plain"))); - response.set_header(&date(SystemTime::now())); - let _ = response.write("Invalid request".as_bytes()); - let _ = response.flush(); - + self.error_response_for_invalid_request(&stream); return Err(process::Error::ParseFail(err.to_string())); } @@ -81,12 +79,16 @@ impl Http1 where T: Handler, { - pub fn new(handler: T) -> Self { - return Http1 { handler }; + pub fn new(max_header_length: usize, handler: T) -> Self { + return Http1 { + max_header_length, + handler, + }; } fn read_header<'a>( &self, + client_addr: &SocketAddr, reader: &mut BufReader<&'a TcpStream>, ) -> Result<(usize, Vec), Error> { let mut res = vec![]; @@ -99,6 +101,13 @@ where } readed += result.unwrap(); + if readed > self.max_header_length { + return Err(Error::BadRequest( + client_addr.clone(), + "header size limit exceed", + )); + } + while buf .chars() .nth(0) @@ -185,6 +194,17 @@ where return header_map; } + + fn error_response_for_invalid_request(&self, stream: &TcpStream) { + let mut response = HttpResponse::new(HttpVersion::default(), stream); + + response.set_response_code(HttpResponseCode::BadRequest); + response.set_header(&server(HttpHeaderValue::Str("server_rs"))); + response.set_header(&content_type(HttpHeaderValue::Str("text/plain"))); + response.set_header(&date(SystemTime::now())); + let _ = response.write("Invalid request".as_bytes()); + let _ = response.flush(); + } } fn parse_url(query: &str) -> (String, HashMap<&str, Vec<&str>>) { diff --git a/src/http/value.rs b/src/http/value.rs index f35db92..6f7fe43 100644 --- a/src/http/value.rs +++ b/src/http/value.rs @@ -1,4 +1,4 @@ -use std::{fmt::Display, hash::Hash}; +use std::{fmt::Display, hash::Hash, net::SocketAddr}; pub enum HttpVersion { Http10, @@ -139,10 +139,10 @@ impl HttpResponseCode { #[allow(dead_code)] #[derive(Debug, Clone)] pub enum Error { - #[allow(dead_code)] ParseFail(String), ReadFail(String), WriteFail(String), + BadRequest(SocketAddr, &'static str), } impl std::fmt::Display for Error { @@ -151,6 +151,7 @@ impl std::fmt::Display for Error { Error::ParseFail(m) => ("parse fail", m), Error::ReadFail(m) => ("read fail", m), Error::WriteFail(m) => ("write fail", m), + Error::BadRequest(remote, msg) => ("bad request", &format!("{} {}", remote, msg)), }; return f.write_fmt(format_args!("HttpError: [{}] {}", name.0, name.1)); diff --git a/src/main.rs b/src/main.rs index f1829f3..bfd4374 100644 --- a/src/main.rs +++ b/src/main.rs @@ -69,7 +69,7 @@ fn main() { host: arg.host.clone(), port: arg.port, worker: arg.worker, - process: Rc::new(Http1::new(SimpleHandler)), + process: Rc::new(Http1::new(arg.max_header_size, SimpleHandler)), }]; let mut server = Server::new(ServerArgs {