| // +build linux |
| |
| package ipvs |
| |
| import ( |
| "bytes" |
| "encoding/binary" |
| "fmt" |
| "net" |
| "os/exec" |
| "strings" |
| "sync" |
| "syscall" |
| "unsafe" |
| |
| "github.com/Sirupsen/logrus" |
| "github.com/vishvananda/netlink/nl" |
| "github.com/vishvananda/netns" |
| ) |
| |
| var ( |
| native = nl.NativeEndian() |
| ipvsFamily int |
| ipvsOnce sync.Once |
| ) |
| |
| type genlMsgHdr struct { |
| cmd uint8 |
| version uint8 |
| reserved uint16 |
| } |
| |
| type ipvsFlags struct { |
| flags uint32 |
| mask uint32 |
| } |
| |
| func deserializeGenlMsg(b []byte) (hdr *genlMsgHdr) { |
| return (*genlMsgHdr)(unsafe.Pointer(&b[0:unsafe.Sizeof(*hdr)][0])) |
| } |
| |
| func (hdr *genlMsgHdr) Serialize() []byte { |
| return (*(*[unsafe.Sizeof(*hdr)]byte)(unsafe.Pointer(hdr)))[:] |
| } |
| |
| func (hdr *genlMsgHdr) Len() int { |
| return int(unsafe.Sizeof(*hdr)) |
| } |
| |
| func (f *ipvsFlags) Serialize() []byte { |
| return (*(*[unsafe.Sizeof(*f)]byte)(unsafe.Pointer(f)))[:] |
| } |
| |
| func (f *ipvsFlags) Len() int { |
| return int(unsafe.Sizeof(*f)) |
| } |
| |
| func setup() { |
| ipvsOnce.Do(func() { |
| var err error |
| if out, err := exec.Command("modprobe", "-va", "ip_vs").CombinedOutput(); err != nil { |
| logrus.Warnf("Running modprobe ip_vs failed with message: `%s`, error: %v", strings.TrimSpace(string(out)), err) |
| } |
| |
| ipvsFamily, err = getIPVSFamily() |
| if err != nil { |
| logrus.Errorf("Could not get ipvs family information from the kernel. It is possible that ipvs is not enabled in your kernel. Native loadbalancing will not work until this is fixed.") |
| } |
| }) |
| } |
| |
| func fillService(s *Service) nl.NetlinkRequestData { |
| cmdAttr := nl.NewRtAttr(ipvsCmdAttrService, nil) |
| nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrAddressFamily, nl.Uint16Attr(s.AddressFamily)) |
| if s.FWMark != 0 { |
| nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrFWMark, nl.Uint32Attr(s.FWMark)) |
| } else { |
| nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrProtocol, nl.Uint16Attr(s.Protocol)) |
| nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrAddress, rawIPData(s.Address)) |
| |
| // Port needs to be in network byte order. |
| portBuf := new(bytes.Buffer) |
| binary.Write(portBuf, binary.BigEndian, s.Port) |
| nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrPort, portBuf.Bytes()) |
| } |
| |
| nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrSchedName, nl.ZeroTerminated(s.SchedName)) |
| if s.PEName != "" { |
| nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrPEName, nl.ZeroTerminated(s.PEName)) |
| } |
| |
| f := &ipvsFlags{ |
| flags: s.Flags, |
| mask: 0xFFFFFFFF, |
| } |
| nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrFlags, f.Serialize()) |
| nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrTimeout, nl.Uint32Attr(s.Timeout)) |
| nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrNetmask, nl.Uint32Attr(s.Netmask)) |
| return cmdAttr |
| } |
| |
| func fillDestinaton(d *Destination) nl.NetlinkRequestData { |
| cmdAttr := nl.NewRtAttr(ipvsCmdAttrDest, nil) |
| |
| nl.NewRtAttrChild(cmdAttr, ipvsDestAttrAddress, rawIPData(d.Address)) |
| // Port needs to be in network byte order. |
| portBuf := new(bytes.Buffer) |
| binary.Write(portBuf, binary.BigEndian, d.Port) |
| nl.NewRtAttrChild(cmdAttr, ipvsDestAttrPort, portBuf.Bytes()) |
| |
| nl.NewRtAttrChild(cmdAttr, ipvsDestAttrForwardingMethod, nl.Uint32Attr(d.ConnectionFlags&ConnectionFlagFwdMask)) |
| nl.NewRtAttrChild(cmdAttr, ipvsDestAttrWeight, nl.Uint32Attr(uint32(d.Weight))) |
| nl.NewRtAttrChild(cmdAttr, ipvsDestAttrUpperThreshold, nl.Uint32Attr(d.UpperThreshold)) |
| nl.NewRtAttrChild(cmdAttr, ipvsDestAttrLowerThreshold, nl.Uint32Attr(d.LowerThreshold)) |
| |
| return cmdAttr |
| } |
| |
| func (i *Handle) doCmd(s *Service, d *Destination, cmd uint8) error { |
| req := newIPVSRequest(cmd) |
| req.AddData(fillService(s)) |
| |
| if d != nil { |
| req.AddData(fillDestinaton(d)) |
| } |
| |
| if _, err := execute(i.sock, req, 0); err != nil { |
| return err |
| } |
| |
| return nil |
| } |
| |
| func getIPVSFamily() (int, error) { |
| sock, err := nl.GetNetlinkSocketAt(netns.None(), netns.None(), syscall.NETLINK_GENERIC) |
| if err != nil { |
| return 0, err |
| } |
| |
| req := newGenlRequest(genlCtrlID, genlCtrlCmdGetFamily) |
| req.AddData(nl.NewRtAttr(genlCtrlAttrFamilyName, nl.ZeroTerminated("IPVS"))) |
| |
| msgs, err := execute(sock, req, 0) |
| if err != nil { |
| return 0, err |
| } |
| |
| for _, m := range msgs { |
| hdr := deserializeGenlMsg(m) |
| attrs, err := nl.ParseRouteAttr(m[hdr.Len():]) |
| if err != nil { |
| return 0, err |
| } |
| |
| for _, attr := range attrs { |
| switch int(attr.Attr.Type) { |
| case genlCtrlAttrFamilyID: |
| return int(native.Uint16(attr.Value[0:2])), nil |
| } |
| } |
| } |
| |
| return 0, fmt.Errorf("no family id in the netlink response") |
| } |
| |
| func rawIPData(ip net.IP) []byte { |
| family := nl.GetIPFamily(ip) |
| if family == nl.FAMILY_V4 { |
| return ip.To4() |
| } |
| |
| return ip |
| } |
| |
| func newIPVSRequest(cmd uint8) *nl.NetlinkRequest { |
| return newGenlRequest(ipvsFamily, cmd) |
| } |
| |
| func newGenlRequest(familyID int, cmd uint8) *nl.NetlinkRequest { |
| req := nl.NewNetlinkRequest(familyID, syscall.NLM_F_ACK) |
| req.AddData(&genlMsgHdr{cmd: cmd, version: 1}) |
| return req |
| } |
| |
| func execute(s *nl.NetlinkSocket, req *nl.NetlinkRequest, resType uint16) ([][]byte, error) { |
| var ( |
| err error |
| ) |
| |
| if err := s.Send(req); err != nil { |
| return nil, err |
| } |
| |
| pid, err := s.GetPid() |
| if err != nil { |
| return nil, err |
| } |
| |
| var res [][]byte |
| |
| done: |
| for { |
| msgs, err := s.Receive() |
| if err != nil { |
| return nil, err |
| } |
| for _, m := range msgs { |
| if m.Header.Seq != req.Seq { |
| return nil, fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq) |
| } |
| if m.Header.Pid != pid { |
| return nil, fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid) |
| } |
| if m.Header.Type == syscall.NLMSG_DONE { |
| break done |
| } |
| if m.Header.Type == syscall.NLMSG_ERROR { |
| error := int32(native.Uint32(m.Data[0:4])) |
| if error == 0 { |
| break done |
| } |
| return nil, syscall.Errno(-error) |
| } |
| if resType != 0 && m.Header.Type != resType { |
| continue |
| } |
| res = append(res, m.Data) |
| if m.Header.Flags&syscall.NLM_F_MULTI == 0 { |
| break done |
| } |
| } |
| } |
| return res, nil |
| } |