https://github.com/torch/cunn
Revision f225bb44c2cbccf90a48ffc4f25ee9056958ac63 authored by Soumith Chintala on 31 March 2017, 18:44:07 UTC, committed by GitHub on 31 March 2017, 18:44:07 UTC
Using temporary variables when performing transpose + addmm
Tip revision: f225bb44c2cbccf90a48ffc4f25ee9056958ac63 authored by Soumith Chintala on 31 March 2017, 18:44:07 UTC
Merge pull request #456 from twitter-forks/addmm-fixes
Merge pull request #456 from twitter-forks/addmm-fixes
Tip revision: f225bb4
DataParallelTable.lua
--[[
This file implements data parallelism for Torch modules.
The same model is replicated on multiple GPUs. The input is split, typically
into smaller mini-batches. Each replicated model handles only its portion of the input.
The weight updates for each replica are summed together on the first replica
in accGradParameters.
By default, this module uses only one thread and relies on asynchronous kernel launches.
To use multiple threads, call DataParallelTable:threads(initFunc).
For best performance, install NCCL:
https://github.com/NVIDIA/nccl
https://github.com/ngimel/nccl.torch
]]--
local DataParallelTable, parent = torch.class('nn.DataParallelTable', 'nn.Container')
local Impls = {}
local BasicImpl = torch.class('nn.DataParallelTable.Basic', Impls)
local ThreadsImpl = torch.class('nn.DataParallelTable.Threads', Impls)
local unpack = unpack and unpack or table.unpack -- lua52 compatibility
-- NCCL does not work when CUDA_LAUNCH_BLOCKING is set
local cudaLaunchBlocking = os.getenv('CUDA_LAUNCH_BLOCKING') == '1'
-- extracts the value at idx from each entry in tbl
local function pluck(tbl, idx)
local r = {}
for n, val in ipairs(tbl) do
r[n] = val[idx]
end
return r
end
-- Synchronizes the current stream on dst device with src device. This is only
-- necessary if we are not on the default stream
local function waitForDevice(dst, src)
local stream = cutorch.getStream()
if stream ~= 0 then
cutorch.streamWaitForMultiDevice(dst, stream, { [src] = {stream} })
end
end
function DataParallelTable:__init(dimension, flattenParams, usenccl)
parent.__init(self)
if not dimension then
error "must specify a dimension!"
end
self.typeStr = 'torch.CudaTensor'
self.dimension = dimension
self.modules = {}
self.gpuAssignments = {} -- Which gpuid each module sits on
self.inputGpu = {} -- inputs for each gpu
self.gradOutputGpu = {} -- gradOutputs for each gpu
self.outputGpu = {} -- outputs for each gpu
self.gradInputGpu = {} -- gradInput for each gpu
self.flattenedParams = nil -- flattened parameters for each gpu
self.flattenParams = flattenParams or false
self.usenccl = false
self.needsSync = false
self.impl = Impls.Basic(self)
if usenccl then
assert(self.flattenParams, 'cannot use nccl without flattenParams')
self.usenccl = pcall(require, 'nccl')
if not self.usenccl then
print("warning: could not load nccl, falling back to default communication")
end
end
end
function DataParallelTable:add(module, gpus)
if type(gpus) == 'number' then
if #self.modules == 0 then
table.insert(self.modules, module)
end
table.insert(self.gpuAssignments, gpus)
return self
end
assert(torch.type(gpus) == 'table' and #gpus >= 1, 'table of GPU IDs required')
assert(#self.modules == 0, 'add should only be called once with a table of GPU assignments')
self.modules[1] = module
self.gpuAssignments = gpus
return self
end
function DataParallelTable:threads(initFunc, syncCopies)
require 'threads'
self.impl:close()
self.impl = Impls.Threads(self, initFunc, syncCopies)
return self
end -- NOTE: Setting syncCopies will copy model to GPUs synchronously.
function DataParallelTable:__tostring()
return 'DataParallelTable: ' .. #self.gpuAssignments .. ' x ' .. tostring(self.modules[1])
end
function DataParallelTable:get(index)
return self.modules[index]
end
-- this flattens parameters, so that syncParameters and accGradParameters can be much more efficient
function DataParallelTable:flattenParameters()
local typeStr = self.typeStr
self.flattenedParams = self.impl:exec(function(module)
local p, dp = module:parameters()
local flattened = true
for i=2,#p do
if p[i]:storage() ~= p[1]:storage()
or dp[i]:storage() ~= dp[1]:storage() then
flattened = false
break
end
end
if flattened then
local pp = torch[typeStr:match('torch.(%a+)')](p[1]:storage(), p[1]:storageOffset(),
p[#p]:storageOffset()+p[#p]:numel()-p[1]:storageOffset())
local dpp = torch[typeStr:match('torch.(%a+)')](dp[1]:storage(), dp[1]:storageOffset(),
dp[#dp]:storageOffset()+dp[#dp]:numel()
- dp[1]:storageOffset())
return {pp, dpp}
else
return { module:getParameters() }
end
end)
self.flattenParams = true
end
function DataParallelTable:getParameters()
self:flattenParameters()
return table.unpack(self.flattenedParams[1])
end
local function hasFlattenedParameters(self)
if not self.flattenedParams then
return false
end
for _, param in ipairs(self.modules[1]:parameters()) do
if param:storage() ~= self.flattenedParams[1][1]:storage() then
return false
end
end
return true
end
function DataParallelTable:training()
self.impl:exec(function(module)
module:training()
end)
parent.training(self)
end
function DataParallelTable:evaluate()
self.impl:exec(function(module)
module:evaluate()
end)
parent.evaluate(self)
end
function DataParallelTable:clearState()
self.impl:exec(function(module)
module:clearState()
end)
return parent.clearState(self)
end
local function _hasData(input)
if torch.isTensor(input) then
return input:numel() ~= 0
else
assert(type(input) == 'table')
for i = 1, #input do
if _hasData(input[i]) then
return true
end
end
return false
end
end
function DataParallelTable:updateOutput(input)
if self.flattenParams and not hasFlattenedParameters(self) then
self:flattenParameters()
end
if self.needsSync then
self:syncParameters()
end
local prevGpuid = cutorch.getDevice()
-- distribute the input to GPUs
self.maxUsedGpu = self:_distribute(self.inputGpu, input)
-- update output for each module
local inputGpu = self.inputGpu
self.outputGpu = self.impl:exec(function(m, i)
return m:updateOutput(inputGpu[i])
end, self.maxUsedGpu)
-- concatenate the outputs to the base GPU
self.output = self:_concat(self.output, self.outputGpu)
cutorch.setDevice(prevGpuid)
return self.output
end
function DataParallelTable:moduleParameters()
-- Returns a table containing the parameters for each replica
if self.flattenedParams then
local res = {}
for i, params in ipairs(self.flattenedParams) do
res[i] = { {params[1]}, {params[2]} }
end
return res
end
return self.impl:exec(function(m)
return { m:parameters() }
end)
end
function DataParallelTable:__backward(method, input, gradOutput, scale)
local prevGpuid = cutorch.getDevice()
local inputGpu, gradOutputGpu = self.inputGpu, self.gradOutputGpu
if method == 'backward' or method == 'updateGradInput' then
-- distribute the gradOutput to GPUs
self:_distribute(self.gradOutputGpu, gradOutput)
self.gradInputGpu = self.impl:exec(function(m, i)
return m[method](m, inputGpu[i], gradOutputGpu[i], scale)
end, self.maxUsedGpu)
if self.gradInput then
-- concatenate the gradInput to the base GPU
self.gradInput = self:_concat(self.gradInput, self.gradInputGpu)
end
end
if method == 'accGradParameters' then
self.impl:exec(function(m, i)
return m:accGradParameters(inputGpu[i], gradOutputGpu[i], scale)
end, self.maxUsedGpu)
end
if method == 'backward' or method == 'accGradParameters' then
local params = self:moduleParameters()
-- Accumulate the gradients onto the base GPU
if self.flattenedParams and self.usenccl and not cudaLaunchBlocking then
if #self.gpuAssignments > 1 then
nccl.reduce(pluck(self.flattenedParams, 2), nil, true, 1)
end
else
self:_reduce(pluck(params, 2))
end
-- Zero out gradients on the other GPUs
for i = 2, #self.gpuAssignments do
cutorch.setDevice(self.gpuAssignments[i])
for _, gradParam in ipairs(params[i][2]) do
gradParam:zero()
end
end
self.needsSync = true
end
cutorch.setDevice(prevGpuid)
return self.gradInput
end
function DataParallelTable:backward(input, gradOutput, scale)
return self:__backward('backward', input, gradOutput, scale)
end
function DataParallelTable:updateGradInput(input, gradOutput)
return self:__backward('updateGradInput', input, gradOutput)
end
function DataParallelTable:accGradParameters(input, gradOutput, scale)
self:__backward('accGradParameters', input, gradOutput, scale)
end
function DataParallelTable:syncParameters()
local prevGpuid = cutorch.getDevice()
if self.flattenedParams and self.usenccl and not cudaLaunchBlocking then
if #self.gpuAssignments > 1 then
nccl.bcast(pluck(self.flattenedParams, 1), true, 1)
end
else
self:_broadcast(pluck(self:moduleParameters(), 1))
end
self.needsSync = false
cutorch.setDevice(prevGpuid)
end
function DataParallelTable:accUpdateGradParameters(input, gradOutput, lr)
error("accUpdateGradParameters not supported for DataParallelTable.")
end
function DataParallelTable:zeroGradParameters()
local prevGpuid = cutorch.getDevice()
if self.flattenedParams then
for i, parameters in ipairs(self.flattenedParams) do
cutorch.setDevice(self.gpuAssignments[i])
parameters[2]:zero()
end
else
self.impl:exec(function(m)
m:zeroGradParameters()
end)
end
cutorch.setDevice(prevGpuid)
end
function DataParallelTable:updateParameters(learningRate)
local prevGpuid = cutorch.getDevice()
cutorch.setDevice(self.gpuAssignments[1])
self.modules[1]:updateParameters(learningRate)
self:syncParameters()
cutorch.setDevice(prevGpuid)
end
function DataParallelTable:parameters()
return self.modules[1]:parameters()
end
function DataParallelTable:share(mlp,...)
error("Share not supported for DataParallelTable")
end
function DataParallelTable:clone(...)
assert(select('#',...) == 0, "Sharing not supported for DataParallelTable")
return parent.clone(self)
end
function DataParallelTable:reset(stdv)
local prevGpuid = cutorch.getDevice()
cutorch.setDevice(self.gpuAssignments[1])
self.modules[1]:reset(stdv)
self:syncParameters()
cutorch.setDevice(prevGpuid)
end
function DataParallelTable:type(typeStr)
assert(typeStr == 'torch.CudaHalfTensor' or typeStr == 'torch.CudaTensor' or typeStr == 'torch.CudaDoubleTensor',
'DataParallelTable supports only torch.CudaHalfTensor or torch.CudaDoubleTensor or torch.CudaTensor types')
for i, m in ipairs(self.modules) do
m:type(typeStr)
end
self.typeStr = typeStr
return self
end
-- Backward compatibility purposes
DataParallelTable.__version = 3
-- DataParallelTable.deserializeNGPUs controls how many GPUs to deserialize
-- upon, otherwise will deserialize to as many GPUs as serialized and error
-- out if it doesn;t have enough available
function DataParallelTable:__read(file, version)
if version < 2 then
local var = file:readObject()
for k, v in pairs(var) do
self[k] = v
end
self.impl = self.impl or Impls.Basic(self)
return
end
-- Pre-read gpuAssignments and either use them of ignore them depending on
-- whether DataParallelTable.deserializeNGPUs is set.
local gpuAssignments = file:readObject()
if DataParallelTable.deserializeNGPUs then
gpuAssignments = {}
for i = 1, DataParallelTable.deserializeNGPUs do gpuAssignments[i] = i end
if DataParallelTable.deserializeNGPUs > cutorch.getDeviceCount() then
error('Deserialization requested on too many GPUs: ' ..
DataParallelTable.deserializeNGPUs .. ' vs ' ..
cutorch.getDeviceCount() .. ' available')
end
end
-- If DataParallelTable.deserializeNGPUs, deserialization overrides
-- gpu assignments anyway. If not, we need as many GPUs as the max,
-- there may be holes.
local nGPUs = math.max(unpack(gpuAssignments))
if nGPUs > cutorch.getDeviceCount() then
error('Model was serialized on ' ..
math.max(unpack(gpuAssignments)) ..
' nGPUs, but you are running on ' .. cutorch.getDeviceCount() ..
' please set DataParallelTable.deserializeNGPUs to ignore ' ..
' serialized tower-GPU assignments')
end
local prevGpuid = cutorch.getDevice()
cutorch.setDevice(gpuAssignments[1])
-- Deserialize from table
local var = file:readObject()
for k, v in pairs(var) do
self[k] = v
end
cutorch.setDevice(prevGpuid)
if self.usenccl then
self.usenccl = pcall(require, 'nccl')
end
if not self.impl then
self.impl = Impls.Basic(self)
end
-- use previously deserialize / recomputed gpuAssignments
self.gpuAssignments = gpuAssignments
assert(#self.modules == 1)
local flattenedParams = self.flattenedParams
if flattenedParams then
self.flattenedParams = self.impl:exec(function(m, i)
if i == 1 then
return flattenedParams[1]
else
return { m:getParameters() }
end
end)
end
end
function DataParallelTable:__write(file)
-- Prewrite the current assignments, we may need them to
-- deserialize the first tower
file:writeObject(self.gpuAssignments)
-- Convert to table
local t = {}
for k, v in pairs(self) do
-- Only keep the flattenedParams from the first module
if k == 'flattenedParams' then
t[k] = {v[1]}
elseif k == 'inputGpu' or k == 'outputGpu' or k == 'gradInputGpu' or k == 'gradOutputGpu' then
t[k] = {}
elseif k == 'buffer' then
t[k] = nil
else
t[k] = v
end
end
file:writeObject(t)
-- Force synchronization, this keeps you honest
self:syncParameters()
end
function DataParallelTable:_reflattenReplicaParameters()
local flattenedParams = self.flattenedParams
if flattenedParams then
self.flattenedParams = self.impl:exec(function(m, i)
if i == 1 then
return flattenedParams[1]
else
return { m:getParameters() }
end
end)
end
end
function DataParallelTable:apply(callback)
parent.apply(self, callback)
self.impl:applyChanges()
self:_reflattenReplicaParameters()
end
local function sliceRange(nElem, idx, splits)
local eltsPerMod = math.floor(nElem / splits)
local numExtra = nElem - eltsPerMod * splits
if idx <= numExtra then
rangeStart = (idx - 1) * (eltsPerMod + 1) + 1
return rangeStart, eltsPerMod + 1
else
rangeStart = numExtra * (eltsPerMod + 1) + (idx - 1 - numExtra) * eltsPerMod + 1
return rangeStart, eltsPerMod
end
end
local function sumSizes(tensors, dim)
local size
for i=1,#tensors do
if tensors[i]:numel() > 0 then
if size then
size[dim] = size[dim] + tensors[i]:size(dim)
else
size = tensors[i]:size()
end
end
end
return size
end
-- Copies the parameters from the first replica to all other replicas
function DataParallelTable:_broadcast(params)
for moduleIdx = 2, #params do
for paramIdx = 1, #params[moduleIdx] do
params[moduleIdx][paramIdx]:copy(params[1][paramIdx])
end
waitForDevice(self.gpuAssignments[moduleIdx], self.gpuAssignments[1])
end
end
-- Sums all the gradParams on to the first replica
function DataParallelTable:_reduce(gradParams)
local dstGpuid = self.gpuAssignments[1]
cutorch.setDevice(dstGpuid)
self.buffer = self.buffer or torch[self.typeStr:match('torch.(%a+)')]()
for moduleIdx = 2, #gradParams do
for paramIdx = 1, #gradParams[moduleIdx] do
local dst = gradParams[1][paramIdx]
local src = gradParams[moduleIdx][paramIdx]
-- Synchronize before and after copy to ensure that it doesn't overlap
-- with this add or previous adds
waitForDevice(self.gpuAssignments[moduleIdx], dstGpuid)
self.buffer:resizeAs(src):copy(src)
waitForDevice(dstGpuid, self.gpuAssignments[moduleIdx])
dst:add(self.buffer)
end
end
end
function DataParallelTable:_distribute(dst, src)
for i = 1, #self.gpuAssignments do
cutorch.setDevice(self.gpuAssignments[i])
dst[i] = self:_distributeTensorRecursive(dst[i], src, i, #self.gpuAssignments)
if not _hasData(dst[i]) then return i-1 end
end
end
-- _distributeTensorRecursive - if the src is a tensor then the function slices
-- it long self.dimension and copies each portion into each child module.
-- Otherwise it does a recursive call on tables.
function DataParallelTable:_distributeTensorRecursive(dst, src, idx, n)
if torch.type(src) == 'table' then
if torch.type(dst) ~= 'table' or #src ~= #dst then
dst = {}
end
-- Recurse on the table
for i, s in ipairs(src) do
dst[i] = self:_distributeTensorRecursive(dst[i], s, idx, n)
end
return dst
end
assert(torch.isTensor(src), 'input must be a tensor or table of tensors')
if self.typeStr == 'torch.CudaHalfTensor' then
assert(src:type() == self.typeStr or src:type() == 'torch.HalfTensor',
'input must be a CudaHalf or Half tensor')
elseif self.typeStr == 'torch.CudaDoubleTensor' then
assert(src:type() == self.typeStr or src:type() == 'torch.DoubleTensor',
'input must be a CudaDouble or Double tensor')
else
assert(src:type() == 'torch.CudaTensor' or src:type() == 'torch.FloatTensor',
'input must be a CUDA or Float tensor')
end
dst = torch.type(dst) == self.typeStr and dst or torch[self.typeStr:match('torch.(%a+)')]()
local srcsize = src:dim() > 0 and src:size(self.dimension) or 0
local index, size = sliceRange(srcsize, idx, n)
if size == 0 then
dst:resize(0)
else
local slice = src:narrow(self.dimension, index, size)
dst:resize(slice:size()):copyAsync(slice)
if slice.getDevice then
waitForDevice(dst:getDevice(), slice:getDevice())
end
end
return dst
end
-- _concat - if the src is a tensor then the function copies it
-- into the dst slice along self.dimension.
-- Otherwise it does a recursive call on tables.
function DataParallelTable:_concat(dst, src)
dst = self:_concatTensorRecursive(dst, src)
for i=2,#self.gpuAssignments do
waitForDevice(self.gpuAssignments[1], self.gpuAssignments[i])
end
return dst
end
function DataParallelTable:_concatTensorRecursive(dst, src)
if torch.type(src[1]) == 'table' then
if torch.type(dst) ~= 'table' or #src[1] ~= #dst then
dst = {}
end
for i, _ in ipairs(src[1]) do
dst[i] = self:_concatTensorRecursive(dst[i], pluck(src, i))
end
return dst
end
assert(torch.isTensor(src[1]), 'input must be a tensor or table of tensors')
cutorch.setDevice(self.gpuAssignments[1])
dst = torch.type(dst) == self.typeStr and dst or torch[self.typeStr:match('torch.(%a+)')]()
local cumsum = sumSizes(src, self.dimension)
if cumsum == nil then return dst end
dst:resize(cumsum)
local start = 1
for i, s in ipairs(src) do
if torch.numel(s) > 0 then
local sz = s:size(self.dimension)
dst:narrow(self.dimension, start, sz):copy(s)
start = start + sz
end
end
return dst
end
-- Single-thread dispatch
function BasicImpl:__init(dpt)
self.dpt = dpt
end
-- Re-copies the first replica onto all the other GPUs, if already setup
function BasicImpl:applyChanges()
if self.modules then
local prevGpuid = cutorch.getDevice()
self.modules = { self.dpt.modules[1] }
collectgarbage()
for i=2,#self.dpt.gpuAssignments do
cutorch.setDevice(self.dpt.gpuAssignments[i])
table.insert(self.modules, self.dpt.modules[1]:clone())
end
cutorch.setDevice(prevGpuid)
end
end
-- Copies the first replica onto all the other GPUs, if necessary
function BasicImpl:setup()
if not self.modules then
self.modules = {}
self:applyChanges()
end
end
-- Applies a function to each replica, combining the results into a table
function BasicImpl:exec(closure, maxGpuIdx)
local prevGpuid = cutorch.getDevice()
self:setup()
local res = {}
for i, gpu in ipairs(self.dpt.gpuAssignments) do
if maxGpuIdx and i > maxGpuIdx then break end
cutorch.setDevice(gpu)
res[i] = closure(self.modules[i], i)
end
cutorch.setDevice(prevGpuid)
return res
end
function BasicImpl:__write(file)
local t = {}
for k, v in pairs(self) do
if k ~= 'modules' then
t[k] = v
end
end
file:writeObject(t)
end
function BasicImpl:close()
self.modules = nil
end
-- Multi-threaded dispatch
function ThreadsImpl:__init(dpt, initFunc, syncCopies)
self.dpt = dpt
self.initFunc = initFunc
self.syncCopies = syncCopies
-- This makes initial copy of models to GPUs synchronous. Set this option
-- in case your model serialization code is not thread-safe.
end
function ThreadsImpl:applyChanges(sync)
if self.__threads then
local module = self.dpt.modules[1]
for i, gpu in ipairs(self.dpt.gpuAssignments) do
self.__threads:addjob(i, function()
cutorch.setDevice(gpu)
if i == 1 then
_G.module = module
else
_G.module = nil
collectgarbage()
_G.module = module:clone()
end
end)
if sync then
self.__threads:synchronize()
end -- if sync is set, changes are applied synchronously
end
self.__threads:synchronize()
end
end
function ThreadsImpl:setup()
if not self.__threads then
local threads = require 'threads'
threads.Threads.serialization('threads.sharedserialize')
self.__threads = threads.Threads(
#self.dpt.gpuAssignments,
function() require 'cunn' end,
self.initFunc)
self.__threads:specific(true)
self:applyChanges(self.syncCopies)
end
end
function ThreadsImpl:exec(closure, maxGpuIdx)
self:setup()
local res = {}
for i=1,#self.dpt.gpuAssignments do
if maxGpuIdx and i > maxGpuIdx then break end
self.__threads:addjob(i,
function()
return closure(_G.module, i)
end,
function (_res_)
res[i] = _res_
end)
end
self.__threads:synchronize()
return res
end
function ThreadsImpl:close()
self.__threads:terminate()
self.__threads = nil
end
function ThreadsImpl:__write(file)
local t = {}
for k, v in pairs(self) do
if k ~= '__threads' then
t[k] = v
end
end
file:writeObject(t)
end
![swh spinner](/static/img/swh-spinner.gif)
Computing file changes ...