#+private package nbio import "base:runtime" import "core:container/queue" import "core:log" import "core:mem" import "core:net" import "core:os" import "core:time" import win "core:sys/windows" _IO :: struct { iocp: win.HANDLE, allocator: mem.Allocator, timeouts: [dynamic]^Completion, completed: queue.Queue(^Completion), completion_pool: Pool(Completion), io_pending: int, // The asynchronous Windows API's don't support reading at the current offset of a file, so we keep track ourselves. offsets: map[os.Handle]u32, } _Completion :: struct { over: win.OVERLAPPED, ctx: runtime.Context, op: Operation, } #assert(offset_of(Completion, over) == 0, "needs to be the first field to work") Op_Accept :: struct { callback: On_Accept, socket: win.SOCKET, client: win.SOCKET, addr: win.SOCKADDR_STORAGE_LH, pending: bool, } Op_Connect :: struct { callback: On_Connect, socket: win.SOCKET, addr: win.SOCKADDR_STORAGE_LH, pending: bool, } Op_Close :: struct { callback: On_Close, fd: Closable, } Op_Read :: struct { callback: On_Read, fd: os.Handle, offset: int, buf: []byte, pending: bool, all: bool, read: int, len: int, } Op_Write :: struct { callback: On_Write, fd: os.Handle, offset: int, buf: []byte, pending: bool, written: int, len: int, all: bool, } Op_Recv :: struct { callback: On_Recv, socket: net.Any_Socket, buf: win.WSABUF, pending: bool, all: bool, received: int, len: int, } Op_Send :: struct { callback: On_Sent, socket: net.Any_Socket, buf: win.WSABUF, pending: bool, len: int, sent: int, all: bool, } Op_Timeout :: struct { callback: On_Timeout, expires: time.Time, } Op_Next_Tick :: struct {} Op_Poll :: struct {} Op_Poll_Remove :: struct {} flush_timeouts :: proc(io: ^IO) -> (expires: Maybe(time.Duration)) { curr: time.Time timeout_len := len(io.timeouts) // PERF: could use a faster clock, is getting time since program start fast? if timeout_len > 0 do curr = time.now() for i := 0; i < timeout_len; { completion := io.timeouts[i] op := &completion.op.(Op_Timeout) cexpires := time.diff(curr, op.expires) // Timeout done. if (cexpires <= 0) { ordered_remove(&io.timeouts, i) queue.push_back(&io.completed, completion) timeout_len -= 1 continue } // Update minimum timeout. exp, ok := expires.? expires = min(exp, cexpires) if ok else cexpires i += 1 } return } prepare_socket :: proc(io: ^IO, socket: net.Any_Socket) -> net.Network_Error { net.set_option(socket, .Reuse_Address, true) or_return net.set_option(socket, .TCP_Nodelay, true) or_return handle := win.HANDLE(uintptr(net.any_socket_to_socket(socket))) handle_iocp := win.CreateIoCompletionPort(handle, io.iocp, 0, 0) assert(handle_iocp == io.iocp) mode: byte mode |= FILE_SKIP_COMPLETION_PORT_ON_SUCCESS mode |= FILE_SKIP_SET_EVENT_ON_HANDLE if !win.SetFileCompletionNotificationModes(handle, mode) { return net.Socket_Option_Error(win.GetLastError()) } return nil } submit :: proc(io: ^IO, user: rawptr, op: Operation) -> ^Completion { completion := pool_get(&io.completion_pool) completion.ctx = context completion.user_data = user completion.op = op queue.push_back(&io.completed, completion) return completion } handle_completion :: proc(io: ^IO, completion: ^Completion) { switch &op in completion.op { case Op_Accept: // TODO: we should directly call the accept callback here, no need for it to be on the Op_Acccept struct. source, err := accept_callback(io, completion, &op) if wsa_err_incomplete(err) { io.io_pending += 1 return } rerr := net.Accept_Error(err) if rerr != nil do win.closesocket(op.client) op.callback(completion.user_data, net.TCP_Socket(op.client), source, rerr) case Op_Connect: err := connect_callback(io, completion, &op) if wsa_err_incomplete(err) { io.io_pending += 1 return } rerr := net.Dial_Error(err) if rerr != nil do win.closesocket(op.socket) op.callback(completion.user_data, net.TCP_Socket(op.socket), rerr) case Op_Close: op.callback(completion.user_data, close_callback(io, op)) case Op_Read: read, err := read_callback(io, completion, &op) if err_incomplete(err) { io.io_pending += 1 return } if err == win.ERROR_HANDLE_EOF { err = win.NO_ERROR } op.read += int(read) if err != win.NO_ERROR { op.callback(completion.user_data, op.read, os.Platform_Error(err)) } else if op.all && op.read < op.len { op.buf = op.buf[read:] if op.offset >= 0 { op.offset += int(read) } op.pending = false handle_completion(io, completion) return } else { op.callback(completion.user_data, op.read, os.ERROR_NONE) } case Op_Write: written, err := write_callback(io, completion, &op) if err_incomplete(err) { io.io_pending += 1 return } op.written += int(written) oerr := os.Platform_Error(err) if oerr != os.ERROR_NONE { op.callback(completion.user_data, op.written, oerr) } else if op.all && op.written < op.len { op.buf = op.buf[written:] if op.offset >= 0 { op.offset += int(written) } op.pending = false handle_completion(io, completion) return } else { op.callback(completion.user_data, op.written, os.ERROR_NONE) } case Op_Recv: received, err := recv_callback(io, completion, &op) if wsa_err_incomplete(err) { io.io_pending += 1 return } op.received += int(received) nerr := net.TCP_Recv_Error(err) if nerr != nil { op.callback(completion.user_data, op.received, {}, nerr) } else if op.all && op.received < op.len { op.buf = win.WSABUF{ len = op.buf.len - win.ULONG(received), buf = (cast([^]byte)op.buf.buf)[received:], } op.pending = false handle_completion(io, completion) return } else { op.callback(completion.user_data, op.received, {}, nil) } case Op_Send: sent, err := send_callback(io, completion, &op) if wsa_err_incomplete(err) { io.io_pending += 1 return } op.sent += int(sent) nerr := net.TCP_Send_Error(err) if nerr != nil { op.callback(completion.user_data, op.sent, nerr) } else if op.all && op.sent < op.len { op.buf = win.WSABUF{ len = op.buf.len - win.ULONG(sent), buf = (cast([^]byte)op.buf.buf)[sent:], } op.pending = false handle_completion(io, completion) return } else { op.callback(completion.user_data, op.sent, nil) } case Op_Timeout: op.callback(completion.user_data) case Op_Next_Tick, Op_Poll, Op_Poll_Remove: unreachable() } pool_put(&io.completion_pool, completion) } accept_callback :: proc(io: ^IO, comp: ^Completion, op: ^Op_Accept) -> (source: net.Endpoint, err: win.c_int) { ok: win.BOOL if op.pending { // Get status update, we've already initiated the accept. flags: win.DWORD transferred: win.DWORD ok = win.WSAGetOverlappedResult(op.socket, &comp.over, &transferred, win.FALSE, &flags) } else { op.pending = true oclient, oerr := open_socket(io, .IP4, .TCP) err = win.c_int(net_err_to_code(oerr)) if err != win.NO_ERROR do return op.client = win.SOCKET(net.any_socket_to_socket(oclient)) accept_ex: LPFN_ACCEPTEX load_socket_fn(op.socket, win.WSAID_ACCEPTEX, &accept_ex) #assert(size_of(win.SOCKADDR_STORAGE_LH) >= size_of(win.sockaddr_in) + 16) bytes_read: win.DWORD ok = accept_ex( op.socket, op.client, &op.addr, 0, size_of(win.sockaddr_in) + 16, size_of(win.sockaddr_in) + 16, &bytes_read, &comp.over, ) } if !ok { err = win.WSAGetLastError() return } // enables getsockopt, setsockopt, getsockname, getpeername. win.setsockopt(op.client, win.SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, nil, 0) source = sockaddr_to_endpoint(&op.addr) return } connect_callback :: proc(io: ^IO, comp: ^Completion, op: ^Op_Connect) -> (err: win.c_int) { transferred: win.DWORD ok: win.BOOL if op.pending { flags: win.DWORD ok = win.WSAGetOverlappedResult(op.socket, &comp.over, &transferred, win.FALSE, &flags) } else { op.pending = true osocket, oerr := open_socket(io, .IP4, .TCP) err = win.c_int(net_err_to_code(oerr)) if err != win.NO_ERROR do return op.socket = win.SOCKET(net.any_socket_to_socket(osocket)) sockaddr := endpoint_to_sockaddr({net.IP4_Any, 0}) res := win.bind(op.socket, &sockaddr, size_of(sockaddr)) if res < 0 do return win.WSAGetLastError() connect_ex: LPFN_CONNECTEX load_socket_fn(op.socket, WSAID_CONNECTEX, &connect_ex) // TODO: size_of(win.sockaddr_in6) when ip6. ok = connect_ex(op.socket, &op.addr, size_of(win.sockaddr_in) + 16, nil, 0, &transferred, &comp.over) } if !ok do return win.WSAGetLastError() // enables getsockopt, setsockopt, getsockname, getpeername. win.setsockopt(op.socket, win.SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, nil, 0) return } close_callback :: proc(io: ^IO, op: Op_Close) -> bool { // NOTE: This might cause problems if there is still IO queued/pending. // Is that our responsibility to check/keep track of? // Might want to call win.CancelloEx to cancel all pending operations first. switch h in op.fd { case os.Handle: delete_key(&io.offsets, h) return win.CloseHandle(win.HANDLE(h)) == true case net.TCP_Socket: return win.closesocket(win.SOCKET(h)) == win.NO_ERROR case net.UDP_Socket: return win.closesocket(win.SOCKET(h)) == win.NO_ERROR case net.Socket: return win.closesocket(win.SOCKET(h)) == win.NO_ERROR case: unreachable() } } read_callback :: proc(io: ^IO, comp: ^Completion, op: ^Op_Read) -> (read: win.DWORD, err: win.DWORD) { ok: win.BOOL if op.pending { ok = win.GetOverlappedResult(win.HANDLE(op.fd), &comp.over, &read, win.FALSE) } else { comp.over.Offset = u32(op.offset) if op.offset >= 0 else io.offsets[op.fd] comp.over.OffsetHigh = comp.over.Offset >> 32 ok = win.ReadFile(win.HANDLE(op.fd), raw_data(op.buf), win.DWORD(len(op.buf)), &read, &comp.over) // Not sure if this also happens with correctly set up handles some times. if ok do log.info("non-blocking write returned immediately, is the handle set up correctly?") op.pending = true } if !ok do err = win.GetLastError() // Increment offset if this was not a call with an offset set. if op.offset >= 0 { io.offsets[op.fd] += read } return } write_callback :: proc(io: ^IO, comp: ^Completion, op: ^Op_Write) -> (written: win.DWORD, err: win.DWORD) { ok: win.BOOL if op.pending { ok = win.GetOverlappedResult(win.HANDLE(op.fd), &comp.over, &written, win.FALSE) } else { comp.over.Offset = u32(op.offset) if op.offset >= 0 else io.offsets[op.fd] comp.over.OffsetHigh = comp.over.Offset >> 32 ok = win.WriteFile(win.HANDLE(op.fd), raw_data(op.buf), win.DWORD(len(op.buf)), &written, &comp.over) // Not sure if this also happens with correctly set up handles some times. if ok do log.debug("non-blocking write returned immediately, is the handle set up correctly?") op.pending = true } if !ok do err = win.GetLastError() // Increment offset if this was not a call with an offset set. if op.offset >= 0 { io.offsets[op.fd] += written } return } recv_callback :: proc(io: ^IO, comp: ^Completion, op: ^Op_Recv) -> (received: win.DWORD, err: win.c_int) { sock := win.SOCKET(net.any_socket_to_socket(op.socket)) ok: win.BOOL if op.pending { flags: win.DWORD ok = win.WSAGetOverlappedResult(sock, &comp.over, &received, win.FALSE, &flags) } else { flags: win.DWORD err_code := win.WSARecv(sock, &op.buf, 1, &received, &flags, win.LPWSAOVERLAPPED(&comp.over), nil) ok = err_code != win.SOCKET_ERROR op.pending = true } if !ok do err = win.WSAGetLastError() return } send_callback :: proc(io: ^IO, comp: ^Completion, op: ^Op_Send) -> (sent: win.DWORD, err: win.c_int) { sock := win.SOCKET(net.any_socket_to_socket(op.socket)) ok: win.BOOL if op.pending { flags: win.DWORD ok = win.WSAGetOverlappedResult(sock, &comp.over, &sent, win.FALSE, &flags) } else { err_code := win.WSASend(sock, &op.buf, 1, &sent, 0, win.LPWSAOVERLAPPED(&comp.over), nil) ok = err_code != win.SOCKET_ERROR op.pending = true } if !ok do err = win.WSAGetLastError() return } FILE_SKIP_COMPLETION_PORT_ON_SUCCESS :: 0x1 FILE_SKIP_SET_EVENT_ON_HANDLE :: 0x2 SO_UPDATE_ACCEPT_CONTEXT :: 28683 WSAID_CONNECTEX :: win.GUID{0x25a207b9, 0xddf3, 0x4660, [8]win.BYTE{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e}} LPFN_CONNECTEX :: #type proc "stdcall" ( socket: win.SOCKET, addr: ^win.SOCKADDR_STORAGE_LH, namelen: win.c_int, send_buf: win.PVOID, send_data_len: win.DWORD, bytes_sent: win.LPDWORD, overlapped: win.LPOVERLAPPED, ) -> win.BOOL LPFN_ACCEPTEX :: #type proc "stdcall" ( listen_sock: win.SOCKET, accept_sock: win.SOCKET, addr_buf: win.PVOID, addr_len: win.DWORD, local_addr_len: win.DWORD, remote_addr_len: win.DWORD, bytes_received: win.LPDWORD, overlapped: win.LPOVERLAPPED, ) -> win.BOOL wsa_err_incomplete :: proc(err: win.c_int) -> bool { #partial switch win.System_Error(err) { case .WSAEWOULDBLOCK, .IO_PENDING, .IO_INCOMPLETE, .WSAEALREADY: return true case: return false } } err_incomplete :: proc(err: win.DWORD) -> bool { return err == win.ERROR_IO_PENDING } // Verbatim copy of private proc in core:net. sockaddr_to_endpoint :: proc(native_addr: ^win.SOCKADDR_STORAGE_LH) -> (ep: net.Endpoint) { switch native_addr.ss_family { case u16(win.AF_INET): addr := cast(^win.sockaddr_in)native_addr port := int(addr.sin_port) ep = net.Endpoint { address = net.IP4_Address(transmute([4]byte)addr.sin_addr), port = port, } case u16(win.AF_INET6): addr := cast(^win.sockaddr_in6)native_addr port := int(addr.sin6_port) ep = net.Endpoint { address = net.IP6_Address(transmute([8]u16be)addr.sin6_addr), port = port, } case: panic("native_addr is neither IP4 or IP6 address") } return } // Verbatim copy of private proc in core:net. endpoint_to_sockaddr :: proc(ep: net.Endpoint) -> (sockaddr: win.SOCKADDR_STORAGE_LH) { switch a in ep.address { case net.IP4_Address: (^win.sockaddr_in)(&sockaddr)^ = win.sockaddr_in { sin_port = u16be(win.USHORT(ep.port)), sin_addr = transmute(win.in_addr)a, sin_family = u16(win.AF_INET), } return case net.IP6_Address: (^win.sockaddr_in6)(&sockaddr)^ = win.sockaddr_in6 { sin6_port = u16be(win.USHORT(ep.port)), sin6_addr = transmute(win.in6_addr)a, sin6_family = u16(win.AF_INET6), } return } unreachable() } net_err_to_code :: proc(err: net.Network_Error) -> os.Platform_Error { switch e in err { case net.Create_Socket_Error: return os.Platform_Error(e) case net.Socket_Option_Error: return os.Platform_Error(e) case net.General_Error: return os.Platform_Error(e) case net.Platform_Error: return os.Platform_Error(e) case net.Dial_Error: return os.Platform_Error(e) case net.Listen_Error: return os.Platform_Error(e) case net.Accept_Error: return os.Platform_Error(e) case net.Bind_Error: return os.Platform_Error(e) case net.TCP_Send_Error: return os.Platform_Error(e) case net.UDP_Send_Error: return os.Platform_Error(e) case net.TCP_Recv_Error: return os.Platform_Error(e) case net.UDP_Recv_Error: return os.Platform_Error(e) case net.Shutdown_Error: return os.Platform_Error(e) case net.Set_Blocking_Error: return os.Platform_Error(e) case net.Parse_Endpoint_Error: return os.Platform_Error(e) case net.Resolve_Error: return os.Platform_Error(e) case net.DNS_Error: return os.Platform_Error(e) case: return nil } } // TODO: loading this takes a overlapped parameter, maybe we can do this async? load_socket_fn :: proc(subject: win.SOCKET, guid: win.GUID, fn: ^$T) { guid := guid bytes: u32 rc := win.WSAIoctl( subject, win.SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, size_of(guid), fn, size_of(fn), &bytes, nil, nil, ) assert(rc != win.SOCKET_ERROR) assert(bytes == size_of(fn^)) }