Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convolution updateWeights function needs to be modified. #106

Closed
leeys888 opened this issue Apr 25, 2018 · 1 comment
Closed

Convolution updateWeights function needs to be modified. #106

leeys888 opened this issue Apr 25, 2018 · 1 comment
Assignees

Comments

@leeys888
Copy link

open func updateWeights(device: MTLDevice) {
        guard let network = network else {
            return
        }

        if #available(iOS 11.0, *) {
            if let weightsPointer = weightsPointer {
                dataSource = ConvolutionDataSource(cnnDescriptor: cnnDescriptor,
                                                   weights: UnsafeMutableRawPointer(mutating: weightsPointer.pointer()),
                                                   bias: UnsafeMutablePointer(mutating: biasPointer?.pointer() as UnsafePointer<Float>?))
            } else {
                dataSource = ConvolutionDataSource(cnnDescriptor: cnnDescriptor, parameterLoader: network.parameterLoader,
                                                   layerId: id, weightCount: getWeightsSize(), biasCount:  convSize.outputChannels)
            }
            makeConv(device: device, weights: nil, bias: nil)
        } else {
            let weights = weightsPointer?.pointer() ?? network.parameterLoader.loadWeights(for: id,
                                                                                           modifier: Convolution.weightModifier,
                                                                                           size: getWeightsSize())

            var bias: UnsafePointer<Float>? = nil
            if useBias {
                bias = biasPointer?.pointer() ?? network.parameterLoader.loadWeights(for: id,
                                                                                     modifier: Convolution.biasModifier,
                                                                                     size: convSize.outputChannels)
            }
            makeConv(device: device, weights: weights, bias: bias)
        }
    }

i think that "ConvolutionDataSource" code will be modified.
This code does not consider whether to use bias or not.
So I suggest the following code.

else {
                dataSource = ConvolutionDataSource(cnnDescriptor: cnnDescriptor, parameterLoader: network.parameterLoader,
                                                   layerId: id, weightCount: getWeightsSize(), biasCount: useBias ? convSize.outputChannels : 0)
            }

Thank you.

@bryant1410
Copy link
Member

It seems you are right. @mats-claassen can you take a look when you are back?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants