Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

webGL support #17610

Closed
mccoysc opened this issue Jan 26, 2025 · 2 comments
Closed

webGL support #17610

mccoysc opened this issue Jan 26, 2025 · 2 comments
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@mccoysc
Copy link

mccoysc commented Jan 26, 2025

As the title.
Since there are still many browsers that do not support the WebGPU interface, it is important to add support for WebGL, and improvements are needed.

@mccoysc mccoysc added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Jan 26, 2025
@tqchen
Copy link
Member

tqchen commented Jan 27, 2025

Unfortunately because webgl do not support most of the GPU primitives, so we will need to stick with WebgPU for now

@tqchen tqchen closed this as completed Jan 27, 2025
@mccoysc
Copy link
Author

mccoysc commented Jan 28, 2025

Unfortunately because webgl do not support most of the GPU primitives, so we will need to stick with WebgPU for now

i have one implemented by gpu.js(gpu.js code is generated by LLMs,but it seems that LLM generate wrong code).
if you can write some glsl code,no need to generate gpu.js code from wgsl ,and maybe it will work ok.

// WebGPU Polyfill using WebGL
import OpenAI from 'openai';

type KernelFunction = (...args: KernelVariable[]) => KernelOutput;

interface IGPUSettings {
  mode?: 'gpu' | 'cpu' | 'dev' | 'webgl' | 'webgl2' | 'headlessgl';
  canvas?: object;
  context?: object;
  functions?: KernelFunction[];
  constants?: {
    [key: string]: number;
  };
}

interface IKernelRunShortcutMethods {
  setOutput(output: number[]): IKernelRunShortcut;
}

interface IKernelRunShortcutBase<T = KernelOutput> extends Kernel, IKernelRunShortcutMethods {
  kernel: Kernel;
  (...args: KernelVariable[]): T;
  exec(...args: KernelVariable[]): T;
}

interface IKernelRunShortcut extends IKernelRunShortcutBase<KernelOutput> {
  readonly _type: 'kernel-run-shortcut';
  constants: {
    [key: string]: number;
  };
}

type KernelVariable = boolean | number | Float32Array | Uint8Array | Uint16Array | Uint32Array | Uint8ClampedArray | number[] | number[][] | number[][][] | Float32Array[] | Float32Array[][] | Float32Array[][][];
type KernelOutput = void | number | number[] | number[][] | number[][][] | Float32Array | Float32Array[] | Float32Array[][] | Float32Array[][][];

class Kernel {
  static isSupported = true;
  output: number[] = [1];
  destroy(): void {
    // 基类空实现
  }
}

class GPU {
  static isGPUSupported = (() => {
    try {
      return typeof navigator !== 'undefined' && navigator.gpu !== undefined;
    } catch {
      return false;
    }
  })();

  static isWebGLSupported = (() => {
    try {
      const canvas = document.createElement('canvas');
      return !!canvas.getContext('webgl');
    } catch {
      return false;
    }
  })();

  static isWebGL2Supported = (() => {
    try {
      const canvas = document.createElement('canvas');
      return !!canvas.getContext('webgl2');
    } catch {
      return false;
    }
  })();

  private settings: IGPUSettings;

  constructor(settings?: IGPUSettings) {
    // 设置默认值
    const defaultSettings: IGPUSettings = {
      mode: 'gpu',
      functions: [],
      constants: {}
    };

    // 合并用户设置
    this.settings = {
      ...defaultSettings,
      ...settings
    };
  }

  getSettings(): IGPUSettings {
    return this.settings;
  }

  createKernel(fn: KernelFunction, settings?: IGPUSettings): IKernelRunShortcut {
    const baseKernel = new Kernel();
    const mergedSettings = {
      ...this.settings,
      ...settings,
      constants: {
        ...this.settings.constants,
        ...settings?.constants
      }
    };
    const kernelContext = {
      thread: { x: 0, y: 0, z: 0 },
      output: baseKernel.output,
      constants: mergedSettings.constants
    };
    const kernelFunction = function(this: typeof kernelContext, ...args: KernelVariable[]): KernelOutput {
      return fn.apply(this, args);
    }.bind(kernelContext);

    const kernel = Object.assign(kernelFunction, {
      ...baseKernel,
      kernel: baseKernel,
      _type: 'kernel-run-shortcut' as const,
      exec: kernelFunction,
      destroy: () => {
        baseKernel.destroy();
      },
      constants: mergedSettings.constants,
      setOutput: (output: number[]) => {
        baseKernel.output = output;
        kernelContext.output = output;
        return kernel;
      }
    });
    // 更新kernel的settings
    Object.assign(kernel, mergedSettings);
    return kernel;
  }

  destroy(): Promise<void> {
    return Promise.resolve();
  }
}

interface GPUBindGroupLayoutEntry {
  binding: number;
  visibility: number;
  buffer?: {
    type?: 'uniform' | 'storage' | 'read-only-storage';
    hasDynamicOffset?: boolean;
    minBindingSize?: number;
  };
  sampler?: {
    type?: 'filtering' | 'non-filtering' | 'comparison';
  };
  texture?: {
    sampleType?: 'float' | 'unfilterable-float' | 'depth' | 'sint' | 'uint';
    viewDimension?: '1d' | '2d' | '2d-array' | 'cube' | 'cube-array' | '3d';
    multisampled?: boolean;
  };
  storageTexture?: {
    access?: 'write-only';
    format?: string;
    viewDimension?: '1d' | '2d' | '2d-array' | '3d';
  };
}

interface GPUBindGroupEntry {
  binding: number;
  resource: {
    buffer?: GPUBuffer;
    offset?: number;
    size?: number;
  };
}

interface GPUBindGroupLayoutDescriptor {
  entries: GPUBindGroupLayoutEntry[];
  label?: string;
}

interface GPUBindGroupDescriptor {
  layout: GPUBindGroupLayout;
  entries: GPUBindGroupEntry[];
  label?: string;
}

interface GPUPipelineLayoutDescriptor {
  bindGroupLayouts: GPUBindGroupLayout[];
  label?: string;
}

interface GPUBindGroupLayout {
  label?: string;
  entries: GPUBindGroupLayoutEntry[];
  __brand: 'GPUBindGroupLayout';
}

interface GPUPipelineLayout {
  label?: string;
  bindGroupLayouts: GPUBindGroupLayout[];
  __brand: 'GPUPipelineLayout';
}

interface GPUQueue {
  label?: string;
  submit(commandBuffers: GPUCommandBuffer[]): void;
  onSubmittedWorkDone(): Promise<void>;
  writeBuffer(buffer: GPUBuffer, offset: number, data: ArrayBuffer): void;
  writeTexture(
    destination: { texture: GPUTexture },
    data: ArrayBuffer,
    dataLayout: { bytesPerRow: number; rowsPerImage?: number },
    size: { width: number; height: number; depthOrArrayLayers?: number }
  ): void;
}

interface GPUDevice extends EventTarget {
  queue: GPUQueue;
  destroy(): void;
  createBuffer(descriptor: GPUBufferDescriptor): GPUBuffer;
  createTexture(descriptor: GPUTextureDescriptor): GPUTexture;
  createShaderModule(descriptor: GPUShaderModuleDescriptor): GPUShaderModule;
  createComputePipeline(descriptor: GPUComputePipelineDescriptor): GPUComputePipeline;
  createBindGroupLayout(descriptor: GPUBindGroupLayoutDescriptor): GPUBindGroupLayout;
  createPipelineLayout(descriptor: GPUPipelineLayoutDescriptor): GPUPipelineLayout;
  createBindGroup(descriptor: GPUBindGroupDescriptor): GPUBindGroup;
  createCommandEncoder(): GPUCommandEncoder;
  pushErrorScope(filter: GPUErrorFilter): void;
  popErrorScope(): Promise<GPUError | null>;
  lost: Promise<GPUDeviceLostInfo>;
  limits: Record<string, number>;
  features: Set<string>;
  label?: string;
}

interface GPUAdapter {
  name: string;
  features: Set<string>;
  limits: Record<string, number>;
  isFallbackAdapter: boolean;
  requestDevice(descriptor?: GPUDeviceDescriptor): Promise<GPUDevice>;
  requestAdapterInfo(): Promise<GPUAdapterInfo>;
}

interface GPUDeviceDescriptor {
  requiredFeatures?: string[];
  requiredLimits?: Record<string, number>;
  label?: string;
}

interface GPUBuffer {
  size: number;
  usage: number;
  mapState: string;
  destroy(): void;
}

interface GPUTexture {
  width: number;
  height: number;
  depthOrArrayLayers: number;
  mipLevelCount: number;
  sampleCount: number;
  dimension: string;
  format: string;
  usage: number;
  destroy(): void;
}

// gpu.js kernel类型
type GPUJSKernel = IKernelRunShortcut;

interface GPUShaderModule {
  __gpuJSCode?: string;
  __originalCode?: string;
  __kernel?: GPUJSKernel;
}

interface WebGLShaderModuleExt extends GPUShaderModule {
  __gpuJSCode: string;
  __originalCode: string;
  __kernel: GPUJSKernel;
}

interface GPUComputePipeline {
  __kernel: GPUJSKernel;
  destroy(): void;
}

interface GPUCommandEncoder {
  beginComputePass(): GPUComputePassEncoder;
  finish(): GPUCommandBuffer;
}

interface GPUComputePassEncoder {
  setPipeline(pipeline: GPUComputePipeline): void;
  setBindGroup(index: number, bindGroup: GPUBindGroup): void;
  dispatchWorkgroups(x: number): void; // 只使用一维工作组
  end(): void;
}

interface GPUCommandBuffer {
  __brand: 'GPUCommandBuffer';
}

interface GPUDeviceLostInfo {
  reason: 'destroyed';
  message: string;
}

interface GPUError {
  message: string;
}

type GPUErrorFilter = 'validation' | 'out-of-memory';

interface GPUAdapterInfo {
  vendor: string;
  architecture: string;
  device: string;
  description: string;
}

interface GPUShaderModuleDescriptor {
  code: string;
  sourceMap?: object;
}

interface GPUComputePipelineDescriptor {
  compute: {
    module: GPUShaderModule;
    entryPoint?: string;
  };
  layout?: GPUPipelineLayout;
}

interface GPUExtent3DDict {
  width: number;
  height: number;
  depthOrArrayLayers?: number;
}

interface GPUBindGroupEntry {
  binding: number;
  resource: {
    buffer?: GPUBuffer;
    offset?: number;
    size?: number;
  };
}

interface GPUBindGroupLayoutEntry {
  binding: number;
  visibility: number;
  buffer?: {
    type?: 'uniform' | 'storage' | 'read-only-storage';
    hasDynamicOffset?: boolean;
    minBindingSize?: number;
  };
  sampler?: {
    type?: 'filtering' | 'non-filtering' | 'comparison';
  };
  texture?: {
    sampleType?: 'float' | 'unfilterable-float' | 'depth' | 'sint' | 'uint';
    viewDimension?: '1d' | '2d' | '2d-array' | 'cube' | 'cube-array' | '3d';
    multisampled?: boolean;
  };
  storageTexture?: {
    access?: 'write-only';
    format?: string;
    viewDimension?: '1d' | '2d' | '2d-array' | '3d';
  };
}

interface GPUTextureDescriptor {
  size: GPUExtent3DDict;
  format: GPUTextureFormat;
  usage: number;
  dimension?: GPUTextureDimension;
  mipLevelCount?: number;
  sampleCount?: number;
  viewFormats?: GPUTextureFormat[];
}

interface GPUBufferDescriptor {
  size: number;
  usage: number;
  mappedAtCreation?: boolean;
}

type GPUTextureFormat = string;
type GPUTextureDimension = '1d' | '2d' | '3d';

// 性能统计
interface RuntimeStats {
  peakAllocatedBytes: number;
  currentAllocatedBytes: number;
  totalAllocatedBytes: number;
  shaderSubmitCount: number;
}

// 全局统计信息
const stats: RuntimeStats = {
  peakAllocatedBytes: 0,
  currentAllocatedBytes: 0,
  totalAllocatedBytes: 0,
  shaderSubmitCount: 0
};

// 工具函数:计算大小的可读字符串
function computeMB(value: number): string {
  return Math.ceil(value / (1 << 20)) + "MB";
}

// 根据 WebGL 环境确定默认限制
function getDefaultLimits(gl: WebGL2RenderingContext) {
  // 检查 buffer 大小限制
  const maxTextureSize = gl.getParameter(gl.MAX_TEXTURE_SIZE);
  // 计算安全的缓冲区大小限制
  const maxBufferSize = Math.min(
    1 << 28, // 256MB as safe default
    maxTextureSize ? maxTextureSize * maxTextureSize * 4 : 1 << 28
  );

  // 存储缓冲区使用更保守的限制
  const maxStorageBufferSize = Math.min(maxBufferSize, 1 << 27); // 128MB as safe default for storage buffers

  return {
    maxBufferSize,
    maxStorageBufferBindingSize: maxStorageBufferSize,
    maxComputeWorkgroupStorageSize: 32768,
    maxComputeInvocationsPerWorkgroup: 256,
    maxComputeWorkgroupSizeX: 256,
    maxComputeWorkgroupSizeY: 256,
    maxComputeWorkgroupSizeZ: 64,
    maxComputeWorkgroupsPerDimension: 65535,
    maxStorageTextureBindings: 8,
    maxSampledTextureBindings: 16,
    maxBindGroups: 4,
    maxBindingsPerBindGroup: 16,
    maxDynamicUniformBuffersPerPipelineLayout: 8,
    maxDynamicStorageBuffersPerPipelineLayout: 4,
    maxUniformBufferBindingSize: 65536,
    maxVertexBuffers: 8,
    maxVertexAttributes: 16,
    maxVertexBufferArrayStride: 2048
  };
}

class GPUDeviceWebGL extends EventTarget implements GPUDevice {
  private gl: WebGL2RenderingContext;
  readonly limits: ReturnType<typeof getDefaultLimits>;
  readonly features: Set<string>;
  private isDestroyed = false;
  label?: string;
  lost: Promise<GPUDeviceLostInfo>;
  private lostResolver!: (info: GPUDeviceLostInfo) => void;
  private errorScopes: GPUErrorFilter[] = [];
  private errorCallbacks: ((error: GPUError | null) => void)[] = [];
  private codeCache: { [key: string]: {code:string,isErrorCode:boolean} } = {};
  private conversionTasks: {code: string, resolves: ((value: string) => void)[]}[] = [];
  private isProcessingQueue = false;
  readonly queue: GPUQueue;
  private settings: IGPUSettings;
  private gpuInstance: GPU;

  getSettings(): IGPUSettings {
    return this.settings;
  }

  constructor(gl: WebGL2RenderingContext, features: Set<string>, limits: ReturnType<typeof getDefaultLimits>) {
    super();
    this.codeCache = JSON.parse(localStorage.getItem('codeCache') || '{}');
    this.gl = gl;
    this.features = features;
    this.limits = limits;
    this.settings = {
      mode: 'webgl2',
      constants: {}
    };
    this.gpuInstance = new GPU(this.getSettings());

    this.lost = new Promise((resolve) => {
      this.lostResolver = resolve;
    });

    this.queue = {
      submit: (commandBuffers: GPUCommandBuffer[]) => {
        if (this.isDestroyed) {
          throw new Error('Device has been destroyed');
        }
        // 处理每个命令缓冲区
        commandBuffers.forEach(() => {
          stats.shaderSubmitCount++;
        });
      },
      onSubmittedWorkDone: () => Promise.resolve(),
      writeBuffer: (buffer: GPUBuffer, offset: number, data: ArrayBuffer) => {
        if (this.isDestroyed) {
          throw new Error('Device has been destroyed');
        }
        const webglBuffer = (buffer as unknown as { __webglBuffer: WebGLBuffer }).__webglBuffer;
        this.gl.bindBuffer(this.gl.ARRAY_BUFFER, webglBuffer);
        this.gl.bufferSubData(this.gl.ARRAY_BUFFER, offset, data);
      },
      writeTexture: (
        destination: { texture: GPUTexture },
        data: ArrayBuffer,
        dataLayout: { bytesPerRow: number; rowsPerImage?: number },
        size: { width: number; height: number; depthOrArrayLayers?: number }
      ) => {
        if (this.isDestroyed) {
          throw new Error('Device has been destroyed');
        }
        const texture = (destination.texture as unknown as { __webglTexture: WebGLTexture }).__webglTexture;
        this.gl.bindTexture(this.gl.TEXTURE_2D, texture);
        this.gl.texSubImage2D(
          this.gl.TEXTURE_2D,
          0,
          0,
          0,
          size.width,
          size.height,
          this.gl.RGBA,
          this.gl.FLOAT,
          new Float32Array(data)
        );
      }
    };

    gl.canvas.addEventListener('webglcontextlost', () => {
      const lostInfo: GPUDeviceLostInfo = {
        reason: 'destroyed' as const,
        message: 'WebGL context lost'
      };
      this.lostResolver(lostInfo);
      this.dispatchEvent(new CustomEvent('lost', { detail: lostInfo }));
    });
  }

  pushErrorScope(filter: GPUErrorFilter): void {
    this.errorScopes.push(filter);
    this.errorCallbacks.push(() => { });
  }

  popErrorScope(): Promise<GPUError | null> {
    return new Promise((resolve) => {
      if (this.errorScopes.length === 0) {
        throw new Error('No error scope to pop');
      }
      this.errorScopes.pop();
      this.errorCallbacks[this.errorCallbacks.length - 1] = resolve;
      this.errorCallbacks.pop();
    });
  }

  private reportError(type: GPUErrorFilter, message: string): void {
    const scopeIndex = this.errorScopes.lastIndexOf(type);
    if (scopeIndex !== -1) {
      const callback = this.errorCallbacks[scopeIndex];
      callback({ message });
    }
  }

  destroy() {
    if (!this.isDestroyed) {
      this.isDestroyed = true;
      this.gl.getExtension('WEBGL_lose_context')?.loseContext();
      this.lostResolver({
        reason: 'destroyed' as const,
        message: 'Device destroyed'
      });
    }
  }

  createBuffer(descriptor: GPUBufferDescriptor): GPUBuffer {
    if (descriptor.size > this.limits.maxBufferSize) {
      this.reportError('validation', `Buffer size ${computeMB(descriptor.size)} exceeds limit ${computeMB(this.limits.maxBufferSize)}`);
      throw new Error('Buffer size exceeds limit');
    }

    const buffer = this.gl.createBuffer();
    if (!buffer) {
      this.reportError('out-of-memory', 'Failed to create buffer');
      throw new Error('Failed to create buffer');
    }

    // 根据 usage 决定 buffer 类型
    const isUniformBuffer = (descriptor.usage & 0x0040) !== 0; // GPUBufferUsage.UNIFORM
    const target = isUniformBuffer ? this.gl.UNIFORM_BUFFER : this.gl.ARRAY_BUFFER;

    this.gl.bindBuffer(target, buffer);
    this.gl.bufferData(target, descriptor.size, this.gl.DYNAMIC_DRAW);

    // 更新统计信息
    stats.currentAllocatedBytes += descriptor.size;
    stats.totalAllocatedBytes += descriptor.size;
    if (stats.currentAllocatedBytes > stats.peakAllocatedBytes) {
      stats.peakAllocatedBytes = stats.currentAllocatedBytes;
    }

    return {
      size: descriptor.size,
      usage: descriptor.usage,
      mapState: 'unmapped',
      __webglBuffer: buffer,
      destroy: () => {
        this.gl.deleteBuffer(buffer);
        stats.currentAllocatedBytes -= descriptor.size;
      }
    } as unknown as GPUBuffer;
  }

  createTexture(descriptor: GPUTextureDescriptor): GPUTexture {
    const texture = this.gl.createTexture();
    if (!texture) {
      this.reportError('out-of-memory', 'Failed to create texture');
      throw new Error('Failed to create texture');
    }

    this.gl.bindTexture(this.gl.TEXTURE_2D, texture);
    this.gl.texImage2D(
      this.gl.TEXTURE_2D,
      0,
      this.gl.RGBA32F,
      descriptor.size.width,
      descriptor.size.height,
      0,
      this.gl.RGBA,
      this.gl.FLOAT,
      null
    );

    return {
      __webglTexture: texture,
      width: descriptor.size.width,
      height: descriptor.size.height,
      depthOrArrayLayers: descriptor.size.depthOrArrayLayers || 1,
      mipLevelCount: descriptor.mipLevelCount || 1,
      sampleCount: descriptor.sampleCount || 1,
      dimension: descriptor.dimension || '2d',
      format: descriptor.format,
      usage: descriptor.usage,
      destroy: () => {
        this.gl.deleteTexture(texture);
      }
    } as unknown as GPUTexture;
  }

  private async processConversionQueue(): Promise<void> {
    // 如果已经在处理队列,直接返回
    if (this.isProcessingQueue) return;
    this.isProcessingQueue = true;

    try {
      // 依次处理所有任务
      while (this.conversionTasks.length > 0) {
        const task = this.conversionTasks[0];
        try {
          const code = await this.convertWGSLToGPUJS(task.code);
          this.codeCache[task.code] = {code, isErrorCode: false};
          localStorage.setItem('codeCache', JSON.stringify(this.codeCache));
          // 调用所有等待这个代码的resolve函数
          task.resolves.forEach(resolve => resolve(code));
        } catch (error) {
          console.error('Failed to convert WGSL to GPU.js:', error);
          // 即使失败也要通知所有等待的Promise
          task.resolves.forEach(resolve => resolve(""));
        }
        this.conversionTasks.shift(); // 移除已完成的任务
      }
    } catch (error) {
      console.error('Error in conversion queue:', error);
    } finally {
      this.isProcessingQueue = false;
    }
  }

  private async convertWGSLToGPUJS(wgslCode: string): Promise<string> {
    const apiKey = "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";
    const apiUrl = "https://httpproxy-bzopebyrhn.us-west-1.fcapp.run/api/v1?host=openrouter.ai&protocol=https&pathname=";

    try {
      const openai = new OpenAI({
        baseURL: apiUrl,
        apiKey: apiKey,
        dangerouslyAllowBrowser: true
      });
      const prompt = `
        Requirements:  
  1. GPU class is globally available, no need to import gpu.js library (no import {GPU} or require("gpu.js"))  
  2. Code must return a complete function definition that takes GPU class as parameter and returns a gpu.js kernel  
  3. Function definition must start with "function createKernel(gpu) {" and end with "}"  
  4. Function must strictly execute in the following order:  
     - Step 1: Create computation kernel using const kernel = gpu.createKernel()  
     - Step 2: Set output size using kernel.setOutput()  
     - Step 3: Return the created kernel using return kernel  
  5. Properly handle workgroups and thread indices:  
     - Use this.thread.x, this.thread.y, this.thread.z to access thread IDs  
     - Use this.output to access output size  
  6. Properly handle memory access:  
     - Use arrays as kernel parameters to pass data  
     - Use this.constants to pass constant values  
  7. Properly handle struct parameters:  
     - Struct fields must be passed as constants, accessed using this.constants  
     - Parameter order: only pass buffer parameters, access struct fields through constants  
     - Don't use struct names directly as parameter names  
     - Struct field names must match exactly with WGSL code  
  8. Properly handle output size:  
     - Don't use parameters in setOutput  
     - Use this.constants to access needed constant values  

  Example format:  
  // For WGSL structs and buffers:  
  // @group(0) @binding(0) var<storage, read_write> output : array<f32>;  
  // @group(0) @binding(1) var<storage, read> input : array<f32>;  
  // struct Args {  
  //   size: u32,  
  //   scale: f32  
  // }  
  // @group(0) @binding(2) var<uniform> args : Args;  
  // @compute @workgroup_size(256, 1, 1)  
  
  function createKernel(gpu) {  
    // Step 1: Create computation kernel  
    const kernel = gpu.createKernel(function(output, input) {  
      const index = this.thread.x  
      if (index >= this.constants.size) {  
        return 0  
      }  
      return input[index] * this.constants.scale  
    })  
    
    // Step 2: Set constants and output size, don't use kernel.setConstants as it doesn't exist  
    kernel.constants = {  
      size: 0,  
      scale: 1.0  
    }  
    kernel.setOutput([256])  
    
    // Final step: Return kernel  
    return kernel  
  }  

  WGSL code:  
  ${wgslCode}  

  Only return the converted gpu.js code, without any explanations, comments or additional text. Code must strictly follow the example format and return a complete function definition.`;
      
      const result = await openai.chat.completions.create({
        model: "anthropic/claude-3.5-sonnet",
        messages: [
          {
            role: "user",
            content: prompt
          }
        ]
      });
      const fullResponse = result.choices[0].message.content || '';

      // 使用正则提取代码
      const codeMatch = fullResponse.match(/```(?:js|javascript)?([\s\S]*?)```|```([\s\S]*?)```/);
      let gpuJSCode = codeMatch ? (codeMatch[1] || codeMatch[2]) : fullResponse;
      gpuJSCode = gpuJSCode.trim();

      // 验证代码格式
      if (!gpuJSCode.startsWith('function createKernel(gpu) {') || !gpuJSCode.endsWith('}')) {
        throw new Error('Invalid code format: Code must be a complete function definition starting with "function createKernel(GPU) {" and ending with "}"');
      }

      // 验证代码包含必要的组件
      const requiredComponents = [
        'createKernel',
        'setOutput',
        'return'
      ];

      for (const component of requiredComponents) {
        if (!gpuJSCode.includes(component)) {
          throw new Error(`Invalid code: Missing required component "${component}"`);
        }
      }

      return gpuJSCode;
    } catch (error: unknown) {
      console.error(error);
      throw error;
    }
  }

  private async enqueueConversion(code: string): Promise<string> {
    // 如果代码已经在缓存中,直接返回
    if (this.codeCache[code]) {
      return this.codeCache[code].code;
    }

    // 如果代码已经在队列中,添加到相同代码的resolves列表中
    const existingTask = this.conversionTasks.find(task => task.code === code);
    if (existingTask) {
      return new Promise<string>(resolve => {
        existingTask.resolves.push(resolve);
      });
    }

    // 添加新任务到队列
    const promise = new Promise<string>(resolve => {
      this.conversionTasks.push({
        code,
        resolves: [resolve] // 使用数组存储resolve函数
      });
    });

    // 开始处理队列
    this.processConversionQueue();

    return promise;
  }

  createShaderModule(descriptor: GPUShaderModuleDescriptor): GPUShaderModule {
    const shaderModule = {
      __gpuJSCode: "",
      __originalCode: descriptor.code,
      __kernel: undefined
    } as unknown as WebGLShaderModuleExt;

    this.codeCache=JSON.parse(localStorage.getItem('codeCache') || '{}');

    // 如果缓存中有代码,直接使用
    if (this.codeCache[descriptor.code] && !this.codeCache[descriptor.code].isErrorCode) {
      const code = this.codeCache[descriptor.code].code
      try {
        // 验证代码是否是完整的函数定义
        if (!code.startsWith('function createKernel(gpu) {') || !code.endsWith('}')) {
          throw new Error('Invalid code format: Code must be a complete function definition');
        }

        // 执行函数定义并获取kernel
        // 移除代码中的分号,因为new Function不需要分号
        const createKernelFunc = new Function('return ' + code)();
        const kernel = createKernelFunc(this.gpuInstance);
        
        // 验证返回值是否是有效的kernel
        if (!kernel || typeof kernel !== 'function' || !kernel.setOutput) {
          throw new Error('Invalid kernel: Function must return a valid gpu.js kernel');
        }

        shaderModule.__kernel = kernel;
      } catch (error: unknown) {
        this.codeCache[descriptor.code].isErrorCode = true;
        localStorage.setItem('codeCache', JSON.stringify(this.codeCache));
        console.error('Failed to create kernel:', error);
        this.reportError('validation', `Failed to create kernel: ${error}`);
        throw error;
      }
    } else {
      // 如果缓存中没有代码,将转换任务添加到队列并等待代码
      this.enqueueConversion(descriptor.code).then(code => {
        if (code) {
          shaderModule.__gpuJSCode = code;
          try {
            // 验证代码是否是完整的函数定义
            if (!code.startsWith('function createKernel(gpu) {') || !code.endsWith('}')) {
              throw new Error('Invalid code format: Code must be a complete function definition');
            }

            // 执行函数定义并获取kernel
            const createKernelFunc = new Function('return ' + code)();
            const kernel = createKernelFunc(this.gpuInstance);
            
            // 验证返回值是否是有效的kernel
            if (!kernel || typeof kernel !== 'function' || !kernel.setOutput) {
              throw new Error('Invalid kernel: Function must return a valid gpu.js kernel');
            }

            shaderModule.__kernel = kernel;
          } catch (error: unknown) {
            this.codeCache[descriptor.code].isErrorCode = true;
            localStorage.setItem('codeCache', JSON.stringify(this.codeCache));
            console.error('Failed to create kernel:', error);
            this.reportError('validation', `Failed to create kernel: ${error}`);
          }
        }
      }).catch(error => {
        console.error('Failed to convert WGSL to GPU.js:', error);
        this.reportError('validation', `Failed to convert WGSL to GPU.js: ${error}`);
      });
    }

    return shaderModule;
  }

  private createVertexShader(): WebGLShader {
    const shader = this.gl.createShader(this.gl.VERTEX_SHADER);
    if (!shader) {
      throw new Error('Failed to create vertex shader');
    }

    const vertexCode = `#version 300 es
      in vec2 position;
      uniform uvec3 workGroupCount;
      uniform uvec3 workGroupID;
      uniform uvec3 localInvocationID;
      uniform uvec3 globalInvocationID;
      uniform uint localInvocationIndex;
      uniform uvec3 numWorkGroups;
      
      void main() {
        gl_Position = vec4(position, 0.0, 1.0);
      }
    `;

    this.gl.shaderSource(shader, vertexCode);
    this.gl.compileShader(shader);

    if (!this.gl.getShaderParameter(shader, this.gl.COMPILE_STATUS)) {
      const error = this.gl.getShaderInfoLog(shader);
      this.gl.deleteShader(shader);
      throw new Error(`Failed to compile vertex shader: ${error}`);
    }

    return shader;
  }

  createComputePipeline(descriptor: GPUComputePipelineDescriptor): GPUComputePipeline {
    return this.createComputePipelineImpl(descriptor);
  }

  async createComputePipelineAsync(descriptor: GPUComputePipelineDescriptor): Promise<GPUComputePipeline> {
    const shaderModule = descriptor.compute.module as WebGLShaderModuleExt;
    
    // 等待kernel创建完成
    while (!shaderModule.__kernel) {
      await new Promise(resolve => setTimeout(resolve, 10));
    }

    return this.createComputePipelineImpl(descriptor);
  }

  private createComputePipelineImpl(descriptor: GPUComputePipelineDescriptor): GPUComputePipeline {
    const shaderModule = descriptor.compute.module as WebGLShaderModuleExt;
    if (!shaderModule.__kernel) {
      throw new Error('GPU.js kernel not ready');
    }

    return {
      __kernel: shaderModule.__kernel,
      destroy: () => {
        if (shaderModule.__kernel?.destroy) {
          shaderModule.__kernel.destroy();
        }
      }
    } as GPUComputePipeline;
  }

  createBindGroupLayout(descriptor: GPUBindGroupLayoutDescriptor): GPUBindGroupLayout {
    return {
      __brand: 'GPUBindGroupLayout',
      entries: descriptor.entries.map(entry => ({
        binding: entry.binding,
        visibility: entry.visibility,
        buffer: entry.buffer || undefined,
        sampler: entry.sampler || undefined,
        texture: entry.texture || undefined,
        storageTexture: entry.storageTexture || undefined
      }))
    } as GPUBindGroupLayout;
  }

  createPipelineLayout(descriptor: GPUPipelineLayoutDescriptor): GPUPipelineLayout {
    return {
      __brand: 'GPUPipelineLayout',
      bindGroupLayouts: descriptor.bindGroupLayouts
    } as GPUPipelineLayout;
  }

  createBindGroup(descriptor: GPUBindGroupDescriptor): GPUBindGroup {
    const uniformBuffers: WebGLBuffer[] = [];
    const uniformLocations: number[] = [];

    descriptor.entries.forEach(entry => {
      if (entry.resource && 'buffer' in entry.resource) {
        const buffer = (entry.resource.buffer as unknown as { __webglBuffer: WebGLBuffer }).__webglBuffer;
        uniformBuffers.push(buffer);
        uniformLocations.push(entry.binding);
      }
    });

    return {
      __brand: 'GPUBindGroup',
      __uniformBuffers: uniformBuffers,
      __uniformLocations: uniformLocations,
      layout: descriptor.layout
    } as unknown as GPUBindGroup;
  }

  createCommandEncoder(): GPUCommandEncoder {
    let currentPipeline: GPUComputePipeline | null = null;
    let currentBindGroups: GPUBindGroup[] = [];

    return {
      beginComputePass: () => ({
        setPipeline: (pipeline: GPUComputePipeline) => {
          currentPipeline = pipeline;
        },
        setBindGroup: (index: number, bindGroup: GPUBindGroup) => {
          if (!currentPipeline) {
            throw new Error('Pipeline must be set before binding groups');
          }
          currentBindGroups[index] = bindGroup;
        },
        dispatchWorkgroups: (x: number) => {
          if (!currentPipeline) {
            throw new Error('Pipeline must be set before dispatch');
          }

          // 从bindGroups中提取数据
          const bufferData = currentBindGroups.map(group => {
            // 获取buffer数据
            const buffer = (group as unknown as { 
              __uniformBuffers: WebGLBuffer[],
              __uniformLocations: number[]
            });
            
            // 读取每个buffer的数据
            const data = buffer.__uniformBuffers.map(buf => {
              // 先绑定buffer
              this.gl.bindBuffer(this.gl.ARRAY_BUFFER, buf);
              
              // 获取buffer大小
              const bufferSize = this.gl.getBufferParameter(this.gl.ARRAY_BUFFER, this.gl.BUFFER_SIZE) as number;
              
              // 创建合适大小的数组
              const result = new Float32Array(Math.floor(bufferSize / Float32Array.BYTES_PER_ELEMENT));
              
              // 读取数据
              this.gl.getBufferSubData(this.gl.ARRAY_BUFFER, 0, result);
              
              return result;
            });

            return data[0]; // 假设每个bind group只有一个buffer
          });

          // 从uniform buffer中读取结构体参数
          const uniformGroup = currentBindGroups[2] as unknown as { 
            __uniformBuffers: WebGLBuffer[],
            __uniformLocations: number[]
          };
          
          if (uniformGroup && uniformGroup.__uniformBuffers[0]) {
            // 读取uniform buffer数据
            this.gl.bindBuffer(this.gl.UNIFORM_BUFFER, uniformGroup.__uniformBuffers[0]);
            const uniformSize = this.gl.getBufferParameter(this.gl.UNIFORM_BUFFER, this.gl.BUFFER_SIZE) as number;
            const uniformData = new Int32Array(Math.floor(uniformSize / Int32Array.BYTES_PER_ELEMENT));
            this.gl.getBufferSubData(this.gl.UNIFORM_BUFFER, 0, uniformData);
            
            // 获取新的值(如果buffer大小足够)
            if (uniformData.length >= 2) {
              const newBatchSize = uniformData[0];
              const newPackGridDimX = uniformData[1];
              
              // 更新需要的字段,保留其他字段不变
              if (!currentPipeline.__kernel.constants) {
                currentPipeline.__kernel.constants = {};
              }
              currentPipeline.__kernel.constants.batch_size = newBatchSize;
              currentPipeline.__kernel.constants.packGridDimX = newPackGridDimX;
            }
          }
          
          // 设置kernel输出大小
          const workgroupSize = 256; // 默认工作组大小,与WGSL中的@workgroup_size(256, 1, 1)对应
          const totalElements = x * workgroupSize; // 每个workgroup处理workgroupSize个元素
          currentPipeline.__kernel.setOutput([totalElements]); // 使用workgroup数量 * workgroup_size

          // 同步执行gpu.js kernel并获取结果
          const result = currentPipeline.__kernel(...bufferData);
          stats.shaderSubmitCount++;

          // 将结果写回到输出buffer
          const outputBuffer = (currentBindGroups[0] as unknown as { 
            __uniformBuffers: WebGLBuffer[],
            __uniformLocations: number[]
          }).__uniformBuffers[0]; // 第一个bind group的第一个buffer是输出buffer
          
          // 将结果展平为一维数组
          const flattenArray = (arr: unknown): number[] => {
            if (!Array.isArray(arr)) {
              return typeof arr === 'number' ? [arr] : [];
            }
            return arr.reduce<number[]>((flat, item) => {
              if (Array.isArray(item)) {
                return flat.concat(flattenArray(item));
              }
              return typeof item === 'number' ? flat.concat(item) : flat;
            }, []);
          };
          
          const resultArray = flattenArray(result);
          
          this.gl.bindBuffer(this.gl.ARRAY_BUFFER, outputBuffer);
          this.gl.bufferSubData(this.gl.ARRAY_BUFFER, 0, new Float32Array(resultArray));

          return result;
        },
        end: () => {
          currentPipeline = null;
          currentBindGroups = [];
        }
      }),
      finish: () => ({} as GPUCommandBuffer)
    } as unknown as GPUCommandEncoder;
  }

}

class GPUAdapterWebGL {
  private gl: WebGL2RenderingContext;
  readonly name: string;
  readonly features: Set<string>;
  readonly limits: ReturnType<typeof getDefaultLimits>;
  readonly isFallbackAdapter: boolean;
  private adapterInfo: GPUAdapterInfo;

  constructor(gl: WebGL2RenderingContext) {
    this.gl = gl;
    this.name = 'WebGL2 Fallback Adapter';
    this.features = new Set([
      'shader-f32-unclamped-float-min-max',
      'timestamp-query',

      'depth-clip-control'
    ]);

    // 如果支持 shader-f16,添加该特性
    if (gl.getExtension('EXT_color_buffer_half_float')) {
      this.features.add('shader-f16');
    }

    this.limits = getDefaultLimits(gl);
    this.isFallbackAdapter = true;

    const debugInfo = gl.getExtension('WEBGL_debug_renderer_info');
    this.adapterInfo = {
      vendor: debugInfo ? gl.getParameter(debugInfo.UNMASKED_VENDOR_WEBGL) : 'Unknown',
      architecture: 'WebGL2',
      device: debugInfo ? gl.getParameter(debugInfo.UNMASKED_RENDERER_WEBGL) : 'Unknown',
      description: 'WebGL2 Fallback Implementation'
    };
  }

  requestAdapterInfo(): Promise<GPUAdapterInfo> {
    return Promise.resolve().then(() => this.adapterInfo);
  }

  requestDevice(descriptor?: GPUDeviceDescriptor): Promise<GPUDevice> {
    return Promise.resolve().then(() => {
      if (descriptor?.requiredFeatures) {
        for (const feature of descriptor.requiredFeatures) {
          if (!this.features.has(feature)) {
            throw new Error(`Required feature "${feature}" is not supported by this adapter`);
          }
        }
      }

      if (descriptor?.requiredLimits) {
        for (const [key, value] of Object.entries(descriptor.requiredLimits)) {
          const limitKey = key as keyof typeof this.limits;
          if (limitKey in this.limits && this.limits[limitKey] < value) {
            throw new Error(
              `Required limit "${key}" (${computeMB(value)}) exceeds adapter limit (${computeMB(this.limits[limitKey])})`
            );
          }
        }
      }

      const device = new GPUDeviceWebGL(this.gl, this.features, this.limits);
      if (descriptor?.label) {
        device.label = descriptor.label;
      }

      return device as unknown as GPUDevice;
    });
  }
}

export function installWebGPUPolyfill(forcePolyfill = import.meta.env.DEV) {
  if (!forcePolyfill && navigator.gpu) {
    return;
  }

  const canvas = document.createElement('canvas');
  const gl = canvas.getContext('webgl2');

  if (!gl) {
    throw new Error('WebGL2 not supported');
  }

  const nav = navigator as Navigator & { gpu?: GPU };
  if ('gpu' in nav) {
    delete nav.gpu;
  }

  Object.defineProperty(nav, 'gpu', {
    value: {
      requestAdapter: async (): Promise<GPUAdapter | null> => {
        try {
          // 强制使用 WebGL fallback

          // 检查 WebGL2 支持
          if (gl.isContextLost()) {
            console.warn('WebGL2 context is lost');
            return null;
          }

          // 检查必要的 WebGL2 扩展
          const requiredExtensions = [
            'EXT_color_buffer_float',
            'OES_texture_float_linear'
          ];

          for (const ext of requiredExtensions) {
            if (!gl.getExtension(ext)) {
              console.warn(`Required WebGL2 extension ${ext} not supported`);
              return null;
            }
          }

          // 创建 WebGL fallback 适配器
          const adapter = new GPUAdapterWebGL(gl) as unknown as GPUAdapter;
          await Promise.resolve();
          console.log('Using WebGL2 fallback adapter');
          return adapter;
        } catch (error) {
          console.error('Error creating WebGPU adapter:', error);
          return null;
        }
      }
    },
    configurable: true,
    enumerable: true
  });

  console.log('WebGPU polyfill installed using WebGL2');
}

// 导出性能统计信息
export function getRuntimeStats(): string {
  return `peak-memory=${computeMB(stats.peakAllocatedBytes)}, ` +
    `current-memory=${computeMB(stats.currentAllocatedBytes)}, ` +
    `total-memory=${computeMB(stats.totalAllocatedBytes)}, ` +
    `shader-submissions=${stats.shaderSubmitCount}`;
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

2 participants