blob: 0ad50e2d36d7dcb5575f069cab874e80b34d576c [file] [log] [blame]
// Copyright 2012 Google Inc. All rights reserved.
// Use of this source code is governed by the Apache 2.0
// license that can be found in the LICENSE file.
// +build appengine
package socket
import (
"fmt"
"io"
"net"
"strconv"
"time"
"github.com/golang/protobuf/proto"
"golang.org/x/net/context"
"google.golang.org/appengine/internal"
pb "google.golang.org/appengine/internal/socket"
)
// Dial connects to the address addr on the network protocol.
// The address format is host:port, where host may be a hostname or an IP address.
// Known protocols are "tcp" and "udp".
// The returned connection satisfies net.Conn, and is valid while ctx is valid;
// if the connection is to be used after ctx becomes invalid, invoke SetContext
// with the new context.
func Dial(ctx context.Context, protocol, addr string) (*Conn, error) {
return DialTimeout(ctx, protocol, addr, 0)
}
var ipFamilies = []pb.CreateSocketRequest_SocketFamily{
pb.CreateSocketRequest_IPv4,
pb.CreateSocketRequest_IPv6,
}
// DialTimeout is like Dial but takes a timeout.
// The timeout includes name resolution, if required.
func DialTimeout(ctx context.Context, protocol, addr string, timeout time.Duration) (*Conn, error) {
dialCtx := ctx // Used for dialing and name resolution, but not stored in the *Conn.
if timeout > 0 {
var cancel context.CancelFunc
dialCtx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
port, err := strconv.Atoi(portStr)
if err != nil {
return nil, fmt.Errorf("socket: bad port %q: %v", portStr, err)
}
var prot pb.CreateSocketRequest_SocketProtocol
switch protocol {
case "tcp":
prot = pb.CreateSocketRequest_TCP
case "udp":
prot = pb.CreateSocketRequest_UDP
default:
return nil, fmt.Errorf("socket: unknown protocol %q", protocol)
}
packedAddrs, resolved, err := resolve(dialCtx, ipFamilies, host)
if err != nil {
return nil, fmt.Errorf("socket: failed resolving %q: %v", host, err)
}
if len(packedAddrs) == 0 {
return nil, fmt.Errorf("no addresses for %q", host)
}
packedAddr := packedAddrs[0] // use first address
fam := pb.CreateSocketRequest_IPv4
if len(packedAddr) == net.IPv6len {
fam = pb.CreateSocketRequest_IPv6
}
req := &pb.CreateSocketRequest{
Family: fam.Enum(),
Protocol: prot.Enum(),
RemoteIp: &pb.AddressPort{
Port: proto.Int32(int32(port)),
PackedAddress: packedAddr,
},
}
if resolved {
req.RemoteIp.HostnameHint = &host
}
res := &pb.CreateSocketReply{}
if err := internal.Call(dialCtx, "remote_socket", "CreateSocket", req, res); err != nil {
return nil, err
}
return &Conn{
ctx: ctx,
desc: res.GetSocketDescriptor(),
prot: prot,
local: res.ProxyExternalIp,
remote: req.RemoteIp,
}, nil
}
// LookupIP returns the given host's IP addresses.
func LookupIP(ctx context.Context, host string) (addrs []net.IP, err error) {
packedAddrs, _, err := resolve(ctx, ipFamilies, host)
if err != nil {
return nil, fmt.Errorf("socket: failed resolving %q: %v", host, err)
}
addrs = make([]net.IP, len(packedAddrs))
for i, pa := range packedAddrs {
addrs[i] = net.IP(pa)
}
return addrs, nil
}
func resolve(ctx context.Context, fams []pb.CreateSocketRequest_SocketFamily, host string) ([][]byte, bool, error) {
// Check if it's an IP address.
if ip := net.ParseIP(host); ip != nil {
if ip := ip.To4(); ip != nil {
return [][]byte{ip}, false, nil
}
return [][]byte{ip}, false, nil
}
req := &pb.ResolveRequest{
Name: &host,
AddressFamilies: fams,
}
res := &pb.ResolveReply{}
if err := internal.Call(ctx, "remote_socket", "Resolve", req, res); err != nil {
// XXX: need to map to pb.ResolveReply_ErrorCode?
return nil, false, err
}
return res.PackedAddress, true, nil
}
// withDeadline is like context.WithDeadline, except it ignores the zero deadline.
func withDeadline(parent context.Context, deadline time.Time) (context.Context, context.CancelFunc) {
if deadline.IsZero() {
return parent, func() {}
}
return context.WithDeadline(parent, deadline)
}
// Conn represents a socket connection.
// It implements net.Conn.
type Conn struct {
ctx context.Context
desc string
offset int64
prot pb.CreateSocketRequest_SocketProtocol
local, remote *pb.AddressPort
readDeadline, writeDeadline time.Time // optional
}
// SetContext sets the context that is used by this Conn.
// It is usually used only when using a Conn that was created in a different context,
// such as when a connection is created during a warmup request but used while
// servicing a user request.
func (cn *Conn) SetContext(ctx context.Context) {
cn.ctx = ctx
}
func (cn *Conn) Read(b []byte) (n int, err error) {
const maxRead = 1 << 20
if len(b) > maxRead {
b = b[:maxRead]
}
req := &pb.ReceiveRequest{
SocketDescriptor: &cn.desc,
DataSize: proto.Int32(int32(len(b))),
}
res := &pb.ReceiveReply{}
if !cn.readDeadline.IsZero() {
req.TimeoutSeconds = proto.Float64(cn.readDeadline.Sub(time.Now()).Seconds())
}
ctx, cancel := withDeadline(cn.ctx, cn.readDeadline)
defer cancel()
if err := internal.Call(ctx, "remote_socket", "Receive", req, res); err != nil {
return 0, err
}
if len(res.Data) == 0 {
return 0, io.EOF
}
if len(res.Data) > len(b) {
return 0, fmt.Errorf("socket: internal error: read too much data: %d > %d", len(res.Data), len(b))
}
return copy(b, res.Data), nil
}
func (cn *Conn) Write(b []byte) (n int, err error) {
const lim = 1 << 20 // max per chunk
for n < len(b) {
chunk := b[n:]
if len(chunk) > lim {
chunk = chunk[:lim]
}
req := &pb.SendRequest{
SocketDescriptor: &cn.desc,
Data: chunk,
StreamOffset: &cn.offset,
}
res := &pb.SendReply{}
if !cn.writeDeadline.IsZero() {
req.TimeoutSeconds = proto.Float64(cn.writeDeadline.Sub(time.Now()).Seconds())
}
ctx, cancel := withDeadline(cn.ctx, cn.writeDeadline)
defer cancel()
if err = internal.Call(ctx, "remote_socket", "Send", req, res); err != nil {
// assume zero bytes were sent in this RPC
break
}
n += int(res.GetDataSent())
cn.offset += int64(res.GetDataSent())
}
return
}
func (cn *Conn) Close() error {
req := &pb.CloseRequest{
SocketDescriptor: &cn.desc,
}
res := &pb.CloseReply{}
if err := internal.Call(cn.ctx, "remote_socket", "Close", req, res); err != nil {
return err
}
cn.desc = "CLOSED"
return nil
}
func addr(prot pb.CreateSocketRequest_SocketProtocol, ap *pb.AddressPort) net.Addr {
if ap == nil {
return nil
}
switch prot {
case pb.CreateSocketRequest_TCP:
return &net.TCPAddr{
IP: net.IP(ap.PackedAddress),
Port: int(*ap.Port),
}
case pb.CreateSocketRequest_UDP:
return &net.UDPAddr{
IP: net.IP(ap.PackedAddress),
Port: int(*ap.Port),
}
}
panic("unknown protocol " + prot.String())
}
func (cn *Conn) LocalAddr() net.Addr { return addr(cn.prot, cn.local) }
func (cn *Conn) RemoteAddr() net.Addr { return addr(cn.prot, cn.remote) }
func (cn *Conn) SetDeadline(t time.Time) error {
cn.readDeadline = t
cn.writeDeadline = t
return nil
}
func (cn *Conn) SetReadDeadline(t time.Time) error {
cn.readDeadline = t
return nil
}
func (cn *Conn) SetWriteDeadline(t time.Time) error {
cn.writeDeadline = t
return nil
}
// KeepAlive signals that the connection is still in use.
// It may be called to prevent the socket being closed due to inactivity.
func (cn *Conn) KeepAlive() error {
req := &pb.GetSocketNameRequest{
SocketDescriptor: &cn.desc,
}
res := &pb.GetSocketNameReply{}
return internal.Call(cn.ctx, "remote_socket", "GetSocketName", req, res)
}
func init() {
internal.RegisterErrorCodeMap("remote_socket", pb.RemoteSocketServiceError_ErrorCode_name)
}