feat(pubsub): Allow Message and PublishResult to be used outside the package (#3200)
Added NewMessage which can be provided a custom ack/nack handler. Added NewPublishResult which returns the set() func.
diff --git a/pubsub/iterator.go b/pubsub/iterator.go
index 8f1d931..da888ad 100644
--- a/pubsub/iterator.go
+++ b/pubsub/iterator.go
@@ -218,14 +218,15 @@
it.mu.Lock()
now := time.Now()
for _, m := range msgs {
- m.receiveTime = now
- addRecv(m.ID, m.ackID, now)
- m.doneFunc = it.done
- it.keepAliveDeadlines[m.ackID] = maxExt
+ ackh, _ := m.ackHandler()
+ ackh.receiveTime = now
+ addRecv(m.ID, ackh.ackID, now)
+ ackh.doneFunc = it.done
+ it.keepAliveDeadlines[ackh.ackID] = maxExt
// Don't change the mod-ack if the message is going to be nacked. This is
// possible if there are retries.
- if !it.pendingNacks[m.ackID] {
- ackIDs[m.ackID] = true
+ if !it.pendingNacks[ackh.ackID] {
+ ackIDs[ackh.ackID] = true
}
}
deadline := it.ackDeadline()
diff --git a/pubsub/message.go b/pubsub/message.go
index 00076ef..54be363 100644
--- a/pubsub/message.go
+++ b/pubsub/message.go
@@ -36,18 +36,12 @@
// labelled with.
Attributes map[string]string
- // ackID is the identifier to acknowledge this message.
- ackID string
-
// PublishTime is the time at which the message was published. This is
// populated by the server for Messages obtained from a subscription.
//
// This field is read-only.
PublishTime time.Time
- // receiveTime is the time the message was received by the client.
- receiveTime time.Time
-
// DeliveryAttempt is the number of times a message has been delivered.
// This is part of the dead lettering feature that forwards messages that
// fail to be processed (from nack/ack deadline timeout) to a dead letter topic.
@@ -59,19 +53,23 @@
// size is the approximate size of the message's data and attributes.
size int
- calledDone bool
-
- // The done method of the iterator that created this Message.
- doneFunc func(string, bool, time.Time)
-
// OrderingKey identifies related messages for which publish order should
// be respected. If empty string is used, message will be sent unordered.
OrderingKey string
+
+ // ackh handles Ack() or Nack().
+ ackh ackHandler
+}
+
+// NewMessage creates a message with a custom ack/nack handler, which should not
+// be nil.
+func NewMessage(ackh ackHandler) *Message {
+ return &Message{ackh: ackh}
}
func toMessage(resp *pb.ReceivedMessage) (*Message, error) {
if resp.Message == nil {
- return &Message{ackID: resp.AckId}, nil
+ return &Message{ackh: &psAckHandler{ackID: resp.AckId}}, nil
}
pubTime, err := ptypes.Timestamp(resp.Message.PublishTime)
@@ -86,13 +84,13 @@
}
return &Message{
- ackID: resp.AckId,
Data: resp.Message.Data,
Attributes: resp.Message.Attributes,
ID: resp.Message.MessageId,
PublishTime: pubTime,
DeliveryAttempt: deliveryAttempt,
OrderingKey: resp.Message.OrderingKey,
+ ackh: &psAckHandler{ackID: resp.AckId},
}, nil
}
@@ -102,7 +100,9 @@
// Client code must call Ack or Nack when finished for each received Message.
// Calls to Ack or Nack have no effect after the first call.
func (m *Message) Ack() {
- m.done(true)
+ if m.ackh != nil {
+ m.ackh.OnAck()
+ }
}
// Nack indicates that the client will not or cannot process a Message passed to the Subscriber.Receive callback.
@@ -111,15 +111,58 @@
// Client code must call Ack or Nack when finished for each received Message.
// Calls to Ack or Nack have no effect after the first call.
func (m *Message) Nack() {
- m.done(false)
+ if m.ackh != nil {
+ m.ackh.OnNack()
+ }
}
-func (m *Message) done(ack bool) {
- if m.calledDone {
+// ackHandler performs a safe cast of the message's ack handler to psAckHandler.
+func (m *Message) ackHandler() (*psAckHandler, bool) {
+ ackh, ok := m.ackh.(*psAckHandler)
+ return ackh, ok
+}
+
+func (m *Message) ackID() string {
+ if ackh, ok := m.ackh.(*psAckHandler); ok {
+ return ackh.ackID
+ }
+ return ""
+}
+
+// ackHandler implements ack/nack handling.
+type ackHandler interface {
+ OnAck()
+ OnNack()
+}
+
+// psAckHandler handles ack/nack for the pubsub package.
+type psAckHandler struct {
+ // ackID is the identifier to acknowledge this message.
+ ackID string
+
+ // receiveTime is the time the message was received by the client.
+ receiveTime time.Time
+
+ calledDone bool
+
+ // The done method of the iterator that created this Message.
+ doneFunc func(string, bool, time.Time)
+}
+
+func (ah *psAckHandler) OnAck() {
+ ah.done(true)
+}
+
+func (ah *psAckHandler) OnNack() {
+ ah.done(false)
+}
+
+func (ah *psAckHandler) done(ack bool) {
+ if ah.calledDone {
return
}
- m.calledDone = true
- if m.doneFunc != nil {
- m.doneFunc(m.ackID, ack, m.receiveTime)
+ ah.calledDone = true
+ if ah.doneFunc != nil {
+ ah.doneFunc(ah.ackID, ack, ah.receiveTime)
}
}
diff --git a/pubsub/streaming_pull_test.go b/pubsub/streaming_pull_test.go
index ba40a2d..e70d2e7 100644
--- a/pubsub/streaming_pull_test.go
+++ b/pubsub/streaming_pull_test.go
@@ -67,7 +67,7 @@
func testStreamingPullIteration(t *testing.T, client *Client, server *mockServer, msgs []*pb.ReceivedMessage) {
sub := client.Subscription("S")
gotMsgs, err := pullN(context.Background(), sub, len(msgs), func(_ context.Context, m *Message) {
- id, err := strconv.Atoi(m.ackID)
+ id, err := strconv.Atoi(m.ackID())
if err != nil {
panic(err)
}
@@ -83,20 +83,21 @@
}
gotMap := map[string]*Message{}
for _, m := range gotMsgs {
- gotMap[m.ackID] = m
+ gotMap[m.ackID()] = m
}
for i, msg := range msgs {
want, err := toMessage(msg)
if err != nil {
t.Fatal(err)
}
- want.calledDone = true
- got := gotMap[want.ackID]
+ wantAckh, _ := want.ackHandler()
+ wantAckh.calledDone = true
+ got := gotMap[wantAckh.ackID]
if got == nil {
- t.Errorf("%d: no message for ackID %q", i, want.ackID)
+ t.Errorf("%d: no message for ackID %q", i, wantAckh.ackID)
continue
}
- if !testutil.Equal(got, want, cmp.AllowUnexported(Message{}), cmpopts.IgnoreTypes(time.Time{}, func(string, bool, time.Time) {})) {
+ if !testutil.Equal(got, want, cmp.AllowUnexported(Message{}, psAckHandler{}), cmpopts.IgnoreTypes(time.Time{}, func(string, bool, time.Time) {})) {
t.Errorf("%d: got\n%#v\nwant\n%#v", i, got, want)
}
}
@@ -235,10 +236,10 @@
}
seen := map[string]bool{}
for _, gm := range gotMsgs {
- if seen[gm.ackID] {
- t.Fatalf("duplicate ID %q", gm.ackID)
+ if seen[gm.ackID()] {
+ t.Fatalf("duplicate ID %q", gm.ackID())
}
- seen[gm.ackID] = true
+ seen[gm.ackID()] = true
}
if len(seen) != nMessages {
t.Fatalf("got %d messages, want %d", len(seen), nMessages)
diff --git a/pubsub/subscription.go b/pubsub/subscription.go
index 59f296e..7e5f201 100644
--- a/pubsub/subscription.go
+++ b/pubsub/subscription.go
@@ -910,9 +910,10 @@
// Return nil if the context is done, not err.
return nil
}
- old := msg.doneFunc
+ ackh, _ := msg.ackHandler()
+ old := ackh.doneFunc
msgLen := len(msg.Data)
- msg.doneFunc = func(ackID string, ack bool, receiveTime time.Time) {
+ ackh.doneFunc = func(ackID string, ack bool, receiveTime time.Time) {
defer fc.release(msgLen)
old(ackID, ack, receiveTime)
}
diff --git a/pubsub/topic.go b/pubsub/topic.go
index 07e392e..1dd50c0 100644
--- a/pubsub/topic.go
+++ b/pubsub/topic.go
@@ -470,6 +470,13 @@
err error
}
+// NewPublishResult returns the set() function to enable callers from outside
+// this package to store and call it (e.g. unit tests).
+func NewPublishResult() (*PublishResult, func(string, error)) {
+ result := &PublishResult{ready: make(chan struct{})}
+ return result, result.set
+}
+
// Ready returns a channel that is closed when the result is ready.
// When the Ready channel is closed, Get is guaranteed not to block.
func (r *PublishResult) Ready() <-chan struct{} { return r.ready }