Move away from OO, IRC command parser + working test

This commit is contained in:
Neale Pickett 2008-02-27 22:50:27 -07:00
parent ea2fe1ed1c
commit ef64cd9f8a
11 changed files with 373 additions and 378 deletions

View File

@ -2,10 +2,11 @@ USE_OCAMLFIND = true
OCAMLPACKS[] = OCAMLPACKS[] =
equeue equeue
pcre pcre
str
.DEFAULT: pgircd .DEFAULT: ircd
OCamlProgram(pgircd, pgircd ircd connection) OCamlProgram(ircd, ircd irc client server)
section section
OCAMLPACKS[] += OCAMLPACKS[] +=
@ -15,7 +16,7 @@ section
tests.cmi: tests.cmi:
tests$(EXT_OBJ): tests$(EXT_OBJ):
OCamlProgram(tests, tests chat ircd connection) OCamlProgram(tests, tests chat ircd irc client server)
.PHONY: clean .PHONY: clean
clean: clean:

68
chat.ml
View File

@ -1,5 +1,7 @@
open Unixqueue open Unixqueue
exception Buffer_overrun
type chat_event = type chat_event =
| Send of string | Send of string
| Recv of string | Recv of string
@ -37,18 +39,73 @@ let read_fd fd =
String.sub buf 0 len String.sub buf 0 len
class chat_handler chatscript (ues : unix_event_system) fd = class chat_handler chatscript
?(input_timeout=0.1)
?(output_timeout = 0.1)
?(output_max = 4096)
?(input_max = 4096)
(ues : unix_event_system) fd =
object (self) object (self)
inherit Connection.bare_connection ~input_timeout:0.1 ~output_timeout:0.1 ues fd val g = ues#new_group ()
val mutable debug = false
val obuf = String.create output_max
val mutable obuf_len = 0
val mutable script = chatscript val mutable script = chatscript
val inbuf = Buffer.create 4096 val inbuf = Buffer.create 4096
initializer initializer
ues#add_handler g self#handle_event;
ues#add_resource g (Wait_in fd, input_timeout);
self#run_script (); self#run_script ();
method handle_timeout op = method write data =
let data_len = String.length data in
if (data_len + obuf_len > output_max) then
raise Buffer_overrun;
String.blit data 0 obuf obuf_len data_len;
obuf_len <- obuf_len + data_len;
ues#add_resource g (Wait_out fd, output_timeout)
method handle_event ues esys e =
match e with
| Input_arrived (g, fd) ->
let data = String.create input_max in
let len = Unix.read fd data 0 input_max in
if (len > 0) then
begin
Buffer.add_string inbuf (String.sub data 0 len);
self#run_script ()
end
else
begin
Unix.close fd;
ues#clear g;
end
| Output_readiness (g, fd) ->
let size = obuf_len in
let n = Unix.single_write fd obuf 0 size in
obuf_len <- obuf_len - n;
if (obuf_len = 0) then
(* Don't check for output readiness anymore *)
begin
ues#remove_resource g (Wait_out fd)
end
else
(* Put unwritten output back into the output queue *)
begin
String.blit obuf n obuf 0 (obuf_len)
end
| Out_of_band (g, fd) ->
raise (Failure "Out of band data")
| Timeout (g, op) ->
raise (Chat_timeout (List.hd script)) raise (Chat_timeout (List.hd script))
| Signal ->
raise (Failure "Signal")
| Extra exn ->
raise (Failure "Extra")
method run_script () = method run_script () =
match script with match script with
@ -80,11 +137,6 @@ object (self)
else else
() ()
method handle_input data =
Buffer.add_string inbuf data;
self#run_script ()
end end

93
client.ml Normal file
View File

@ -0,0 +1,93 @@
open Irc
(* ==========================================
* Client stuff
*)
let ibuf_max = 4096
let max_outq = 50
let obuf_max = 4096
let shutdown ues g fd =
Unix.close fd;
Unixqueue.remove_resource ues g (Unixqueue.Wait_in fd);
try
Unixqueue.remove_resource ues g (Unixqueue.Wait_out fd);
with Not_found ->
()
let write cli line =
let was_empty = Queue.is_empty cli.outq in
Queue.add line cli.outq;
if was_empty then
cli.output_ready ()
let handle_close srv cli =
()
let handle_command_login srv cli command =
(* Handle a command during the login phase *)
match command.command with
| "USER"
| "NICK" ->
()
| _ ->
print_endline "NO CAN DO SIR"
let rec handle_input srv cli =
match cli.ibuf with
| "" ->
()
| ibuf ->
let p = String.index ibuf '\n' in
let s = String.sub ibuf 0 p in
if p >= !(cli.ibuf_len) then
raise Not_found;
cli.ibuf_len := !(cli.ibuf_len) - (p + 1);
String.blit ibuf (p + 1) ibuf 0 !(cli.ibuf_len);
let parsed = Irc.command_of_string s in
cli.handle_command srv cli parsed;
handle_input srv cli
let create_event_handler srv =
fun ues esys e ->
match e with
| Unixqueue.Input_arrived (g, fd) ->
let cli = Server.get_client_by_file_descr srv fd in
let size = ibuf_max - !(cli.ibuf_len) in
let len = Unix.read fd cli.ibuf !(cli.ibuf_len) size in
if (len > 0) then
begin
cli.ibuf_len := !(cli.ibuf_len) + len;
try
handle_input srv cli
with Not_found ->
if (!(cli.ibuf_len) = ibuf_max) then
(* No newline found, and the buffer is full *)
raise (Failure "Buffer overrun");
end
else
shutdown ues g fd
| Unixqueue.Output_readiness (g, fd) ->
print_endline "out"
| Unixqueue.Out_of_band (g, fd) ->
print_endline "oob"
| Unixqueue.Timeout (g, op) ->
print_endline "timeout"
| Unixqueue.Signal ->
print_endline "signal"
| Unixqueue.Extra exn ->
print_endline "extra"
let create ues g fd =
{outq = Queue.create ();
unsent = ref "";
ibuf = String.create ibuf_max;
ibuf_len = ref 0;
output_ready =
begin
fun () -> Unixqueue.add_resource ues g (Unixqueue.Wait_out fd, -.1.0)
end;
handle_command = handle_command_login;
channels = []}

5
client.mli Normal file
View File

@ -0,0 +1,5 @@
val create : Unixqueue.event_system -> Unixqueue.group -> Unix.file_descr -> Irc.client
val create_event_handler : Irc.server -> (Unixqueue.event_system -> Unixqueue.event Equeue.t -> Unixqueue.event -> unit)
val write : Irc.client -> string list -> unit

View File

@ -1,250 +0,0 @@
open Unixqueue
exception Buffer_overrun
(** Generic equeue connection class. *)
class virtual connection
(ues : unix_event_system)
?(input_timeout = -.1.0)
fd =
object (self)
val g = ues#new_group ()
val mutable debug = false
initializer
ues#add_handler g self#handle_event;
ues#add_resource g (Wait_in fd, input_timeout)
method debug v =
debug <- v
method log msg =
if debug then
print_endline msg
method handle_event ues esys e =
match e with
| Input_arrived (g, fd) ->
self#input_ready fd
| Output_readiness (g, fd) ->
self#output_ready fd
| Out_of_band (g, fd) ->
self#handle_oob fd
| Timeout (g, op) ->
self#handle_timeout op
| Signal ->
self#handle_signal ()
| Extra exn ->
self#handle_extra exn
method virtual output_ready : Unix.file_descr -> unit
method virtual input_ready : Unix.file_descr -> unit
method handle_oob fd =
self#log "Unhandled OOB";
raise Equeue.Reject
method handle_timeout op =
self#log "Unhandled timeout";
raise Equeue.Reject
method handle_signal () =
self#log "Unhandled signal";
raise Equeue.Reject
method handle_extra exn =
self#log "Unhandled extra";
raise Equeue.Reject
method handle_close () =
self#log "Closed"
end
(** Bare connection for reading and writing.
You can inherit this and define appropriate [handle_*] methods.
A [write] method is provided for your convenience.
*)
class bare_connection
(ues : unix_event_system)
?(input_timeout = -.1.0)
?(output_timeout = -.1.0)
?(input_max = 1024)
?(output_max = 1024)
fd =
object (self)
inherit connection ues ~input_timeout fd
val obuf = String.create output_max
val mutable obuf_len = 0
method write data =
let data_len = String.length data in
if (data_len + obuf_len > output_max) then
raise Buffer_overrun;
String.blit data 0 obuf obuf_len data_len;
obuf_len <- obuf_len + data_len;
ues#add_resource g (Wait_out fd, output_timeout)
method output_ready fd =
let size = obuf_len in
let n = Unix.single_write fd obuf 0 size in
obuf_len <- obuf_len - n;
if (obuf_len = 0) then
(* Don't check for output readiness anymore *)
begin
ues#remove_resource g (Wait_out fd)
end
else
(* Put unwritten output back into the output queue *)
begin
String.blit obuf n obuf 0 (obuf_len)
end
method input_ready fd =
let data = String.create input_max in
let len = Unix.read fd data 0 input_max in
if (len > 0) then
self#handle_input (String.sub data 0 len)
else
begin
self#handle_close ();
Unix.close fd;
ues#clear g;
end
method handle_input data =
self#log ("<-- [" ^ (String.escaped data) ^ "]")
end
(** Write s to fd, returning any unwritten data. *)
let write fd s =
let sl = String.length s in
let n = Unix.single_write fd s 0 sl in
(String.sub s n (sl - n))
(** Buffered connection class.
Input is split by newlines and sent to [handle_line].
Output is done with [write]. Send a list of words to be joined by a
space. This is intended to make one-to-many communications more
memory-efficient: the common strings need not be copied to all
recipients.
*)
class virtual buffered_connection
(ues : unix_event_system)
?(output_timeout = -.1.0)
?(ibuf_max = 4096)
?(max_outq = 50)
?(max_unsent = 4096)
fd =
object (self)
inherit connection ues fd
(* This allocates a string of length ibuf_max for each connection.
That could add up. *)
val mutable ibuf = String.create ibuf_max
val mutable ibuf_len = 0
val mutable unsent = ""
val mutable outq = Queue.create ()
method output_ready fd =
(* This could be better optimized, I'm sure. *)
match (unsent, Queue.is_empty outq) with
| ("", true) ->
ues#remove_resource g (Wait_out fd)
| ("", false) ->
let s = (String.concat " " (Queue.pop outq)) ^ "\n" in
unsent <- write fd s;
if (unsent = "") then
self#output_ready fd
| (s, _) ->
unsent <- write fd s;
if (unsent = "") then
self#output_ready fd
method virtual handle_line : string -> unit
(** Split ibuf on newline, feeding each split into self#handle_input.
Does not send the trailing newline. You can add it back if you want.
*)
method split_handle_input () =
match ibuf with
| "" ->
()
| ibuf ->
let p = String.index ibuf '\n' in
let s = String.sub ibuf 0 p in
if p >= ibuf_len then
raise Not_found;
ibuf_len <- ibuf_len - (p + 1);
String.blit ibuf (p + 1) ibuf 0 ibuf_len;
self#handle_line s;
self#split_handle_input ()
method input_ready fd =
let size = ibuf_max - ibuf_len in
let len = Unix.read fd ibuf ibuf_len size in
if (len > 0) then
begin
ibuf_len <- ibuf_len + len;
try
self#split_handle_input ()
with Not_found ->
if (ibuf_len = ibuf_max) then
(* No newline found, and the buffer is full *)
raise Buffer_overrun;
end
else
begin
self#handle_close ();
Unix.close fd;
ues#clear g;
end
method write line =
if (Queue.length outq) >= max_outq then
raise (Failure "Maximum output queue length exceeded")
else
begin
Queue.add line outq;
ues#add_resource g (Wait_out fd, output_timeout)
end
end
(** Establish a server on the given address.
[connection_handler] will be called with the file descriptor of
any new connections.
*)
let establish_server ues connection_handler addr =
let g = ues#new_group () in
let handle_event ues esys e =
match e with
| Input_arrived (g, fd) ->
let cli_fd, cli_addr = Unix.accept fd in
connection_handler cli_fd
| _ ->
raise Equeue.Reject
in
let srv = Unix.socket Unix.PF_INET Unix.SOCK_STREAM 0 in
Unix.bind srv addr;
Unix.listen srv 50;
Unix.setsockopt srv Unix.SO_REUSEADDR true;
ues#add_handler g handle_event;
ues#add_resource g (Wait_in srv, -.1.0)

104
irc.ml Normal file
View File

@ -0,0 +1,104 @@
type command = {sender: string option;
command: string;
args: string list;
text: string option}
type server = {clients_by_name: (string, client) Hashtbl.t;
clients_by_file_descr: (Unix.file_descr, client) Hashtbl.t;
channels_by_name: (string, channel) Hashtbl.t}
and client = {outq: string list Queue.t;
unsent: string ref;
ibuf: string;
ibuf_len: int ref;
output_ready: unit -> unit;
handle_command: server -> client -> command -> unit;
channels: channel list}
and channel = {name: string}
let newline_re = Pcre.regexp "\n\r?"
let argsep_re = Pcre.regexp " :"
let space_re = Pcre.regexp " "
let dbg msg a =
prerr_endline ("[" ^ msg ^ "]");
a
let string_map f s =
let l = String.length s in
if l = 0 then
s
else
let r = String.create l in
for i = 0 to l - 1 do
String.unsafe_set r i (f (String.unsafe_get s i))
done;
r
let lowercase_char c =
if (c >= 'A' && c <= '^') then
Char.unsafe_chr(Char.code c + 32)
else
c
let uppercase_char c =
if (c >= 'a' && c <= '~') then
Char.unsafe_chr(Char.code c - 32)
else
c
let uppercase s = string_map uppercase_char s
let lowercase s = string_map lowercase_char s
let extract_word s =
try
let pos = String.index s ' ' in
(Str.string_before s pos, Str.string_after s (pos + 1))
with Not_found ->
(s, "")
let string_list_of_command cmd =
([] @
(match cmd.sender with
| None -> []
| Some s -> [":" ^ s]) @
[cmd.command] @
cmd.args @
(match cmd.text with
| None -> []
| Some s -> [":" ^ s]))
let string_of_command cmd =
String.concat " " (string_list_of_command cmd)
let rec command_of_string line =
(* Very simple. Pull out words until you get one starting with ":".
The very first word might start with ":", that doesn't count
because it's the sender.. *)
let rec loop sender acc line =
let c = (if (line = "") then None else (Some line.[0])) in
match (c, acc) with
| (None, cmd :: args) ->
(* End of line, no text part *)
{sender = sender;
command = cmd;
args = args;
text = None}
| (None, []) ->
(* End of line, no text part, no args, no command *)
raise (Failure "No command, eh?")
| (Some ':', []) ->
(* First word, starts with ':' *)
let (word, rest) = extract_word line in
loop (Some (Str.string_after word 1)) acc rest
| (Some ':', cmd :: args) ->
(* Not first word, starts with ':' *)
{sender = sender;
command = cmd;
args = args;
text = Some (Str.string_after line 1)}
| (Some _, _) ->
(* Argument *)
let (word, rest) = extract_word line in
loop sender (acc @ [word]) rest
in
loop None [] line

21
irc.mli Normal file
View File

@ -0,0 +1,21 @@
type command = {sender: string option;
command: string;
args: string list;
text: string option}
type server = {clients_by_name: (string, client) Hashtbl.t;
clients_by_file_descr: (Unix.file_descr, client) Hashtbl.t;
channels_by_name: (string, channel) Hashtbl.t}
and client = {outq: string list Queue.t;
unsent: string ref;
ibuf: string;
ibuf_len: int ref;
output_ready: unit -> unit;
handle_command: server -> client -> command -> unit;
channels: channel list}
and channel = {name: string}
val uppercase : string -> string
val lowercase : string -> string
val command_of_string : string -> command
val string_of_command : command -> string

135
ircd.ml
View File

@ -1,128 +1,45 @@
type server = {clients_by_name: (string, client) Hashtbl.t; let dbg msg a =
clients_by_file_descr: (Unix.file_descr, client) Hashtbl.t;
channels_by_name: (string, channel) Hashtbl.t}
and client = {outq: string list Queue.t;
unsent: string ref;
ibuf: string;
ibuf_len: int ref;
out_ready: unit -> unit;
channels: channel list}
and channel = {name: string}
let dump msg a =
prerr_endline msg; prerr_endline msg;
a a
(* ========================================== (** Establish a server on the given address.
* Server stuff
[connection_handler] will be called with the file descriptor of
any new connections.
*) *)
let create_server () = let establish_server ues connection_handler addr =
{clients_by_name = Hashtbl.create 25; let g = Unixqueue.new_group ues in
clients_by_file_descr = Hashtbl.create 25; let handle_event ues esys e =
channels_by_name = Hashtbl.create 10}
let get_client_by_name srv name =
Hashtbl.find srv.clients_by_name name
let get_client_by_file_descr srv fd =
Hashtbl.find srv.clients_by_file_descr fd
let get_channel_by_name srv name =
Hashtbl.find srv.channels_by_name name
(* ==========================================
* Client stuff
*)
let ibuf_max = 4096
let max_outq = 50
let obuf_max = 4096
let create_client ues g fd =
{outq = Queue.create ();
unsent = ref "";
ibuf = String.create ibuf_max;
ibuf_len = ref 0;
out_ready =
begin
fun () -> Unixqueue.add_resource ues g (Unixqueue.Wait_out fd, -.1.0)
end;
channels = []}
let client_shutdown ues g fd =
Unix.close fd;
Unixqueue.remove_resource ues g (Unixqueue.Wait_in fd);
try
Unixqueue.remove_resource ues g (Unixqueue.Wait_out fd);
with Not_found ->
()
let client_handle_line srv cli line =
print_endline line
let client_handle_close srv cli =
()
let rec client_handle_input srv cli =
match cli.ibuf with
| "" ->
()
| ibuf ->
let p = String.index ibuf '\n' in
let s = String.sub ibuf 0 p in
if p >= !(cli.ibuf_len) then
raise Not_found;
cli.ibuf_len := !(cli.ibuf_len) - (p + 1);
String.blit ibuf (p + 1) ibuf 0 !(cli.ibuf_len);
client_handle_line srv cli s;
client_handle_input srv cli
let create_event_handler srv =
fun ues esys e ->
match e with match e with
| Unixqueue.Input_arrived (g, fd) -> | Unixqueue.Input_arrived (g, fd) ->
let cli = dump "input" get_client_by_file_descr srv fd in let cli_fd, cli_addr = Unix.accept fd in
let size = dump "size" ibuf_max - !(cli.ibuf_len) in connection_handler cli_fd
let len = dump "read" Unix.read fd cli.ibuf !(cli.ibuf_len) size in | _ ->
if (len > 0) then raise Equeue.Reject
begin in
cli.ibuf_len := !(cli.ibuf_len) + len; let srv = Unix.socket Unix.PF_INET Unix.SOCK_STREAM 0 in
try Unix.bind srv addr;
client_handle_input srv cli Unix.listen srv 50;
with Not_found -> Unix.setsockopt srv Unix.SO_REUSEADDR true;
if (!(cli.ibuf_len) = ibuf_max) then Unixqueue.add_handler ues g handle_event;
(* No newline found, and the buffer is full *) Unixqueue.add_resource ues g (Unixqueue.Wait_in srv, -.1.0)
raise (Failure "Buffer overrun");
end
else
client_shutdown ues g fd
| Unixqueue.Output_readiness (g, fd) ->
print_endline "out"
| Unixqueue.Out_of_band (g, fd) ->
print_endline "oob"
| Unixqueue.Timeout (g, op) ->
print_endline "timeout"
| Unixqueue.Signal ->
print_endline "signal"
| Unixqueue.Extra exn ->
print_endline "extra"
let main () = let main () =
let srv = create_server () in let srv = Server.create () in
let handle_event = create_event_handler srv in let handle_event = Client.create_event_handler srv in
let ues = Unixqueue.create_unix_event_system () in let ues = Unixqueue.create_unix_event_system () in
let g = Unixqueue.new_group ues in let g = Unixqueue.new_group ues in
let handle_connection fd = let handle_connection fd =
let cli = create_client ues g fd in let cli = Client.create ues g fd in
Hashtbl.replace srv.clients_by_file_descr fd cli; Hashtbl.replace srv.Irc.clients_by_file_descr fd cli;
Unixqueue.add_resource ues g (Unixqueue.Wait_in fd, -.1.0); Unixqueue.add_resource ues g (Unixqueue.Wait_in fd, -.1.0);
in in
Unixqueue.add_handler ues g handle_event; Unixqueue.add_handler ues g handle_event;
Connection.establish_server establish_server
ues ues
handle_connection handle_connection
(Unix.ADDR_INET (Unix.inet_addr_any, 7777)); (Unix.ADDR_INET (Unix.inet_addr_any, 7777));
ues#run () ues#run ()
let _ =
main ()

View File

@ -1,2 +0,0 @@
let _ =
Ircd.main ()

18
server.ml Normal file
View File

@ -0,0 +1,18 @@
open Irc
(* ==========================================
* Server stuff
*)
let create () =
{clients_by_name = Hashtbl.create 25;
clients_by_file_descr = Hashtbl.create 25;
channels_by_name = Hashtbl.create 10}
let get_client_by_name srv name =
Hashtbl.find srv.clients_by_name name
let get_client_by_file_descr srv fd =
Hashtbl.find srv.clients_by_file_descr fd
let get_channel_by_name srv name =
Hashtbl.find srv.channels_by_name name

View File

@ -1,15 +1,51 @@
open Unixqueue open Unixqueue
open OUnit open OUnit
open Chat open Chat
open Irc
let do_chat script () = let do_chat script () =
let ircd_instance ues fd = let ircd_instance ues fd =
let irc = new Ircd.ircd_connection ues fd in let srv = Server.create () in
irc#debug true let handle_event = Client.create_event_handler srv in
let g = Unixqueue.new_group ues in
let cli = Client.create ues g fd in
Hashtbl.replace srv.Irc.clients_by_file_descr fd cli;
Unixqueue.add_handler ues g handle_event;
Unixqueue.add_resource ues g (Unixqueue.Wait_in fd, -.1.0)
in in
chat script ircd_instance chat script ircd_instance
let normal_tests = let unit_tests =
"Unit tests" >:::
[
"command_of_string" >::
(fun () ->
assert_equal
~printer:string_of_command
{sender = None;
command = "NICK";
args = ["name"];
text = None}
(command_of_string "NICK name");
assert_equal
~printer:string_of_command
{sender = Some "foo";
command = "NICK";
args = ["name"];
text = None}
(command_of_string ":foo NICK name");
assert_equal
~printer:string_of_command
{sender = Some "foo.bar";
command = "PART";
args = ["#foo"; "#bar"];
text = Some "ta ta"}
(command_of_string ":foo.bar PART #foo #bar :ta ta");
)
]
let regression_tests =
let login_script = let login_script =
[ [
Send "USER nick +iw nick :gecos\n"; Send "USER nick +iw nick :gecos\n";
@ -19,7 +55,7 @@ let normal_tests =
Send "PONG :12345\n"; Send "PONG :12345\n";
] ]
in in
"Normal tests" >::: "Regression tests" >:::
[ [
"Simple connection" >:: "Simple connection" >::
(do_chat (do_chat
@ -50,6 +86,6 @@ let normal_tests =
] ]
let _ = let _ =
run_test_tt_main (TestList [normal_tests]) run_test_tt_main (TestList [unit_tests; regression_tests])