blob: a9dc9972ff820ae227866c1d3a90d83227b13773 [file] [log] [blame]
// Copyright 2018 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tmutex
import (
"fmt"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestBasicLock(t *testing.T) {
var m Mutex
m.Init()
m.Lock()
// Try blocking lock the mutex from a different goroutine. This must
// not block because the mutex is held.
ch := make(chan struct{}, 1)
go func() {
m.Lock()
ch <- struct{}{}
m.Unlock()
ch <- struct{}{}
}()
select {
case <-ch:
t.Fatalf("Lock succeeded on locked mutex")
case <-time.After(100 * time.Millisecond):
}
// Unlock the mutex and make sure that the goroutine waiting on Lock()
// unblocks and succeeds.
m.Unlock()
select {
case <-ch:
case <-time.After(100 * time.Millisecond):
t.Fatalf("Lock failed to acquire unlocked mutex")
}
// Make sure we can lock and unlock again.
m.Lock()
m.Unlock()
}
func TestTryLock(t *testing.T) {
var m Mutex
m.Init()
// Try to lock. It should succeed.
if !m.TryLock() {
t.Fatalf("TryLock failed on unlocked mutex")
}
// Try to lock again, it should now fail.
if m.TryLock() {
t.Fatalf("TryLock succeeded on locked mutex")
}
// Try blocking lock the mutex from a different goroutine. This must
// not block because the mutex is held.
ch := make(chan struct{}, 1)
go func() {
m.Lock()
ch <- struct{}{}
m.Unlock()
}()
select {
case <-ch:
t.Fatalf("Lock succeeded on locked mutex")
case <-time.After(100 * time.Millisecond):
}
// Unlock the mutex and make sure that the goroutine waiting on Lock()
// unblocks and succeeds.
m.Unlock()
select {
case <-ch:
case <-time.After(100 * time.Millisecond):
t.Fatalf("Lock failed to acquire unlocked mutex")
}
}
func TestMutualExclusion(t *testing.T) {
var m Mutex
m.Init()
// Test mutual exclusion by running "gr" goroutines concurrently, and
// have each one increment a counter "iters" times within the critical
// section established by the mutex.
//
// If at the end the counter is not gr * iters, then we know that
// goroutines ran concurrently within the critical section.
//
// If one of the goroutines doesn't complete, it's likely a bug that
// causes to it to wait forever.
const gr = 1000
const iters = 100000
v := 0
var wg sync.WaitGroup
for i := 0; i < gr; i++ {
wg.Add(1)
go func() {
for j := 0; j < iters; j++ {
m.Lock()
v++
m.Unlock()
}
wg.Done()
}()
}
wg.Wait()
if v != gr*iters {
t.Fatalf("Bad count: got %v, want %v", v, gr*iters)
}
}
func TestMutualExclusionWithTryLock(t *testing.T) {
var m Mutex
m.Init()
// Similar to the previous, with the addition of some goroutines that
// only increment the count if TryLock succeeds.
const gr = 1000
const iters = 100000
total := int64(gr * iters)
var tryTotal int64
v := int64(0)
var wg sync.WaitGroup
for i := 0; i < gr; i++ {
wg.Add(2)
go func() {
for j := 0; j < iters; j++ {
m.Lock()
v++
m.Unlock()
}
wg.Done()
}()
go func() {
local := int64(0)
for j := 0; j < iters; j++ {
if m.TryLock() {
v++
m.Unlock()
local++
}
}
atomic.AddInt64(&tryTotal, local)
wg.Done()
}()
}
wg.Wait()
t.Logf("tryTotal = %d", tryTotal)
total += tryTotal
if v != total {
t.Fatalf("Bad count: got %v, want %v", v, total)
}
}
// BenchmarkTmutex is equivalent to TestMutualExclusion, with the following
// differences:
//
// - The number of goroutines is variable, with the maximum value depending on
// GOMAXPROCS.
//
// - The number of iterations per benchmark is controlled by the benchmarking
// framework.
//
// - Care is taken to ensure that all goroutines participating in the benchmark
// have been created before the benchmark begins.
func BenchmarkTmutex(b *testing.B) {
for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 {
b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
var m Mutex
m.Init()
var ready sync.WaitGroup
begin := make(chan struct{})
var end sync.WaitGroup
for i := 0; i < n; i++ {
ready.Add(1)
end.Add(1)
go func() {
ready.Done()
<-begin
for j := 0; j < b.N; j++ {
m.Lock()
m.Unlock()
}
end.Done()
}()
}
ready.Wait()
b.ResetTimer()
close(begin)
end.Wait()
})
}
}
// BenchmarkSyncMutex is equivalent to BenchmarkTmutex, but uses sync.Mutex as
// a comparison point.
func BenchmarkSyncMutex(b *testing.B) {
for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 {
b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
var m sync.Mutex
var ready sync.WaitGroup
begin := make(chan struct{})
var end sync.WaitGroup
for i := 0; i < n; i++ {
ready.Add(1)
end.Add(1)
go func() {
ready.Done()
<-begin
for j := 0; j < b.N; j++ {
m.Lock()
m.Unlock()
}
end.Done()
}()
}
ready.Wait()
b.ResetTimer()
close(begin)
end.Wait()
})
}
}