...

Source file src/cmd/vendor/golang.org/x/tools/go/analysis/passes/loopclosure/loopclosure.go

Documentation: cmd/vendor/golang.org/x/tools/go/analysis/passes/loopclosure

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package loopclosure
     6  
     7  import (
     8  	_ "embed"
     9  	"go/ast"
    10  	"go/types"
    11  
    12  	"golang.org/x/tools/go/analysis"
    13  	"golang.org/x/tools/go/analysis/passes/inspect"
    14  	"golang.org/x/tools/go/analysis/passes/internal/analysisutil"
    15  	"golang.org/x/tools/go/ast/inspector"
    16  	"golang.org/x/tools/go/types/typeutil"
    17  	"golang.org/x/tools/internal/typesinternal"
    18  	"golang.org/x/tools/internal/versions"
    19  )
    20  
    21  //go:embed doc.go
    22  var doc string
    23  
    24  var Analyzer = &analysis.Analyzer{
    25  	Name:     "loopclosure",
    26  	Doc:      analysisutil.MustExtractDoc(doc, "loopclosure"),
    27  	URL:      "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/loopclosure",
    28  	Requires: []*analysis.Analyzer{inspect.Analyzer},
    29  	Run:      run,
    30  }
    31  
    32  func run(pass *analysis.Pass) (interface{}, error) {
    33  	inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    34  
    35  	nodeFilter := []ast.Node{
    36  		(*ast.File)(nil),
    37  		(*ast.RangeStmt)(nil),
    38  		(*ast.ForStmt)(nil),
    39  	}
    40  	inspect.Nodes(nodeFilter, func(n ast.Node, push bool) bool {
    41  		if !push {
    42  			// inspect.Nodes is slightly suboptimal as we only use push=true.
    43  			return true
    44  		}
    45  		// Find the variables updated by the loop statement.
    46  		var vars []types.Object
    47  		addVar := func(expr ast.Expr) {
    48  			if id, _ := expr.(*ast.Ident); id != nil {
    49  				if obj := pass.TypesInfo.ObjectOf(id); obj != nil {
    50  					vars = append(vars, obj)
    51  				}
    52  			}
    53  		}
    54  		var body *ast.BlockStmt
    55  		switch n := n.(type) {
    56  		case *ast.File:
    57  			// Only traverse the file if its goversion is strictly before go1.22.
    58  			goversion := versions.FileVersion(pass.TypesInfo, n)
    59  			return versions.Before(goversion, versions.Go1_22)
    60  		case *ast.RangeStmt:
    61  			body = n.Body
    62  			addVar(n.Key)
    63  			addVar(n.Value)
    64  		case *ast.ForStmt:
    65  			body = n.Body
    66  			switch post := n.Post.(type) {
    67  			case *ast.AssignStmt:
    68  				// e.g. for p = head; p != nil; p = p.next
    69  				for _, lhs := range post.Lhs {
    70  					addVar(lhs)
    71  				}
    72  			case *ast.IncDecStmt:
    73  				// e.g. for i := 0; i < n; i++
    74  				addVar(post.X)
    75  			}
    76  		}
    77  		if vars == nil {
    78  			return true
    79  		}
    80  
    81  		// Inspect statements to find function literals that may be run outside of
    82  		// the current loop iteration.
    83  		//
    84  		// For go, defer, and errgroup.Group.Go, we ignore all but the last
    85  		// statement, because it's hard to prove go isn't followed by wait, or
    86  		// defer by return. "Last" is defined recursively.
    87  		//
    88  		// TODO: consider allowing the "last" go/defer/Go statement to be followed by
    89  		// N "trivial" statements, possibly under a recursive definition of "trivial"
    90  		// so that that checker could, for example, conclude that a go statement is
    91  		// followed by an if statement made of only trivial statements and trivial expressions,
    92  		// and hence the go statement could still be checked.
    93  		forEachLastStmt(body.List, func(last ast.Stmt) {
    94  			var stmts []ast.Stmt
    95  			switch s := last.(type) {
    96  			case *ast.GoStmt:
    97  				stmts = litStmts(s.Call.Fun)
    98  			case *ast.DeferStmt:
    99  				stmts = litStmts(s.Call.Fun)
   100  			case *ast.ExprStmt: // check for errgroup.Group.Go
   101  				if call, ok := s.X.(*ast.CallExpr); ok {
   102  					stmts = litStmts(goInvoke(pass.TypesInfo, call))
   103  				}
   104  			}
   105  			for _, stmt := range stmts {
   106  				reportCaptured(pass, vars, stmt)
   107  			}
   108  		})
   109  
   110  		// Also check for testing.T.Run (with T.Parallel).
   111  		// We consider every t.Run statement in the loop body, because there is
   112  		// no commonly used mechanism for synchronizing parallel subtests.
   113  		// It is of course theoretically possible to synchronize parallel subtests,
   114  		// though such a pattern is likely to be exceedingly rare as it would be
   115  		// fighting against the test runner.
   116  		for _, s := range body.List {
   117  			switch s := s.(type) {
   118  			case *ast.ExprStmt:
   119  				if call, ok := s.X.(*ast.CallExpr); ok {
   120  					for _, stmt := range parallelSubtest(pass.TypesInfo, call) {
   121  						reportCaptured(pass, vars, stmt)
   122  					}
   123  
   124  				}
   125  			}
   126  		}
   127  		return true
   128  	})
   129  	return nil, nil
   130  }
   131  
   132  // reportCaptured reports a diagnostic stating a loop variable
   133  // has been captured by a func literal if checkStmt has escaping
   134  // references to vars. vars is expected to be variables updated by a loop statement,
   135  // and checkStmt is expected to be a statements from the body of a func literal in the loop.
   136  func reportCaptured(pass *analysis.Pass, vars []types.Object, checkStmt ast.Stmt) {
   137  	ast.Inspect(checkStmt, func(n ast.Node) bool {
   138  		id, ok := n.(*ast.Ident)
   139  		if !ok {
   140  			return true
   141  		}
   142  		obj := pass.TypesInfo.Uses[id]
   143  		if obj == nil {
   144  			return true
   145  		}
   146  		for _, v := range vars {
   147  			if v == obj {
   148  				pass.ReportRangef(id, "loop variable %s captured by func literal", id.Name)
   149  			}
   150  		}
   151  		return true
   152  	})
   153  }
   154  
   155  // forEachLastStmt calls onLast on each "last" statement in a list of statements.
   156  // "Last" is defined recursively so, for example, if the last statement is
   157  // a switch statement, then each switch case is also visited to examine
   158  // its last statements.
   159  func forEachLastStmt(stmts []ast.Stmt, onLast func(last ast.Stmt)) {
   160  	if len(stmts) == 0 {
   161  		return
   162  	}
   163  
   164  	s := stmts[len(stmts)-1]
   165  	switch s := s.(type) {
   166  	case *ast.IfStmt:
   167  	loop:
   168  		for {
   169  			forEachLastStmt(s.Body.List, onLast)
   170  			switch e := s.Else.(type) {
   171  			case *ast.BlockStmt:
   172  				forEachLastStmt(e.List, onLast)
   173  				break loop
   174  			case *ast.IfStmt:
   175  				s = e
   176  			case nil:
   177  				break loop
   178  			}
   179  		}
   180  	case *ast.ForStmt:
   181  		forEachLastStmt(s.Body.List, onLast)
   182  	case *ast.RangeStmt:
   183  		forEachLastStmt(s.Body.List, onLast)
   184  	case *ast.SwitchStmt:
   185  		for _, c := range s.Body.List {
   186  			cc := c.(*ast.CaseClause)
   187  			forEachLastStmt(cc.Body, onLast)
   188  		}
   189  	case *ast.TypeSwitchStmt:
   190  		for _, c := range s.Body.List {
   191  			cc := c.(*ast.CaseClause)
   192  			forEachLastStmt(cc.Body, onLast)
   193  		}
   194  	case *ast.SelectStmt:
   195  		for _, c := range s.Body.List {
   196  			cc := c.(*ast.CommClause)
   197  			forEachLastStmt(cc.Body, onLast)
   198  		}
   199  	default:
   200  		onLast(s)
   201  	}
   202  }
   203  
   204  // litStmts returns all statements from the function body of a function
   205  // literal.
   206  //
   207  // If fun is not a function literal, it returns nil.
   208  func litStmts(fun ast.Expr) []ast.Stmt {
   209  	lit, _ := fun.(*ast.FuncLit)
   210  	if lit == nil {
   211  		return nil
   212  	}
   213  	return lit.Body.List
   214  }
   215  
   216  // goInvoke returns a function expression that would be called asynchronously
   217  // (but not awaited) in another goroutine as a consequence of the call.
   218  // For example, given the g.Go call below, it returns the function literal expression.
   219  //
   220  //	import "sync/errgroup"
   221  //	var g errgroup.Group
   222  //	g.Go(func() error { ... })
   223  //
   224  // Currently only "golang.org/x/sync/errgroup.Group()" is considered.
   225  func goInvoke(info *types.Info, call *ast.CallExpr) ast.Expr {
   226  	if !isMethodCall(info, call, "golang.org/x/sync/errgroup", "Group", "Go") {
   227  		return nil
   228  	}
   229  	return call.Args[0]
   230  }
   231  
   232  // parallelSubtest returns statements that can be easily proven to execute
   233  // concurrently via the go test runner, as t.Run has been invoked with a
   234  // function literal that calls t.Parallel.
   235  //
   236  // In practice, users rely on the fact that statements before the call to
   237  // t.Parallel are synchronous. For example by declaring test := test inside the
   238  // function literal, but before the call to t.Parallel.
   239  //
   240  // Therefore, we only flag references in statements that are obviously
   241  // dominated by a call to t.Parallel. As a simple heuristic, we only consider
   242  // statements following the final labeled statement in the function body, to
   243  // avoid scenarios where a jump would cause either the call to t.Parallel or
   244  // the problematic reference to be skipped.
   245  //
   246  //	import "testing"
   247  //
   248  //	func TestFoo(t *testing.T) {
   249  //		tests := []int{0, 1, 2}
   250  //		for i, test := range tests {
   251  //			t.Run("subtest", func(t *testing.T) {
   252  //				println(i, test) // OK
   253  //		 		t.Parallel()
   254  //				println(i, test) // Not OK
   255  //			})
   256  //		}
   257  //	}
   258  func parallelSubtest(info *types.Info, call *ast.CallExpr) []ast.Stmt {
   259  	if !isMethodCall(info, call, "testing", "T", "Run") {
   260  		return nil
   261  	}
   262  
   263  	if len(call.Args) != 2 {
   264  		// Ignore calls such as t.Run(fn()).
   265  		return nil
   266  	}
   267  
   268  	lit, _ := call.Args[1].(*ast.FuncLit)
   269  	if lit == nil {
   270  		return nil
   271  	}
   272  
   273  	// Capture the *testing.T object for the first argument to the function
   274  	// literal.
   275  	if len(lit.Type.Params.List[0].Names) == 0 {
   276  		return nil
   277  	}
   278  
   279  	tObj := info.Defs[lit.Type.Params.List[0].Names[0]]
   280  	if tObj == nil {
   281  		return nil
   282  	}
   283  
   284  	// Match statements that occur after a call to t.Parallel following the final
   285  	// labeled statement in the function body.
   286  	//
   287  	// We iterate over lit.Body.List to have a simple, fast and "frequent enough"
   288  	// dominance relationship for t.Parallel(): lit.Body.List[i] dominates
   289  	// lit.Body.List[j] for i < j unless there is a jump.
   290  	var stmts []ast.Stmt
   291  	afterParallel := false
   292  	for _, stmt := range lit.Body.List {
   293  		stmt, labeled := unlabel(stmt)
   294  		if labeled {
   295  			// Reset: naively we don't know if a jump could have caused the
   296  			// previously considered statements to be skipped.
   297  			stmts = nil
   298  			afterParallel = false
   299  		}
   300  
   301  		if afterParallel {
   302  			stmts = append(stmts, stmt)
   303  			continue
   304  		}
   305  
   306  		// Check if stmt is a call to t.Parallel(), for the correct t.
   307  		exprStmt, ok := stmt.(*ast.ExprStmt)
   308  		if !ok {
   309  			continue
   310  		}
   311  		expr := exprStmt.X
   312  		if isMethodCall(info, expr, "testing", "T", "Parallel") {
   313  			call, _ := expr.(*ast.CallExpr)
   314  			if call == nil {
   315  				continue
   316  			}
   317  			x, _ := call.Fun.(*ast.SelectorExpr)
   318  			if x == nil {
   319  				continue
   320  			}
   321  			id, _ := x.X.(*ast.Ident)
   322  			if id == nil {
   323  				continue
   324  			}
   325  			if info.Uses[id] == tObj {
   326  				afterParallel = true
   327  			}
   328  		}
   329  	}
   330  
   331  	return stmts
   332  }
   333  
   334  // unlabel returns the inner statement for the possibly labeled statement stmt,
   335  // stripping any (possibly nested) *ast.LabeledStmt wrapper.
   336  //
   337  // The second result reports whether stmt was an *ast.LabeledStmt.
   338  func unlabel(stmt ast.Stmt) (ast.Stmt, bool) {
   339  	labeled := false
   340  	for {
   341  		labelStmt, ok := stmt.(*ast.LabeledStmt)
   342  		if !ok {
   343  			return stmt, labeled
   344  		}
   345  		labeled = true
   346  		stmt = labelStmt.Stmt
   347  	}
   348  }
   349  
   350  // isMethodCall reports whether expr is a method call of
   351  // <pkgPath>.<typeName>.<method>.
   352  func isMethodCall(info *types.Info, expr ast.Expr, pkgPath, typeName, method string) bool {
   353  	call, ok := expr.(*ast.CallExpr)
   354  	if !ok {
   355  		return false
   356  	}
   357  
   358  	// Check that we are calling a method <method>
   359  	f := typeutil.StaticCallee(info, call)
   360  	if f == nil || f.Name() != method {
   361  		return false
   362  	}
   363  	recv := f.Type().(*types.Signature).Recv()
   364  	if recv == nil {
   365  		return false
   366  	}
   367  
   368  	// Check that the receiver is a <pkgPath>.<typeName> or
   369  	// *<pkgPath>.<typeName>.
   370  	_, named := typesinternal.ReceiverNamed(recv)
   371  	return analysisutil.IsNamedType(named, pkgPath, typeName)
   372  }
   373  

View as plain text